diff --git a/.bazelrc b/.bazelrc index 066b0db10bc..396b84f70b3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -174,6 +174,12 @@ build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only -c opt +# Config setting to build with oneDNN for Arm. +build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true +build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_aarch64 --define=build_with_mkl_opensource=true +build:mkl_aarch64 -c opt + # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 952ff91fef7..c6dc0ec9c85 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -12,12 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -# -# THIS IS A GENERATED DOCKERFILE. -# -# This file was assembled from multiple pieces, whose use is documented -# throughout. Please refer to the TensorFlow dockerfiles documentation -# for more information. # A list of assignees assignees: diff --git a/.github/workflows/update-nightly.yml b/.github/workflows/update-nightly.yml new file mode 100644 index 00000000000..01b5147d053 --- /dev/null +++ b/.github/workflows/update-nightly.yml @@ -0,0 +1,28 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +on: + workflow_dispatch: # Allow manual triggers + schedule: + - cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST +name: Set nightly branch to master HEAD +jobs: + master-to-nightly: + runs-on: ubuntu-latest + steps: + - uses: zofrex/mirror-branch@v1 + name: Set nightly branch to master HEAD + with: + target-branch: 'nightly' diff --git a/CODEOWNERS b/CODEOWNERS index 83ad24b2845..9de1922a262 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -4,7 +4,7 @@ /tensorflow/core/common_runtime/eager @qqfish @kkimdev /tenosrflow/core/debug @caisq /tensorflow/core/nccl/ @azaks2 @chsigg -/tensorflow/core/platform/windows/ @gunan @mihaimaruseac +/tensorflow/core/platform/windows/ @mihaimaruseac /tensorflow/lite/experimental/micro @petewarden @advaitjain /tensorflow/python/autograph/ @mdanatg @kkimdev /tensorflow/python/debug @caisq diff --git a/RELEASE.md b/RELEASE.md index 5e06d0d473e..23324d56ca7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -34,6 +34,7 @@ shape assumptions (note that you can pass shapes with `None` entries for axes that are meant to be dynamic). You can also disable the input checking entirely by setting `model.input_spec = None`. +* TF pip packages now use CUDA11 and cuDNN 8.0.2. * XLA:CPU and XLA:GPU devices are no longer registered by default. Use `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be removed). @@ -46,6 +47,13 @@ * `tf.data.experimental.service.WorkerServer` now takes a config tuple instead of individual arguments. Usages should be updated to `tf.data.experimental.service.WorkerServer(worker_config)`. +* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which + updates the gradient definition for quantization which is outside the range + to be 0. To simulate the V1 the behavior of + tf.quantization.quantize_and_dequantize(...) use + tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). +* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please + use `tf.data.Dataset.from_tensor_slices` instead. ## Known Caveats @@ -63,143 +71,168 @@ ## Bug Fixes and Other Changes -* -* -* -* Security: - * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` - ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) - * Fixes three vulnerabilities in conversion to DLPack format - ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), - [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), - [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) - * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` - ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), - [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) - * Fixes several vulnerabilities in `RaggedCountSparseOutput` and - `SparseCountSparseOutput` operations - ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196), - [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197), - [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198), - [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199), - [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200), - [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)) - * Fixes an integer truncation vulnerability in code using the work sharder API - ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) - * Fixes a format string vulnerability in `tf.strings.as_string` - ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) - * Fixes segfault raised by calling session-only ops in eager mode - ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) - * Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams` - ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) - * Fixes segfaults caused by incomplete `SavedModel` validation - ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) - * Fixes a data corruption due to a bug in negative indexing support in TFLite - ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) - * Fixes a data corruption due to dimension mismatch in TFLite - ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) - * Fixes several vulnerabilities in TFLite saved model format - ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), - [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), - [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) - * Fixes several vulnerabilities in TFLite implementation of segment sum - ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), - [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), - [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) -* TF Core: - * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as - type annotation for variables representing a Tensor or a value that can be - converted to Tensor by `tf.convert_to_tensor`. - * Calling ops with a python constants or numpy values is now consistent with - tf.convert_to_tensor behavior. This avoids operations like tf.reshape - truncating inputs such as from int64 to int32. - * Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments. - * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` - and `__invert__` now support non-`bool` arguments and apply the - corresponding bitwise ops. `bool` arguments continue to be supported and - dispatch to logical ops. This brings them more in line with Python and NumPy - benavior. - * Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with - the same sparsity pattern, but with new provided values. It is similar to - the `with_values` function of `RaggedTensor`. - * Added `StatelessCase` op, and uses it if none of case branches has stateful ops. - * Added `tf.config.experimental.get_memory_usage` to return total memory usage - of the device. -* `tf.data`: - * tf.data service: - * Added new `tf.data.experimental.service.register_dataset` and - `tf.data.experimental.service.from_dataset_id` APIs to enable one process - to register a dataset with the tf.data service, and another process to - consume data from the dataset. - * Added support for dispatcher fault tolerance. To enable fault tolerance, - configure a `work_dir` when running your dispatcher server and set - `dispatcher_fault_tolerance=True`. The dispatcher will store its state to - `work_dir`, so that on restart it can continue from its previous state - after restart. - * Added support for sharing dataset graphs via shared filesystem instead of - over RPC. This reduces load on the dispatcher, improving performance of - distributing datasets. For this to work, the dispatcher's `work_dir` must - be accessible from workers. If the worker fails to read from the - `work_dir`, it falls back to using RPC for dataset graph transfer. - * Added support for a new "distributed_epoch" processing mode. This - processing mode distributes a dataset across all tf.data workers, instead - of having each worker process the full dataset. See - [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode) - to learn more. - * Added optional `exclude_cols` parameter to CsvDataset. This parameter is - the complement of `select_cols`; at most one of these should be specified. - * We have implemented an optimization which reorders data-discarding - transformations such as `take` and `shard` to happen earlier in the - dataset when it is safe to do so. The optimization can be disabled via - the `experimental_optimization.reorder_data_discarding_ops` dataset - option. - * `tf.data.Options` were previously immutable and can now be overriden. - * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors - with a new `output_signature` argument, which allows `from_generator` to - produce any type describable by a `tf.TypeSpec`. - * `tf.data.experimental.AUTOTUNE` is now available in the core API as - `tf.data.AUTOTUNE`. -* `tf.image`: - * Added deterministic `tf.image.stateless_random_*` functions for each - `tf.image.random_*` function. Added a new op - `stateless_sample_distorted_bounding_box` which is a determinstic - version of `sample_distorted_bounding_box` op. Given the same seed, these - stateless functions/ops produce the same results independent of how many - times the function is called, and independent of global seed settings. +* +* +* +* Security: + * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) + * Fixes three vulnerabilities in conversion to DLPack format + ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) + * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) + * Fixes several vulnerabilities in `RaggedCountSparseOutput` and + `SparseCountSparseOutput` operations + ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196), + [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197), + [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198), + [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199), + [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200), + [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)) + * Fixes an integer truncation vulnerability in code using the work sharder + API + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) + * Fixes a format string vulnerability in `tf.strings.as_string` + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) + * Fixes segfault raised by calling session-only ops in eager mode + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) + * Fixes data leak and potential ASLR violation from + `tf.raw_ops.StringNGrams` + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) + * Fixes segfaults caused by incomplete `SavedModel` validation + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) + * Fixes a data corruption due to a bug in negative indexing support in + TFLite + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) + * Fixes a data corruption due to dimension mismatch in TFLite + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) + * Fixes several vulnerabilities in TFLite saved model format + ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) + * Fixes several vulnerabilities in TFLite implementation of segment sum + ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), + [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), + [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) +* TF Core: + * `tf.types.experimental.TensorLike` is a new `Union` type that can be + used as type annotation for variables representing a Tensor or a value + that can be converted to Tensor by `tf.convert_to_tensor`. + * Calling ops with a python constants or numpy values is now consistent + with tf.convert_to_tensor behavior. This avoids operations like + tf.reshape truncating inputs such as from int64 to int32. + * Added `tf.sparse.map_values` to apply a function to the `.value`s of + `SparseTensor` arguments. + * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, + `__xor__` and `__invert__` now support non-`bool` arguments and apply + the corresponding bitwise ops. `bool` arguments continue to be supported + and dispatch to logical ops. This brings them more in line with Python + and NumPy behavior. + * Added `tf.SparseTensor.with_values`. This returns a new SparseTensor + with the same sparsity pattern, but with new provided values. It is + similar to the `with_values` function of `RaggedTensor`. + * Added `StatelessCase` op, and uses it if none of case branches has + stateful ops. + * Added `tf.config.experimental.get_memory_usage` to return total memory + usage of the device. +* `tf.data`: + * tf.data service: + * Added new `tf.data.experimental.service.register_dataset` and + `tf.data.experimental.service.from_dataset_id` APIs to enable one + process to register a dataset with the tf.data service, and another + process to consume data from the dataset. + * Added support for dispatcher fault tolerance. To enable fault tolerance, + configure a `work_dir` when running your dispatcher server and set + `dispatcher_fault_tolerance=True`. The dispatcher will store its state + to `work_dir`, so that on restart it can continue from its previous + state after restart. + * Added support for sharing dataset graphs via shared filesystem instead + of over RPC. This reduces load on the dispatcher, improving performance + of distributing datasets. For this to work, the dispatcher's `work_dir` + must be accessible from workers. If the worker fails to read from the + `work_dir`, it falls back to using RPC for dataset graph transfer. + * Added support for a new "distributed_epoch" processing mode. This + processing mode distributes a dataset across all tf.data workers, + instead of having each worker process the full dataset. See + [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode) + to learn more. + * Added optional `exclude_cols` parameter to CsvDataset. This parameter is + the complement of `select_cols`; at most one of these should be + specified. + * We have implemented an optimization which reorders data-discarding + transformations such as `take` and `shard` to happen earlier in the + dataset when it is safe to do so. The optimization can be disabled via + the `experimental_optimization.reorder_data_discarding_ops` dataset + option. + * `tf.data.Options` were previously immutable and can now be overridden. + * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors + with a new `output_signature` argument, which allows `from_generator` to + produce any type describable by a `tf.TypeSpec`. + * `tf.data.experimental.AUTOTUNE` is now available in the core API as + `tf.data.AUTOTUNE`. +* `tf.image`: + * Added deterministic `tf.image.stateless_random_*` functions for each + `tf.image.random_*` function. Added a new op + `stateless_sample_distorted_bounding_box` which is a deterministic + version of `sample_distorted_bounding_box` op. Given the same seed, + these stateless functions/ops produce the same results independent of + how many times the function is called, and independent of global seed + settings. * `tf.distribute`: - * -* `tf.keras`: - * Improvements from the functional API refactoring: - * Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models. - * Functional model construction should be ~8-10% faster on average. - * Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument. - * Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale` - * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. - * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` - as an alternative to accepting a `callable` loss. - * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) - to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf). - * Added `mobilenet_v3` to keras application model. - * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for - customization of how gradients are aggregated across devices, as well as - `gradients_transformers` to allow for custom gradient transformations - (such as gradient clipping). - * The `steps_per_execution` argument in `compile()` is no longer - experimental; if you were passing `experimental_steps_per_execution`, - rename it to `steps_per_execution` in your code. This argument controls - the number of batches to run during each `tf.function` call when calling - `fit()`. Running multiple batches inside a single `tf.function` call can - greatly improve performance on TPUs or small models with a large Python - overhead. -* `tf.function` / AutoGraph: - * Added `experimental_follow_type_hints` argument for `tf.function`. When - True, the function may use type annotations to optimize the tracing - performance. - * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. - * AutoGraph now allows creating new symbols inside a TensorFLow loop, if - the values of these symbols at an iteration does not depend on the previous - iteration. These types of loops must run at least one iteration, and will - raise a runtime error otherwise. + * +* `tf.keras`: + * Improvements from the functional API refactoring: + * Functional model construction does not need to maintain a global + workspace graph, removing memory leaks especially when building many + models or very large models. + * Functional model construction should be ~8-10% faster on average. + * Functional models can now contain non-symbolic values in their call + inputs inside of the first positional argument. + * Several classes of TF ops that were not reliably converted to Keras + layers during functional API construction should now work, e.g. + `tf.image.ssim_multiscale` + * Error messages when Functional API construction goes wrong (and when + ops cannot be converted to Keras layers automatically) should be + clearer and easier to understand. + * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` + as an alternative to accepting a `callable` loss. + * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) + to match FTRL paper + (https://research.google.com/pubs/archive/41159.pdf). + * Added `mobilenet_v3` to keras application model. + * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for + customization of how gradients are aggregated across devices, as well as + `gradients_transformers` to allow for custom gradient transformations + (such as gradient clipping). + * The `steps_per_execution` argument in `compile()` is no longer + experimental; if you were passing `experimental_steps_per_execution`, + rename it to `steps_per_execution` in your code. This argument controls + the number of batches to run during each `tf.function` call when calling + `fit()`. Running multiple batches inside a single `tf.function` call can + greatly improve performance on TPUs or small models with a large Python + overhead. + * Improvements to Keras preprocessing layers: + * TextVectorization can now accept a vocabulary list or file as an + init arg. + * Normalization can now accept mean and variance values as init args. + * In `Attention` and `AdditiveAttention` layers, the `call()` method now + accepts a `return_attention_scores` argument. When set to + True, the layer returns the attention scores as an additional output + argument. + * Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints + with the same implementation as their `tf.losses` equivalent. +* `tf.function` / AutoGraph: + * Added `experimental_follow_type_hints` argument for `tf.function`. When + True, the function may use type annotations to optimize the tracing + performance. + * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. + * AutoGraph now allows creating new symbols inside a TensorFLow loop, if + the values of these symbols at an iteration does not depend on the + previous iteration. These types of loops must run at least one + iteration, and will raise a runtime error otherwise. Example: @@ -208,51 +241,97 @@ outputs = train_step(batch) tf.print('final outputs', outputs) ``` + See tensorflow/python/autograph/g3doc/reference/limitations.md for more info. + * `tf.lite`: - * `DynamicBuffer::AddJoinedString()` will now add a separator if the first - string to be joined is empty. - * `TFLiteConverter`: - * Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`). - * Deprecate `Interpreter::UseNNAPI(bool)` C++ API - * Prefer using `NnApiDelegate()` and related delegate configuration methods directly. - * Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair. - * TFLite Profiler for Android is available. See the detailed - [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android). - * + + * `TFLiteConverter`: + * Support optional flags `inference_input_type` and + `inference_output_type` for full integer quantized models. This + allows users to modify the model input and output type to integer + types (`tf.int8`, `tf.uint8`) instead of defaulting to float type + (`tf.float32`). + * TFLite Profiler for Android is available. See the detailed + [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android). + * NNAPI + * Added NNAPI Delegation support for requantization use cases by + converting the operation into a dequantize-quantize pair. + * Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API. + * Use `Interpreter.Options.setUseNNAPI` instead. + * Deprecate `Interpreter::UseNNAPI(bool)` C++ API. + * Use `NnApiDelegate()` and related delegate configuration methods + directly. + * Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API + * Prefer controlling this via delegate options, e.g. + `tflite::StatefulNnApiDelegate::Options::allow_fp16' or + `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. + * `DynamicBuffer::AddJoinedString()` will now add a separator if the first + string to be joined is empty. + * + * `tf.random`: - * + + * + * Math and Linear Algebra: - * + + * + * TPU Enhancements: - * Added support for the `beta` parameter of the FTRL optimizer for TPU - embeddings. Users of other TensorFlow platforms can implement equivalent - behavior by adjusting the `l2` parameter. - * + + * Added support for the `beta` parameter of the FTRL optimizer for TPU + embeddings. Users of other TensorFlow platforms can implement equivalent + behavior by adjusting the `l2` parameter. + * + * XLA Support: - * xla.experimental.compile is deprecated, use - `tf.function(experimental_compile=True)` instead - * Added `tf.function.experimental_get_compiler_ir` which returns compiler IR - (currently 'hlo' and 'optimized_hlo') for given input for given function. - * + + * xla.experimental.compile is deprecated, use + `tf.function(experimental_compile=True)` instead + * Added `tf.function.experimental_get_compiler_ir` which returns compiler + IR (currently 'hlo' and 'optimized_hlo') for given input for given + function. + * + * Tracing and Debugging: - * + + * + * `tf.train.Checkpoint`: - * Now accepts a `root` argument in the initialization, which generates a - checkpoint with a root object. This allows users to create a `Checkpoint` - object that is compatible with Keras `model.save_weights()` and - `model.load_weights`. The checkpoint is also compatible with the - checkpoint saved in the `variables/` folder in the SavedModel. - * When restoring, `save_path` can be a path to a SavedModel. The function - will automatically find the checkpoint in the SavedModel. + + * Now accepts a `root` argument in the initialization, which generates a + checkpoint with a root object. This allows users to create a + `Checkpoint` object that is compatible with Keras `model.save_weights()` + and `model.load_weights`. The checkpoint is also compatible with the + checkpoint saved in the `variables/` folder in the SavedModel. + * When restoring, `save_path` can be a path to a SavedModel. The function + will automatically find the checkpoint in the SavedModel. + * `tf.nn`: - * `tf.nn.max_pool2d` now supports explicit padding. + + * `tf.nn.max_pool2d` now supports explicit padding. + +* `tf.debugging`: + + * `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268). + +* `tf.print`: + + * Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict` + didn't have the keys sorted, the keys and values were not being printed + in accordance with their correct mapping. + * Other: - * We have replaced uses of "whitelist" and "blacklist" with "allowlist" - and "denylist" where possible. Please see - https://developers.google.com/style/word-list#blacklist for more context. - + + * We have replaced uses of "whitelist" and "blacklist" with "allowlist" + and "denylist" where possible. Please see + https://developers.google.com/style/word-list#blacklist for more + context. + * Add `tf.config.experimental.mlir_bridge_rollout` which will help us + rollout the new MLIR TPU bridge. + * ## Thanks to our Contributors @@ -500,42 +579,87 @@ stjohnso98, , , , , # Release 2.3.0 ## Major Features and Improvements - * `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources: - * [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot) - * [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service). - In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler. +* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and + save resources: - * [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`). + * [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot) + * [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service). - * [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level. + In addition checkout the detailed + [guide](https://www.tensorflow.org/guide/data_performance_analysis) for + analyzing input pipeline performance with TF Profiler. - * Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers. +* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) + is now a stable API and no longer considered experimental for TensorFlow. + (earlier `tf.distribute.experimental.TPUStrategy`). - * TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental). +* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new + tools: a memory profiler to visualize your model’s memory usage over time + and a [python tracer](https://www.tensorflow.org/guide/profiler#events) + which allows you to trace python function calls in your model. Usability + improvements include better diagnostic messages and + [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) + to customize the host and device trace verbosity level. - * Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds). +* Introduces experimental support for Keras Preprocessing Layers API + ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) + to handle data preprocessing operations, with support for composite tensor + inputs. Please see below for additional details on these layers. - * The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations. +* TFLite now properly supports dynamic shapes during conversion and inference. + We’ve also added opt-in support on Android and iOS for + [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), + a highly optimized set of CPU kernels, as well as opt-in support for + [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental). + +* Libtensorflow packages are available in GCS starting this release. We have + also started to + [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds). + +* The experimental Python API + [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) + now allows you to instrument a TensorFlow program and dump debugging + information to a directory on the file system. The directory can be read and + visualized by a new interactive dashboard in TensorBoard 2.3 called + [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which + reveals the details of the TensorFlow program including graph structures, + history of op executions at the Python (eager) and intra-graph levels, the + runtime dtype, shape, and numerical composition of tensors, as well as their + code locations. ## Breaking Changes -* Increases the **minimum bazel version** required to build TF to **3.1.0**. -* `tf.data` - * Makes the following (breaking) changes to the `tf.data`. - * C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation. - * The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`. - * Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed. - * The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly. -* `tf.keras` - * Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback. -* `tf.image.extract_glimpse` has been updated to correctly process the case - where `centered=False` and `normalized=False`. This is a breaking change as - the output is different from (incorrect) previous versions. Note this - breaking change only impacts `tf.image.extract_glimpse` and - `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of - `tf.compat.v1.image.extract_glimpse` does not change. The behavior of - exsiting C++ kernel `ExtractGlimpse` does not change either, so saved - models using `tf.raw_ops.ExtractGlimpse` will not be impacted. + +* Increases the **minimum bazel version** required to build TF to **3.1.0**. +* `tf.data` + * Makes the following (breaking) changes to the `tf.data`. + * C++ API: - `IteratorBase::RestoreInternal`, + `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` + become pure-virtual and subclasses are now expected to provide an + implementation. + * The deprecated `DatasetBase::IsStateful` method is removed in favor of + `DatasetBase::CheckExternalState`. + * Deprecated overrides of `DatasetBase::MakeIterator` and + `MakeIteratorFromInputElement` are removed. + * The signature of `tensorflow::data::IteratorBase::SaveInternal` and + `tensorflow::data::IteratorBase::SaveInput` has been extended with + `SerializationContext` argument to enable overriding the default policy + for the handling external state during iterator checkpointing. This is + not a backwards compatible change and all subclasses of `IteratorBase` + *need to be updated* accordingly. +* `tf.keras` + * Add a new `BackupAndRestore` callback for handling distributed training + failures & restarts. Please take a look at this + [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) + for details on how to use the callback. +* `tf.image.extract_glimpse` has been updated to correctly process the case + where `centered=False` and `normalized=False`. This is a breaking change as + the output is different from (incorrect) previous versions. Note this + breaking change only impacts `tf.image.extract_glimpse` and + `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of + `tf.compat.v1.image.extract_glimpse` does not change. The behavior of + existing C++ kernel `ExtractGlimpse` does not change either, so saved models + using `tf.raw_ops.ExtractGlimpse` will not be impacted. ## Known Caveats * `tf.lite` @@ -1105,7 +1229,7 @@ This release contains contributions from many people at Google, as well as: 8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee # Release 1.15.0 -This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year. +This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year. ## Major Features and Improvements * As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size. @@ -1115,7 +1239,7 @@ This enables writing forward compatible code: by explicitly importing either `te * Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow. * Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`. * AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS. -* Adds `enable_tensor_equality()`, which switches the behavior such that: +* Adds `enable_tensor_equality()`, which switches the behavior such that: * Tensors are no longer hashable. * Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0. @@ -1271,12 +1395,12 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t * TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow. * Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation. Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs. - + * `tf.contrib`: * `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely. * Remove `tf.contrib.timeseries` dependency on TF distributions. * Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`. - + * `tf.estimator`: * Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`. * Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method. @@ -1284,13 +1408,13 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t * `Estimator.export_savedmodel` has been renamed to `export_saved_model`. * When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`. * Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead. - + * `tf.keras`: * `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs. * `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported. * Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead. * Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer ` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information. - + * `tf.lite`: * Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API. * Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior. @@ -1525,8 +1649,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0 conversion. TensorRT initialization arguments are now passed wrapped in a named-tuple, `TrtConversionParams`, rather than as separate arguments as in `TrtGraphConverter`. - * Changed API to optimize TensorRT enginges during graph optimization. - This is now done by calling `converter.build()` where previously + * Changed API to optimize TensorRT engines during graph optimization. This + is now done by calling `converter.build()` where previously `is_dynamic_op=False` would be set. * `converter.convert()` no longer returns a `tf.function`. Now the function must be accessed from the saved model. @@ -2536,7 +2660,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A * [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier) * The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector) API supports broadcasting for Bijectors with new API changes. - + ## Breaking Changes * If you're opening empty variable scopes; replace `variable_scope('', ...)` by `variable_scope(tf.get_variable_scope(), ...)`. @@ -3015,7 +3139,7 @@ Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan, Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay, Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang, -Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武 +Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武 We are also grateful to all who filed issues or helped resolve them, asked and answered questions, and were part of inspiring discussions. diff --git a/configure.py b/configure.py index 5b9fd55b740..e381c8c20db 100644 --- a/configure.py +++ b/configure.py @@ -1485,6 +1485,7 @@ def main(): 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') config_info_line('mkl', 'Build with MKL support.') + config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.') config_info_line('monolithic', 'Config for mostly static monolithic build.') config_info_line('ngraph', 'Build with Intel nGraph support.') config_info_line('numa', 'Build with NUMA support.') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 0269dae7d72..420c7479e50 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -568,17 +568,7 @@ selects.config_setting_group( # If you need functionality that is not exposed, we will work with you to expand our public APIs. package_group( name = "internal", - packages = [ - "//learning/brain/distribute/...", - "//learning/brain/swift/x10/...", - "//perftools/accelerators/xprof/api/...", - "//tensorflow/...", - "//tensorflow_estimator/python/estimator/...", - "//tensorflow_models/official/...", - "//third_party/py/autograph/...", - "//third_party/swift/tensorflow/x10/...", - "//third_party/swift/tensorflow_apis/...", - ], + packages = ["//tensorflow/..."], ) package_group( @@ -588,10 +578,8 @@ package_group( # Packages that use private types symbols, until they are exported. # TODO(b/154650521) Remove. -package_group( - name = "types_whitelist", - packages = ["//learning/deepmind/tensorflow/replicator/..."], -) +# If this is modified, then copy.bara.sky must also be modified. +package_group(name = "types_whitelist") # Packages that use StructuredTensors. # TODO(b/159007891) Remove this package once StructuredTensor is exported. @@ -719,7 +707,7 @@ tf_cc_shared_object( deps = [ "//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/cc/saved_model:loader_lite_impl", - "//tensorflow/core:core_cpu_impl", + "//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core:framework_internal_impl", "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 3fbd289021f..693bae5cfc7 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -217,6 +217,8 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/kernels:logging_ops", + "//tensorflow/compiler/mlir/tfr:node_expansion_pass", + "//tensorflow/compiler/mlir/tfr:graph_decompose_pass", ], }), alwayslink = 1, @@ -254,6 +256,30 @@ tf_cuda_library( }), ) +cc_library( + name = "tf_shape", + srcs = ["tf_shape.cc"], + hdrs = ["tf_shape.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api_macros", + ":tf_shape_internal", + "//tensorflow/core:framework", + ], +) + +cc_library( + name = "tf_shape_internal", + hdrs = ["tf_shape_internal.h"], + copts = tf_copts(), + visibility = ["//tensorflow:internal"], + deps = [ + ":conversion_macros", + "//tensorflow/core:framework", + ], +) + cc_library( name = "tf_status", srcs = ["tf_status.cc"], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 2e1759ecea0..a03e9227a75 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { return ret; } +void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status) { + using tensorflow::RecordMutation; + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } + status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, + &dst.oper->node, dst.index); + + if (TF_GetCode(status) == TF_OK) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + // TF_Server functions ---------------------------------------------- #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 0b4d9993e4d..db5f8fd68f8 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// Update edge, switch input/ output in a node +TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, + TF_Input dst, TF_Status* status); + // -------------------------------------------------------------------------- // In-process TensorFlow server functionality, for use in distributed training. // A Server instance encapsulates a set of devices and a Session target that diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index bbbbb8f7d56..fc1fdccee16 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -634,6 +634,40 @@ TEST(CAPI, Graph) { TF_DeleteStatus(s); } +TEST(CAPI, UpdateEdge) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Make two scalar constants. + TF_Operation* one = ScalarConst(1, graph, s, "one"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* two = ScalarConst(2, graph, s, "two"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add oper. + TF_Operation* add = Add(one, two, graph, s, "add"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Add another oper to the graph. + TF_Operation* neg = Neg(add, graph, s, "neg"); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + NodeDef node_def_neg; + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("add"), node_def_neg.input(0)); + + // update edge of neg + TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s); + + ASSERT_TRUE(GetNodeDef(neg, &node_def_neg)); + EXPECT_EQ(string("one:0"), node_def_neg.input(0)); + + // Clean up + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + /* TODO(skyewm): this test currently DCHECKs, change to bad status diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index a71fdc8aa06..b90b2644269 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,7 +3,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", - "if_tpu", + "if_libtpu", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", @@ -116,7 +116,6 @@ filegroup( "immediate_execution_context.h", "immediate_execution_operation.h", "immediate_execution_tensor_handle.h", - "mnist_gradients_testutil.h", "tape.h", "tfe_cancellation_manager_internal.h", "tfe_context_internal.h", @@ -290,7 +289,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", - ] + if_tpu( + ] + if_libtpu( if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], if_true = [], ), @@ -314,6 +313,7 @@ cc_library( ":gradients_internal", ":gradients_util", ":tape", + "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", @@ -354,7 +354,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", - ] + if_tpu( + ] + if_libtpu( if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], if_true = [], ), diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 37604951453..5f388bfe0cd 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" -#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU) +#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) #include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #endif #include "tensorflow/core/common_runtime/device.h" @@ -729,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { -#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU) +#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); @@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetThreadLocalDevicePlacementPolicy( + tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy( static_cast(policy)); } @@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy( // safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); return static_cast( - context->GetDevicePlacementPolicy()); + tensorflow::unwrap(ctx)->GetDevicePlacementPolicy()); } TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { @@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return context->FindFunctionDef(name) != nullptr; + return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr; } void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(true); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(false); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); } } // extern "C" diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a58c681e8fe..0afb69bb82c 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy { // Placement policy which silently copies int32 tensors but not other dtypes. TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; -// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h) +// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h) // Sets the default execution mode (sync/async). Note that this can be // overridden per thread using TFE_ContextSetExecutorForThread. diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index 2718c75c3ee..d21cadfd0cb 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) { TestDistributedFunctionCancellation(false); } -TEST(CAPI, DistributedFunctionCancelledOnError) { +// TODO(b/170399182): Update test once an alternative to using the function +// optimization hook is in place. +TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) { TestDistributedFunctionCancellation(true); } diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index eabb159a631..cc2270755bf 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(true); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); } void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetShouldStoreGraphs(false); + tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); } uint64_t TFE_GetContextId(TFE_Context* ctx) { @@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) { } void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetExecutorForThread(executor->executor()); + tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor()); } TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - return new TFE_Executor(&context->Executor()); + return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor()); } void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); auto address_space = tensorflow::DeviceNameUtils::AddressSpace( - context->HostCPU()->parsed_name()); + tensorflow::unwrap(ctx)->HostCPUParsedName()); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); void* data = tensorflow::port::Malloc(str.length()); str.copy(static_cast(data), str.length(), 0); @@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - auto* function_def = context->FindFunctionDef(function_name); + auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name); if (function_def == nullptr) { status->status = tensorflow::errors::NotFound( "Unable to find FunctionDef with name: ", function_name); @@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetAllowSoftPlacement(enable); + tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable); } void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->SetLogDevicePlacement(enable); + tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); } diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index fc9c19d127c..58ffcf247cf 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction( &ctx, incoming_gradients, result); } -Status TapeVSpace::BuildOnesLike(TapeTensor t, +Status TapeVSpace::BuildOnesLike(const TapeTensor& t, AbstractTensorHandle** result) const { AbstractOperationPtr op(ctx_->CreateOperation()); TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr)); diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index d240806c044..f7d80cbeb34 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t); // allow us to trace the data dependencies between operations and hence compute // gradients. // -// This also implements `OnesLike` to create the default -// incoming gradients for tensors which do not already have an incoming -// gradient. -// // `ZerosLike` is not expected to be called and returns a nullptr. The creation // of default zeros grads is handled by the `DefaultGradientFunction` registered // for each op. @@ -233,7 +229,7 @@ class TapeVSpace std::vector* result) const override; // Builds a tensor filled with ones with the same shape and dtype as `t`. - Status BuildOnesLike(TapeTensor t, + Status BuildOnesLike(const TapeTensor& t, AbstractTensorHandle** result) const override; // Looks up the ID of a Gradient. diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 01834879b67..cd4febba8c1 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer)); return Status::OK(); } @@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx, return Status::OK(); } +// Computes +// y = sqrt(inputs[0]) +// return grad(y, {inputs[0]}) +Status SqrtGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + std::vector sqrt_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt")); + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto sqrt_output : sqrt_outputs) { + sqrt_output->Unref(); + } + outputs[0] = out_grads[0]; + delete tape; + return Status::OK(); +} + // Computes // ignored, y = IdentityN(inputs[0], inputs[1]) // return grad(y, {inputs[0], inputs[1]}) @@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) { result_tensor = nullptr; } +TEST_P(CppGradients, TestSqrtGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // y = sqrt(x) + // outputs = tape.gradient(y, x) + std::vector outputs(1); + s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_NEAR(*result_value, 0.5, 0.001); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + TEST_P(CppGradients, TestIdentityNGrad) { // Pseudo-code: // diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 02a3320ef65..a3e3857b34b 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -29,8 +29,25 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +class EagerExecutor; + +// LINT.IfChange +// Note: Keep in sync with exported copy of enum in eager/c_api.h. +enum ContextDevicePlacementPolicy { + // Running operations with input tensors on the wrong device will fail. + DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default policy. + DEVICE_PLACEMENT_SILENT = 2, + // Placement policy which silently copies int32 tensors but not other dtypes. + DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, +}; +// LINT.ThenChange(//tensorflow/c/eager/c_api.h) // Abstract interface to a context. // @@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext { // List attributes of available devices virtual void ListDevices(std::vector* devices) = 0; - virtual void ClearCachesAndThreadExecutors() = 0; - - // Initialize the step resource container for a training step. This is used - // in current TF runtime. For tfrt, it is used by fallback op handler. - virtual void StartStep() = 0; - // Destroy the step resource container for a training step. - virtual void EndStep() = 0; - // Block until all pending nodes are finished. virtual Status AsyncWait() = 0; @@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext { // already exists. virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; + // Find and return a added function by its name. + virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; + + // Return the ParsedName of Host CPU device. + virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; + + // Configure soft device placement policy. + virtual void SetAllowSoftPlacement(bool enable) = 0; + + // Configure device placement policy logging. + virtual void SetLogDevicePlacement(bool enable) = 0; + + // Sets the device placement policy for the current thread. + virtual void SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) = 0; + // Returns the device placement policy for the current thread. + virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; + // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; } + //===--------------------------------------------------------------------===// + // Following are legacy features in TF Eager Runtime. + // TODO(tf-runtime): Figure out a way to deprecate following features after + // migrated to TFRT. + //===--------------------------------------------------------------------===// + // Clear pending nodes in thread executors and kernel caches. + virtual void ClearCachesAndThreadExecutors() = 0; + + // Initialize the step resource container for a training step. This is used + // in current TF runtime. For tfrt, it is used by fallback op handler. + virtual void StartStep() = 0; + // Destroy the step resource container for a training step. + virtual void EndStep() = 0; + + // Return the Eager Executor for current thread. Please note that Eager + // Executor is only used in current TF but not in TFRT. + virtual EagerExecutor& Executor() = 0; + // Update the Eager Executor for current thread. + virtual void SetExecutorForThread(EagerExecutor* executor) = 0; + + // Configure graph collection in RunMetadata. + virtual void SetShouldStoreGraphs(bool value) = 0; + protected: explicit ImmediateExecutionContext(AbstractContextKind kind) : AbstractContext(kind) {} diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index 8354b37354e..6688d9d4e75 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -25,133 +25,18 @@ limitations under the License. #include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -// ========================== Tape Ops ============================== namespace tensorflow { namespace gradients { namespace internal { using std::vector; -using tensorflow::tracing::TracingOperation; - -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr add_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR( - Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(add_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(add_op.get())->SetOpName("my_add")); - } - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); - int num_retvals = 1; - return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. -Status MatMul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - bool transpose_a, bool transpose_b, - const GradientRegistry& registry) { - AbstractOperationPtr matmul_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa(matmul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(matmul_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op)); - TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( - matmul_op.get(), "transpose_a", transpose_a, &forward_op)); - TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( - matmul_op.get(), "transpose_b", transpose_b, &forward_op)); - - int num_retvals = 1; - return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -Status Mul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry) { - AbstractOperationPtr mul_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR( - Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(mul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(mul_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op)); - - int num_retvals = 1; - return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `Relu(inputs[0])` and records it on the tape. -Status Relu(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry) { - AbstractOperationPtr relu_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR( - Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(relu_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(relu_op.get())->SetOpName(name)); - } - TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op)); - int num_retvals = 1; - return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not -// one-hot) and records it on the tape. -Status SparseSoftmaxCrossEntropyWithLogits( - AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry) { - AbstractTensorHandle* scores = inputs[0]; - AbstractTensorHandle* labels = inputs[1]; - - AbstractOperationPtr sm_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa(sm_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(sm_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op)); - TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op)); - - int num_retvals = 2; // returns loss values and backprop - return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} //===================== Test Models to run ========================= @@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. std::vector add_outputs(1); - TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), - registry)); // Compute x+y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add")); std::unordered_map source_tensors_that_are_targets; @@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. vector mm_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute x*y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs, + absl::MakeSpan(mm_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute x*y. std::unordered_map source_tensors_that_are_targets; @@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx, tape->Watch(ToId(W2)); // Watch W2. vector temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 - TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, - absl::MakeSpan(temp_outputs), "relu", - registry)); // Compute Relu(X*W1) + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]}, + absl::MakeSpan(temp_outputs), + "relu")); // Compute Relu(X*W1) - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, - absl::MakeSpan(temp_outputs), "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false, - registry)); // Compute W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::MatMul( + tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), + "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1) AbstractTensorHandle* scores = temp_outputs[0]; temp_outputs.resize(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmax_loss", registry)); // Compute Softmax(Scores,labels) + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmax_loss")); // Compute Softmax(Scores,labels) AbstractTensorHandle* loss_vals = temp_outputs[0]; @@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx, tape->Watch(ToId(W1)); vector temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/true, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/true, + /*transpose_b=*/false)); // Compute X*W1 outputs[0] = temp_outputs[0]; @@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch X vector relu_outputs(1); - TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), - "relu0", registry)); // Relu(X) + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs, + absl::MakeSpan(relu_outputs), + "relu0")); // Relu(X) std::unordered_map source_tensors_that_are_targets; @@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch scores. tape->Watch(ToId(inputs[1])); // Watch labels. vector sm_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0")); std::unordered_map source_tensors_that_are_targets; @@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx, tape->Watch(ToId(W1)); // Watch W1. tape->Watch(ToId(W2)); // Watch W1. vector temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 AbstractTensorHandle* mm = temp_outputs[0]; - TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, - absl::MakeSpan(temp_outputs), // Relu(X*W1) - "relu0", registry)); + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm}, + absl::MakeSpan(temp_outputs), // Relu(X*W1) + "relu0")); AbstractTensorHandle* hidden = temp_outputs[0]; - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, - absl::MakeSpan(temp_outputs), "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false, - registry)); // W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::MatMul( + tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1) AbstractTensorHandle* scores = temp_outputs[0]; temp_outputs.resize(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmaxloss", registry)); // W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmaxloss")); // W2*Relu(X*W1) AbstractTensorHandle* loss = temp_outputs[0]; @@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); vector temp_outputs(1); - TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), - "scalarMul0", registry)); // Compute eta*A + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A}, + absl::MakeSpan(temp_outputs), + "scalarMul0")); // Compute eta*A outputs[0] = temp_outputs[0]; @@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx, TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 outputs[0] = temp_outputs[0]; delete tape; @@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx, TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector temp_outputs(1); - TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs), - "mul0", registry)); // Compute x*y + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y}, + absl::MakeSpan(temp_outputs), + "mul0")); // Compute x*y outputs[0] = temp_outputs[0]; delete tape; @@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx, TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector temp_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss", - registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss")); outputs[0] = temp_outputs[0]; // loss values diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index 1cf87bb9dee..b173446ac9b 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -29,45 +29,10 @@ limitations under the License. #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/status.h" -// ========================== Tape Ops ============================== namespace tensorflow { namespace gradients { namespace internal { -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. -Status MatMul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - bool transpose_a, bool transpose_b, - const GradientRegistry& registry); - -// Computes `inputs[0] * inputs[1]` and records it on the tape. -Status Mul(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry); - -// Computes `Relu(inputs[0])` and records it on the tape. -Status Relu(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry); - -// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the -// tape. -Status SparseSoftmaxCrossEntropyWithLogits( - AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, const char* name, - const GradientRegistry& registry); - -// ====================== End Tape Ops ============================ // Computes // y = inputs[0] + inputs[1] diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 80db01c7924..efab4dfbeb2 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -100,7 +100,8 @@ class VSpace { std::vector* result) const = 0; // Builds a tensor filled with ones with the same shape and dtype as `t`. - virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0; + virtual Status BuildOnesLike(const TapeTensor& t, + Gradient** result) const = 0; // Looks up the ID of a Gradient. virtual int64 TensorId(Gradient* tensor) const = 0; diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD index cd697906121..765c4e5f06e 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD @@ -29,6 +29,7 @@ cc_library( }), deps = [ "//tensorflow/c:env", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "//third_party/hadoop:hdfs", diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc index ac82ff6df12..5ff28e4229a 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc @@ -22,11 +22,10 @@ limitations under the License. #include #include -#include "absl/synchronization/mutex.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" -#include "third_party/hadoop/hdfs.h" // Implementation of a filesystem for HADOOP environments. // This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes. @@ -149,15 +148,20 @@ class LibHDFS { char* hdfs_home = getenv("HADOOP_HDFS_HOME"); if (hdfs_home != nullptr) { auto JoinPath = [](std::string home, std::string lib) { +#if defined(_WIN32) + if (home.back() != '\\') home.push_back('\\'); + return home + "lib\\native\\" + lib; +#else if (home.back() != '/') home.push_back('/'); return home + "lib/native/" + lib; +#endif }; std::string path = JoinPath(hdfs_home, kLibHdfsDso); TryLoadAndBind(path.c_str(), &handle_, status); if (TF_GetCode(status) == TF_OK) { return; } else { - std::cerr << "HadoopFileSystem load error: " << TF_Message(status); + TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status)); } } @@ -169,13 +173,15 @@ class LibHDFS { void* handle_; }; -// We rely on HDFS connection caching here. The HDFS client calls -// org.apache.hadoop.fs.FileSystem.get(), which caches the connection -// internally. -hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) { +// We implement connection caching in Tensorflow, which can significantly +// improve performance. Fixes #43187 +hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, + const std::string& path, TF_Status* status) { + auto libhdfs = hadoop_file->libhdfs; std::string scheme, namenode, hdfs_path; ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + std::string cacheKey(scheme); hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); if (scheme == "file") { libhdfs->hdfsBuilderSetNameNode(builder, nullptr); @@ -200,15 +206,24 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) { SplitArchiveNameAndPath(&path_har, &namenode, status); if (TF_GetCode(status) != TF_OK) return nullptr; libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); + cacheKey += namenode; } else { libhdfs->hdfsBuilderSetNameNode( builder, namenode.empty() ? "default" : namenode.c_str()); + cacheKey += namenode; } - auto fs = libhdfs->hdfsBuilderConnect(builder); - if (fs == nullptr) - TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); - else - TF_SetStatus(status, TF_OK, ""); + absl::MutexLock l(&hadoop_file->connection_cache_lock); + if (hadoop_file->connection_cache.find(cacheKey) == + hadoop_file->connection_cache.end()) { + auto cacheFs = libhdfs->hdfsBuilderConnect(builder); + if (cacheFs == nullptr) { + TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); + return cacheFs; + } + hadoop_file->connection_cache[cacheKey] = cacheFs; + } + auto fs = hadoop_file->connection_cache[cacheKey]; + TF_SetStatus(status, TF_OK, ""); return fs; } @@ -409,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) { // SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` // ---------------------------------------------------------------------------- namespace tf_read_only_memory_region { - -// TODO(vnvo2409): Implement later - +// Hadoop doesn't support Readonly Memory Region } // namespace tf_read_only_memory_region // SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem // ---------------------------------------------------------------------------- namespace tf_hadoop_filesystem { +HadoopFile::HadoopFile(TF_Status* status) + : libhdfs(new LibHDFS(status)), + connection_cache_lock(), + connection_cache() {} + void Init(TF_Filesystem* filesystem, TF_Status* status) { - filesystem->plugin_filesystem = new LibHDFS(status); + filesystem->plugin_filesystem = new HadoopFile(status); if (TF_GetCode(status) != TF_OK) return; TF_SetStatus(status, TF_OK, ""); } void Cleanup(TF_Filesystem* filesystem) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; delete libhdfs; + delete hadoop_file; } void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, TF_RandomAccessFile* file, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -448,8 +469,9 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, void NewWritableFile(const TF_Filesystem* filesystem, const char* path, TF_WritableFile* file, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -465,8 +487,9 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path, void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, TF_WritableFile* file, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -497,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, void PathExists(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -513,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path, void Stat(const TF_Filesystem* filesystem, const char* path, TF_FileStatistics* stats, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -532,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path, int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return -1; std::string scheme, namenode, hdfs_path; @@ -553,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, void DeleteFile(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -568,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path, void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -583,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path; @@ -619,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path, void RenameFile(const TF_Filesystem* filesystem, const char* src, const char* dst, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, src, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, src, status); if (TF_GetCode(status) != TF_OK) return; std::string scheme, namenode, hdfs_path_src, hdfs_path_dst; @@ -640,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src, int GetChildren(const TF_Filesystem* filesystem, const char* path, char*** entries, TF_Status* status) { - auto libhdfs = static_cast(filesystem->plugin_filesystem); - auto fs = Connect(libhdfs, path, status); + auto hadoop_file = static_cast(filesystem->plugin_filesystem); + auto libhdfs = hadoop_file->libhdfs; + auto fs = Connect(hadoop_file, path, status); if (TF_GetCode(status) != TF_OK) return -1; std::string scheme, namenode, hdfs_path; @@ -677,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path, return num_entries; } -// TODO(vnvo2409): Implement later +static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { + return strdup(uri); +} } // namespace tf_hadoop_filesystem @@ -685,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri) { TF_SetFilesystemVersionMetadata(ops); ops->scheme = strdup(uri); + + ops->random_access_file_ops = static_cast( + plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE)); + ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup; + ops->random_access_file_ops->read = tf_random_access_file::Read; + + ops->writable_file_ops = static_cast( + plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); + ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; + ops->writable_file_ops->append = tf_writable_file::Append; + ops->writable_file_ops->tell = tf_writable_file::Tell; + ops->writable_file_ops->flush = tf_writable_file::Flush; + ops->writable_file_ops->sync = tf_writable_file::Sync; + ops->writable_file_ops->close = tf_writable_file::Close; + + ops->filesystem_ops = static_cast( + plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); + ops->filesystem_ops->init = tf_hadoop_filesystem::Init; + ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup; + ops->filesystem_ops->new_random_access_file = + tf_hadoop_filesystem::NewRandomAccessFile; + ops->filesystem_ops->new_writable_file = + tf_hadoop_filesystem::NewWritableFile; + ops->filesystem_ops->new_appendable_file = + tf_hadoop_filesystem::NewAppendableFile; + ops->filesystem_ops->new_read_only_memory_region_from_file = + tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile; + ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists; + ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat; + ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize; + ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile; + ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir; + ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir; + ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile; + ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren; + ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName; } void TF_InitPlugin(TF_FilesystemPluginInfo* info) { diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h index 8de66c05bac..06b91a68123 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h @@ -15,10 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ +#include #include +#include "absl/synchronization/mutex.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/tf_status.h" +#include "third_party/hadoop/hdfs.h" void ParseHadoopPath(const std::string& fname, std::string* scheme, std::string* namenode, std::string* path); @@ -43,6 +46,14 @@ void Close(const TF_WritableFile* file, TF_Status* status); } // namespace tf_writable_file namespace tf_hadoop_filesystem { +typedef struct HadoopFile { + LibHDFS* libhdfs; + absl::Mutex connection_cache_lock; + std::map connection_cache + ABSL_GUARDED_BY(connection_cache_lock); + HadoopFile(TF_Status* status); +} HadoopFile; + void Init(TF_Filesystem* filesystem, TF_Status* status); void Cleanup(TF_Filesystem* filesystem); void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc index 77079fb5325..df85ba9e4dd 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem_test.cc @@ -352,6 +352,48 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) { EXPECT_TF_OK(status_); } +TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) { + static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1"; + putenv(set_disable_var); + + const std::string path = TmpDir("ReadWhileOverwriting"); + if (path.find_first_of("hdfs://") != 0) GTEST_SKIP(); + + const string content1 = "content1"; + WriteString(path, content1); + ASSERT_TF_OK(status_); + + auto reader = GetReader(); + tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + EXPECT_TF_OK(status_); + + std::string result; + result.resize(content1.size()); + auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(), + &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content1, result); + + tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + + string content2 = "overwrite"; + WriteString(path, content1 + content2); + ASSERT_TF_OK(status_); + + result.resize(content2.size()); + read = tf_random_access_file::Read(reader.get(), content1.size(), + content2.size(), &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(0, result.size()); + + static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0"; + putenv(set_enable_var); +} + TEST_F(HadoopFileSystemTest, HarSplit) { const std::string har_path = "har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt"; diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index c2aa9caf814..5cba7b28fda 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -24,6 +24,7 @@ using std::vector; using tensorflow::ops::Conj; using tensorflow::ops::MatMul; using tensorflow::ops::Mul; +using tensorflow::ops::SqrtGrad; namespace tensorflow { namespace gradients { @@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction { AbstractTensorHandlePtr exp_; }; +class SqrtGradientFunction : public GradientFunction { + public: + explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) { + sqrt->Ref(); + } + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + std::string name = "Sqrt_Grad"; + grad_outputs->resize(1); + TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]}, + absl::MakeSpan(*grad_outputs), name.c_str())); + return Status::OK(); + } + ~SqrtGradientFunction() override {} + + private: + AbstractTensorHandlePtr sqrt_; +}; + class MatMulGradientFunction : public GradientFunction { public: explicit MatMulGradientFunction(vector f_inputs, @@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) { return new BackwardFunction(gradient_function, default_gradients); } +BackwardFunction* SqrtRegisterer(const ForwardOperation& op) { + auto gradient_function = new SqrtGradientFunction(op.outputs[0]); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h index 205419e1201..7faeadcca81 100644 --- a/tensorflow/c/experimental/gradients/math_grad.h +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -19,10 +19,13 @@ limitations under the License. namespace tensorflow { namespace gradients { + BackwardFunction* AddRegisterer(const ForwardOperation& op); BackwardFunction* ExpRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op); +BackwardFunction* SqrtRegisterer(const ForwardOperation& op); + } // namespace gradients } // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ \ No newline at end of file +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD index e57cc11b84f..bada49ea669 100644 --- a/tensorflow/c/experimental/gradients/tape/BUILD +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -38,3 +38,29 @@ cc_library( "//tensorflow/c/eager:gradients_internal", ], ) + +cc_library( + name = "tape", + hdrs = [ + "tape_context.h", + "tape_operation.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tape_context", + ":tape_operation", + ], +) + +filegroup( + name = "pywrap_required_hdrs", + srcs = [ + "tape_context.h", + "tape_operation.h", + ], + visibility = [ + "//tensorflow:internal", + ], +) diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc index 3ff10923d48..20aab8a77d3 100644 --- a/tensorflow/c/experimental/ops/math_ops.cc +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span inputs, return exp_op->Execute(outputs, &num_retvals); } +Status Sqrt(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sqrt_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name)); + TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0])); + + int num_retvals = 1; + Status s = sqrt_op->Execute(outputs, &num_retvals); + return s; +} + +Status SqrtGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name)); + TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1])); + + int num_retvals = 1; + Status s = sqrt_grad_op->Execute(outputs, &num_retvals); + return s; +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h index be324c0b55a..7051e38656f 100644 --- a/tensorflow/c/experimental/ops/math_ops.h +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx, Status Exp(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); + +Status Sqrt(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status SqrtGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index bc532506a46..4cf868e4714 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -91,15 +91,24 @@ cc_library( ":signature_def_function_metadata", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/types:span", ], ) cc_library( name = "signature_def_function_metadata", + srcs = [ + "signature_def_function_metadata.cc", + ], hdrs = [ "signature_def_function_metadata.h", ], + deps = [ + ":tensor_spec", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], ) cc_library( @@ -268,6 +277,20 @@ tf_cc_test( ], ) +cc_library( + name = "tensor_spec", + srcs = [ + "tensor_spec.cc", + ], + hdrs = [ + "tensor_spec.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], +) + tf_cc_test( name = "tf_concrete_function_loading_test", srcs = [ diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index eaccc2fac69..ac168830a0e 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -92,6 +92,8 @@ cc_library( "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/c/experimental/saved_model/core:tensor_spec", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", @@ -164,6 +166,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/lib/llvm_rtti", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc index 5cc06e6c54f..1c615405644 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" +#include #include +#include #include #include "absl/types/span.h" @@ -30,14 +32,26 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" namespace tensorflow { namespace { +using StructuredValueDictEntry = + protobuf::MapPair; + +using NamedParamMap = + gtl::FlatMap; + Status AssertAllCreateResourceFunctionsHaveNoCaptures( const PartiallyRevivedObjects& objects) { for (const auto& id_and_resource : objects.restored_resources) { @@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph, } } +std::vector SignatureDefParamsFromNamedParamMap( + const NamedParamMap& params) { + // The underlying functiondef associated with the SignatureDef has + // nest.flattened inputs and outputs, which are sorted by string key. + std::vector result; + result.reserve(params.size()); + for (const auto& named_param : params) { + result.push_back(SignatureDefParam(std::string(named_param.first), + TensorSpec(*named_param.second))); + } + std::sort(result.begin(), result.end(), + [](const SignatureDefParam& x, const SignatureDefParam& y) { + return x.name() < y.name(); + }); + + return result; +} + +// SignatureDefArgsFromInputs takes the "canonicalized_input_signature" +// field of a SavedConcreteFunction, ensures it conforms to the structure of +// tuple(tuple(), dict()), and "returns" a list of +// SignatureDefParams of the SignatureDefFunction's arguments. +Status SignatureDefArgsFromInputs( + const StructuredValue& canonicalized_input_signature, + std::vector* out) { + // Note(bmzhao): canonicalized_input_signature should be a tuple of + // (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of + // string keys to TensorSpecs. + if (!canonicalized_input_signature.has_tuple_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature should be " + "of form tuple(tuple(), dict()), but was instead: \n", + canonicalized_input_signature.DebugString()); + } + + const TupleValue& args_kwargs_tuple = + canonicalized_input_signature.tuple_value(); + if (args_kwargs_tuple.values_size() != 2) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature should be " + "a tuple of two elements (args, kwargs), but was instead: \n", + args_kwargs_tuple.DebugString()); + } + + const StructuredValue& args = args_kwargs_tuple.values(0); + if (!args.has_tuple_value() || !args.tuple_value().values().empty()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's args" + "should be an empty tuple, but instead got: \n", + args.DebugString()); + } + + const StructuredValue& kwargs = args_kwargs_tuple.values(1); + if (!kwargs.has_dict_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's kwargs" + "should be a dictionary, but instead got: \n", + kwargs.DebugString()); + } + + const DictValue& kwargs_dict = kwargs.dict_value(); + NamedParamMap result; + result.reserve(kwargs_dict.fields_size()); + + for (const auto& key_value : kwargs_dict.fields()) { + const std::string& key = key_value.first; + const StructuredValue& value = key_value.second; + if (!value.has_tensor_spec_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's canonicalized_input_signature's kwargs" + "dictionary contained a non-tensorspec value for key-value pair: \n", + "Key: ", key, "Value: \n", value.DebugString()); + } + result[key] = &value.tensor_spec_value(); + } + + *out = SignatureDefParamsFromNamedParamMap(result); + + return Status(); +} + +// SignatureDefReturnsFromOutputs takes the "output_signature" field of a +// SavedConcreteFunction, ensures it conforms to the structure of +// dict(), and "returns" a list of SignatureDefParams of the +// SignatureDefFunction's returns. +Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature, + std::vector* out) { + if (!output_signature.has_dict_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's output_signature must be a dictionary, but " + "instead got: ", + output_signature.DebugString()); + } + + const DictValue& output_dict = output_signature.dict_value(); + NamedParamMap result; + result.reserve(output_dict.fields_size()); + + for (const auto& key_value : output_dict.fields()) { + const std::string& key = key_value.first; + const StructuredValue& value = key_value.second; + if (!value.has_tensor_spec_value()) { + return errors::FailedPrecondition( + "SignatureDefFunction's output_signature dictionary contained a " + "non-tensorspec value for key-value pair: \n", + "Key: ", key, "Value: \n", value.DebugString()); + } + result[key] = &value.tensor_spec_value(); + } + *out = SignatureDefParamsFromNamedParamMap(result); + + return Status(); +} + +// The implementation takes advantage of the fact that SignatureDefFunction's +// "traced" Signature wrapper function always has inputs/outputs of dictionaries +// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126 +// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178 +// Additionally, we take advantage of the fact that the SignatureDefFunction's +// associated functiondef has lexicographically ordered inputs/outputs due to +// nest.flatten. +Status LoadSignatureDefFunctionMetadata( + const SavedConcreteFunction& saved_concrete_function, + SignatureDefFunctionMetadata* out) { + std::vector args; + TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs( + saved_concrete_function.canonicalized_input_signature(), &args)); + + std::vector rets; + TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs( + saved_concrete_function.output_signature(), &rets)); + + *out = SignatureDefFunctionMetadata(std::move(args), std::move(rets)); + return Status(); +} + // This function finds the necessary captures, then forwards to the builder // method Status CreateConcreteFunction(ImmediateExecutionContext* ctx, @@ -162,10 +312,14 @@ Status CreateSignatureDefFunction( &capture_handle)); captures.push_back(capture_handle); } - // TODO(bmzhao): Create Metadata here + + SignatureDefFunctionMetadata metadata; + TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata( + *builder.saved_concrete_func, &metadata)); + return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef, /*captures=*/std::move(captures), - /*metadata=*/{}, + /*metadata=*/std::move(metadata), /*ctx=*/ctx, /*out=*/out); } @@ -378,6 +532,7 @@ Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx, revived->variables = std::move(variables); revived->assets = std::move(assets); revived->constants = std::move(constants); + revived->signatures_map = std::move(signatures_map); // 3b. Move over resources. TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived)); diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h index cce52748a1d..78960b8c95f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h @@ -36,7 +36,14 @@ namespace tensorflow { // Notably, resources and functions can be in a state where they reference // other resources/functions that have not been constructed yet. We collect // *all* objects in a partially valid state here, then properly initialize -// resources and functions. +// resources and functions. Implementation-wise, PartiallyRevivedObjects +// contains maps keyed by the node number of the SavedObjectGraph, and map to an +// object of the corresponding type. So, if node 2 in the object graph is a +// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a +// tensorflow::Variable object. The only exception to this is the +// "signatures_map", which is keyed by the "signature" key +// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89), +// and maps to the SignatureDefFunction node in the SavedObjectGraph. struct PartiallyRevivedObjects { gtl::FlatMap> variables; gtl::FlatMap> assets; @@ -44,6 +51,7 @@ struct PartiallyRevivedObjects { gtl::FlatMap concrete_functions; gtl::FlatMap signature_def_functions; gtl::FlatMap restored_resources; + gtl::FlatMap signatures_map; Status Build(ImmediateExecutionContext* ctx, const SavedObjectGraph& obj_graph, RevivedObjects* revived); diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h index 3b8fb0f6a8a..cc9be0b937d 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h @@ -44,6 +44,7 @@ struct RevivedObjects { gtl::FlatMap> signature_def_functions; gtl::FlatMap restored_resources; + gtl::FlatMap signatures_map; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc index a212c25bd28..2ede228e4ed 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" @@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) { return internal::ReadVariable(ctx_, handle_.get(), dtype_, out); } -Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx, - DataType dtype, TensorShape shape, - absl::optional name, - const char* raw_device_name, - std::unique_ptr* output) { +Status Variable::CreateUninitialized( + ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + absl::optional name, const char* raw_device_name, + const std::vector& component_devices, + std::unique_ptr* output) { ImmediateTensorHandlePtr handle; - TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( - ctx, dtype, shape, raw_device_name, &handle)); + if (component_devices.empty()) { + TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( + ctx, dtype, shape, raw_device_name, &handle)); + output->reset( + new Variable(ctx, dtype, shape, std::move(name), std::move(handle))); + return Status(); + } + + if (!tensorflow::isa(ctx)) { + return errors::InvalidArgument( + "Can only load distributed variables with EagerContext."); + } + + EagerContext* eager_ctx = reinterpret_cast(ctx); + + std::vector handles; + for (const auto& device : component_devices) { + ImmediateTensorHandlePtr handlePtr; + TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( + ctx, dtype, shape, device.empty() ? nullptr : device.c_str(), + &handlePtr)); + if (!tensorflow::isa(handlePtr.get())) { + return errors::Internal("Returned replica handle has unsupported type."); + } + handles.push_back(reinterpret_cast(handlePtr.release())); + } + TensorHandle* packed_handle; + TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle( + std::move(handles), eager_ctx, &packed_handle)); + // The call to `CreatePackedHandle` incremented the handles' reference count, + // which we must now decrement to make the packed handle the owner of those + // handles. We can't loop through the `handles` vector because it was + // `std::move`d in the call above. + for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) { + TensorHandle* component; + TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component)); + component->Unref(); + } + + handle.reset(packed_handle); output->reset( new Variable(ctx, dtype, shape, std::move(name), std::move(handle))); return Status(); diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h index 13f56fda5f3..6d630b54562 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h @@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible { public: // Creates an uninitialized resource variable. Note that a caller must // call "assign" to associate a value with the variable. - static Status CreateUninitialized(ImmediateExecutionContext* ctx, - DataType dtype, TensorShape shape, - absl::optional name, - const char* raw_device_name, - std::unique_ptr* output); + static Status CreateUninitialized( + ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + absl::optional name, const char* raw_device_name, + const std::vector& component_devices, + std::unique_ptr* output); // The dtype of the underlying variable. DataType dtype(); diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index ef75f1a382b..2a4297e2b67 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, const std::string& name = variable.name(); tensorflow::TensorShape shape(variable.shape()); tensorflow::DataType dtype = variable.dtype(); + std::vector component_devices; + + for (const auto& component : + variable.experimental_distributed_variable_components()) { + component_devices.push_back(component.device()); + } TF_RETURN_IF_ERROR(Variable::CreateUninitialized( ctx, dtype, shape, name, - variable.device().empty() ? nullptr : variable.device().c_str(), output)); + variable.device().empty() ? nullptr : variable.device().c_str(), + component_devices, output)); return Status(); } @@ -519,6 +526,8 @@ Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, } } + objects->signatures_map = std::move(signatures_map); + return Status(); } diff --git a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc index 45b0ac00c9b..a5a4e900843 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc @@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { Status status; std::unique_ptr var; TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape, - absl::nullopt, nullptr, &var)); + absl::nullopt, nullptr, {}, &var)); // Create a TensorHandle ImmediateTensorHandlePtr expected_handle = diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc new file mode 100644 index 00000000000..4e455f08f49 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.cc @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +namespace tensorflow { + +SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec) + : name_(std::move(name)), spec_(std::move(spec)) {} + +const std::string& SignatureDefParam::name() const { return name_; } + +const TensorSpec& SignatureDefParam::spec() const { return spec_; } + +SignatureDefFunctionMetadata::SignatureDefFunctionMetadata( + std::vector arguments, + std::vector returns) + : arguments_(std::move(arguments)), returns_(std::move(returns)) {} + +const std::vector& SignatureDefFunctionMetadata::arguments() + const { + return arguments_; +} + +const std::vector& SignatureDefFunctionMetadata::returns() + const { + return returns_; +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h index 5a579676d4e..e9cc0b11b00 100644 --- a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h @@ -16,10 +16,42 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#include +#include + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/struct.pb.h" + namespace tensorflow { +// SignatureDefParam represents a named Tensor input or output to a +// SignatureDefFunction. +class SignatureDefParam { + public: + SignatureDefParam(std::string name, TensorSpec spec); + + const std::string& name() const; + + const TensorSpec& spec() const; + + private: + std::string name_; + TensorSpec spec_; +}; + class SignatureDefFunctionMetadata { - // TODO(bmzhao): Fill in with fields as necessary + public: + SignatureDefFunctionMetadata() = default; + SignatureDefFunctionMetadata(std::vector arguments, + std::vector returns); + + const std::vector& arguments() const; + const std::vector& returns() const; + + private: + std::vector arguments_; + std::vector returns_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tensor_spec.cc b/tensorflow/c/experimental/saved_model/core/tensor_spec.cc new file mode 100644 index 00000000000..4d68ec73b1b --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tensor_spec.cc @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +TensorSpec::TensorSpec() + : shape_(std::initializer_list()), dtype_(DT_FLOAT) {} + +TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype) + : shape_(std::move(shape)), dtype_(dtype) {} + +TensorSpec::TensorSpec(const TensorSpecProto& proto) + : shape_(proto.shape()), dtype_(proto.dtype()) {} + +const PartialTensorShape& TensorSpec::shape() const { return shape_; } + +DataType TensorSpec::dtype() const { return dtype_; } + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tensor_spec.h b/tensorflow/c/experimental/saved_model/core/tensor_spec.h new file mode 100644 index 00000000000..dcdff8900bd --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tensor_spec.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +// Note(bmzhao): TensorSpec deliberately does not store the "name" from a +// TensorSpecProto. From edloper@, "Names should really be associated with +// parameters, not the tensors inside those parameters. This would be +// inconsistent with the corresponding Python class, but I don't think that's +// necessarily a problem. If it turns out later that we really need a name +// attribute here, we can always add it back in; but let's see how far we can +// get without it." +class TensorSpec { + public: + // Constructs a scalar, DT_FLOAT TensorSpec + TensorSpec(); + + TensorSpec(PartialTensorShape shape, DataType dtype); + + explicit TensorSpec(const TensorSpecProto& proto); + + const PartialTensorShape& shape() const; + DataType dtype() const; + + private: + PartialTensorShape shape_; + DataType dtype_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 6386d0dbb79..f0990235963 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -192,9 +192,23 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path, Status TFSavedModelAPI::GetSignatureDefFunction( const std::string& signature_def_key, SignatureDefFunction** function) { - // TODO(bmzhao): Add support for retrieving a signaturedef function. - return errors::Unimplemented( - "Retrieving SignatureDef functions is unimplemented currently"); + auto signatures_iter = + revived_objects_.signatures_map.find(signature_def_key); + if (signatures_iter == revived_objects_.signatures_map.end()) { + return errors::NotFound("No signature with key ", signature_def_key, + " was found"); + } + int node = signatures_iter->second; + + auto function_iter = revived_objects_.signature_def_functions.find(node); + if (function_iter == revived_objects_.signature_def_functions.end()) { + return errors::Internal( + "Unable to find SignatureDefFunction associated with key ", + signature_def_key, " despite key being valid."); + } + + *function = function_iter->second.get(); + return Status(); } std::vector TFSavedModelAPI::ListFunctions() { diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 091f5c86967..06fbc7aef0a 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -224,6 +224,8 @@ cc_library( ], deps = [ ":signature_def_function_metadata_type", + ":signature_def_param_list", + ":signature_def_param_list_type", "//tensorflow/c:c_api_macros", "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", ], @@ -240,6 +242,104 @@ cc_library( ], ) +cc_library( + name = "signature_def_param", + srcs = [ + "signature_def_param.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_param.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_param_type", + ":tensor_spec", + ":tensor_spec_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_shape_internal", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "signature_def_param_type", + hdrs = [ + "signature_def_param_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "signature_def_param_list", + srcs = [ + "signature_def_param_list.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_param_list.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_param", + ":signature_def_param_list_type", + ":signature_def_param_type", + "//tensorflow/c:c_api_macros", + ], +) + +cc_library( + name = "signature_def_param_list_type", + hdrs = [ + "signature_def_param_list_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "tensor_spec", + srcs = [ + "tensor_spec.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:tensor_spec.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":tensor_spec_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_shape", + "//tensorflow/c:tf_shape_internal", + "//tensorflow/c/experimental/saved_model/core:tensor_spec", + ], +) + +cc_library( + name = "tensor_spec_type", + hdrs = [ + "tensor_spec_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c:tf_shape_internal", + "//tensorflow/c/experimental/saved_model/core:tensor_spec", + ], +) + tf_cc_test( name = "saved_model_api_test", size = "small", @@ -252,6 +352,8 @@ tf_cc_test( ], deps = [ ":saved_model_api_type", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_shape", "//tensorflow/c:tf_status", "//tensorflow/c:tf_tensor", "//tensorflow/c/eager:c_api", @@ -260,6 +362,11 @@ tf_cc_test( "//tensorflow/c/experimental/saved_model/core:tf_saved_model_api", "//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/c/experimental/saved_model/public:signature_def_function", + "//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata", + "//tensorflow/c/experimental/saved_model/public:signature_def_param", + "//tensorflow/c/experimental/saved_model/public:signature_def_param_list", + "//tensorflow/c/experimental/saved_model/public:tensor_spec", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index eca1885a6fe..5a4f676ec06 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -24,6 +24,13 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_shape.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/io/path.h" @@ -143,6 +150,146 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { TFE_DeleteContext(ctx); } +// This tests running the "serving_default" SignatureDefFunction from the +// VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs +// protobuf in the metagraph looks like: +// signature_def: { +// key : "serving_default" +// value: { +// inputs: { +// key : "a" +// value: { +// name : "serving_default_a:0" +// dtype: DT_FLOAT +// tensor_shape: { +// } +// } +// } +// inputs: { +// key : "b" +// value: { +// name : "serving_default_b:0" +// dtype: DT_FLOAT +// tensor_shape: { +// } +// } +// } +// outputs: { +// key : "output_0" +// value: { +// name : "StatefulPartitionedCall:0" +// dtype: DT_FLOAT +// tensor_shape: { +// } +// } +// } +// method_name: "tensorflow/serving/predict" +// } +// } +TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_SignatureDefFunction* serving_default = + TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default", + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_SignatureDefFunctionMetadata* metadata = + TF_SignatureDefFunctionGetMetadata(serving_default); + + const TF_SignatureDefParamList* args = + TF_SignatureDefFunctionMetadataArgs(metadata); + const TF_SignatureDefParamList* returns = + TF_SignatureDefFunctionMetadataReturns(metadata); + + EXPECT_EQ(TF_SignatureDefParamListSize(args), 2); + const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0); + const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a); + const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a); + + // Input "a" is a scalar, float32 tensor + EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a))); + EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a)); + EXPECT_EQ(0, TF_ShapeDims(shape_a)); + + const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1); + const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b); + const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b); + + // Input "b" is a scalar, float32 tensor + EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b))); + EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b)); + EXPECT_EQ(0, TF_ShapeDims(shape_b)); + + EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1); + + const TF_SignatureDefParam* param_out = + TF_SignatureDefParamListGet(returns, 0); + const TF_TensorSpec* tensor_spec_out = + TF_SignatureDefParamTensorSpec(param_out); + const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out); + + // Output "output_0" is a scalar, float32 tensor + EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out))); + EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out)); + EXPECT_EQ(0, TF_ShapeDims(shape_out)); + + std::vector compute_fn_inputs; + TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f); + TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f); + compute_fn_inputs.push_back(input_a); + compute_fn_inputs.push_back(input_b); + + TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp( + serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(), + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + std::vector compute_fn_outputs( + TF_SignatureDefParamListSize(returns)); + int num_retvals = TF_SignatureDefParamListSize(returns); + + TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals, + status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + float output_value = *static_cast(TF_TensorData(result)); + // (1 + 2) * (2 + 1) / 3 + 5 should be 8 + EXPECT_FLOAT_EQ(output_value, 8.0); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(compute_fn_outputs[0]); + TFE_DeleteTensorHandle(input_a); + TFE_DeleteTensorHandle(input_b); + TFE_DeleteOp(serving_default_op); + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc index c5c3616211c..1c547a94155 100644 --- a/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc @@ -16,5 +16,18 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h" -// TODO(bmzhao): Add getter functions here as necessary. +extern "C" { + +extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataArgs( + const TF_SignatureDefFunctionMetadata* list) { + return tensorflow::wrap(&tensorflow::unwrap(list)->arguments()); +} + +extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataReturns( + const TF_SignatureDefFunctionMetadata* list) { + return tensorflow::wrap(&tensorflow::unwrap(list)->returns()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_param.cc new file mode 100644 index 00000000000..ac54f8f5700 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" + +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h" +#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h" + +extern "C" { + +extern const char* TF_SignatureDefParamName(const TF_SignatureDefParam* param) { + return tensorflow::unwrap(param)->name().c_str(); +} + +extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec( + const TF_SignatureDefParam* param) { + return tensorflow::wrap(&tensorflow::unwrap(param)->spec()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param_list.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list.cc new file mode 100644 index 00000000000..328f21635c3 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" + +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h" + +extern "C" { + +extern size_t TF_SignatureDefParamListSize( + const TF_SignatureDefParamList* list) { + return tensorflow::unwrap(list)->size(); +} + +extern const TF_SignatureDefParam* TF_SignatureDefParamListGet( + const TF_SignatureDefParamList* list, int i) { + return tensorflow::wrap(&tensorflow::unwrap(list)->at(i)); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h new file mode 100644 index 00000000000..6f535110cee --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefParamList TF_SignatureDefParamList; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(std::vector, + TF_SignatureDefParamList) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h new file mode 100644 index 00000000000..fd634bcddb0 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefParam TF_SignatureDefParam; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/tensor_spec.cc b/tensorflow/c/experimental/saved_model/internal/tensor_spec.cc new file mode 100644 index 00000000000..f310adef449 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensor_spec.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h" +#include "tensorflow/c/tf_shape_internal.h" + +extern "C" { + +TF_DataType TF_TensorSpecDataType(const TF_TensorSpec* spec) { + return static_cast(tensorflow::unwrap(spec)->dtype()); +} + +const TF_Shape* TF_TensorSpecShape(const TF_TensorSpec* spec) { + return tensorflow::wrap(&tensorflow::unwrap(spec)->shape()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h b/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h new file mode 100644 index 00000000000..7284c8a8fb2 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" + +typedef struct TF_TensorSpec TF_TensorSpec; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index 33c3ce91aef..4198b0e7ee7 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -28,6 +28,9 @@ exports_files( "saved_model_api.h", "signature_def_function.h", "signature_def_function_metadata.h", + "signature_def_param.h", + "signature_def_param_list.h", + "tensor_spec.h", ], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], ) @@ -45,6 +48,9 @@ cc_library( ":saved_model_api", ":signature_def_function", ":signature_def_function_metadata", + ":signature_def_param", + ":signature_def_param_list", + ":tensor_spec", ], ) @@ -77,3 +83,18 @@ alias( name = "signature_def_function_metadata", actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata", ) + +alias( + name = "signature_def_param", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param", +) + +alias( + name = "signature_def_param_list", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param_list", +) + +alias( + name = "tensor_spec", + actual = "//tensorflow/c/experimental/saved_model/internal:tensor_spec", +) diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h index cedb9de66b8..68f1ece2991 100644 --- a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -23,6 +23,9 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" // IWYU pragma: end_exports #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h index 6f4459732c4..b7a7f67eb19 100644 --- a/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h +++ b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -24,6 +27,18 @@ extern "C" { // SavedModel. typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; +// Retrieves the arguments of the SignatureDefFunction. The caller is not +// responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern const TF_SignatureDefParamList* +TF_SignatureDefFunctionMetadataArgs( + const TF_SignatureDefFunctionMetadata* list); + +// Retrieves the returns of the SignatureDefFunction. The caller is not +// responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern const TF_SignatureDefParamList* +TF_SignatureDefFunctionMetadataReturns( + const TF_SignatureDefFunctionMetadata* list); + #ifdef __cplusplus } // end extern "C" #endif // __cplusplus diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_param.h b/tensorflow/c/experimental/saved_model/public/signature_def_param.h new file mode 100644 index 00000000000..82993d7fedf --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_param.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that containing metadata of an input/output of a +// TF_SignatureDefFunction loaded from a SavedModel. +typedef struct TF_SignatureDefParam TF_SignatureDefParam; + +// Returns the name of the given parameter. The caller is not responsible for +// freeing the returned char*. +TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName( + const TF_SignatureDefParam* param); + +// Returns the TensorSpec associated with the given parameter. The caller is +// not reponsible for freeing the returned TF_TensorSpec*. +TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec( + const TF_SignatureDefParam* param); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h b/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h new file mode 100644 index 00000000000..0cb3a0d6d33 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that containing metadata of an input/output of a +// ConcreteFunction loaded from a SavedModel. +typedef struct TF_SignatureDefParamList TF_SignatureDefParamList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize( + const TF_SignatureDefParamList* list); + +// Returns the `i`th TF_SignatureDefParam in the list. +TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet( + const TF_SignatureDefParamList* list, int i); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ diff --git a/tensorflow/c/experimental/saved_model/public/tensor_spec.h b/tensorflow/c/experimental/saved_model/public/tensor_spec.h new file mode 100644 index 00000000000..82972ef74ef --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/tensor_spec.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_shape.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type corresponding to TensorSpec +typedef struct TF_TensorSpec TF_TensorSpec; + +// Returns the dtype associated with the TensorSpec. +TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType( + const TF_TensorSpec* spec); + +// Returns the shape associated with the TensorSpec. The returned Shape is not +// owned by the caller. Caller must not call TF_DeleteShape on the returned +// shape. +TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape( + const TF_TensorSpec* spec); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 1a421dca51e..94418aa9ad9 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -22,6 +22,8 @@ cc_library( "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/core:lib", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:strcat", "//tensorflow/stream_executor:executor_cache", "//tensorflow/stream_executor:multi_platform_manager", "//tensorflow/stream_executor:platform", diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index 093499c298e..a4136e04b84 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -27,7 +27,10 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" @@ -39,6 +42,8 @@ limitations under the License. using tensorflow::StatusFromTF_Status; namespace stream_executor { +using tensorflow::StringPiece; + namespace { #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ @@ -58,10 +63,35 @@ namespace { } \ } while (0) +port::Status ValidateDeviceType(StringPiece type) { + // Validate device type. Device type must start with a capital letter and + // consist of capital letters and underscores. Reasoning behind this decision: + // * At the minimum we want to disallow '/' and ':' since + // these characters are used in device spec, for e.g. + // /job:foo/replica:12/device:GPU:1. + // * Underscores seem useful, for e.g. XLA_GPU uses underscores. + // * Allowing lowercase might get confusing. For example, say someone + // registers a new type called "Gpu". It might be confusing for users that + // "Gpu" is not the same device type as "GPU". + // Note that lowercase "cpu" and "gpu" are currently supported only for + // legacy reasons: + // https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd + static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"}; + bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx); + if (!matches) { + return port::FailedPreconditionError( + tensorflow::strings::StrCat("Device name/type '", type, "' must match ", + kTfDeviceTypeRegEx->pattern(), ".")); + } + return port::Status::OK(); +} + port::Status ValidateSPPlatform(const SP_Platform& platform) { VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE); VALIDATE_MEMBER(SP_Platform, platform, name); VALIDATE_MEMBER(SP_Platform, platform, type); + TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name)); + TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type)); // `visible_device_count` could be 0 at initialization time. return port::Status::OK(); } diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.h b/tensorflow/c/experimental/stream_executor/stream_executor.h index ba6b1c564a8..bec77ef520b 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor.h @@ -52,7 +52,7 @@ limitations under the License. // params.device = &device; // // /* Plugin code below */ -// constexpr char DEVICE_NAME[] = "MyDevice"; +// constexpr char DEVICE_NAME[] = "MY_DEVICE"; // constexpr char DEVICE_TYPE[] = "GPU"; // // void create_device(const SP_Platform* platform, @@ -416,10 +416,15 @@ typedef struct SP_Platform { void* ext; // free-form data set by plugin - // Platform name. Must be null-terminated. + // Platform name (also referred to as subtype), for example MY_DEVICE. + // The name must start with a capital letter and consist of + // capital letters and underscores. + // Must be null-terminated. const char* name; // Device type name, for example GPU. Must be null-terminated. + // The name must start with a capital letter and consist of + // capital letters and underscores. const char* type; // Number of visible devices diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index b28d1f6fc6d..56c4ea09052 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -41,9 +41,9 @@ struct SP_Timer_st { namespace stream_executor { namespace { -constexpr int DEVICE_COUNT = 2; -constexpr char DEVICE_NAME[] = "MyDevice"; -constexpr char DEVICE_TYPE[] = "GPU"; +constexpr int kDeviceCount = 2; +constexpr char kDeviceName[] = "MY_DEVICE"; +constexpr char kDeviceType[] = "GPU"; /*** Create SP_StreamExecutor (with empty functions) ***/ void allocate(const SP_Device* const device, uint64_t size, @@ -190,9 +190,9 @@ void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) { void PopulateDefaultPlatform(SP_Platform* platform, SP_PlatformFns* platform_fns) { *platform = {SP_PLATFORM_STRUCT_SIZE}; - platform->name = DEVICE_NAME; - platform->type = DEVICE_TYPE; - platform->visible_device_count = DEVICE_COUNT; + platform->name = kDeviceName; + platform->type = kDeviceType; + platform->visible_device_count = kDeviceCount; platform_fns->create_device = create_device; platform_fns->destroy_device = destroy_device; platform_fns->create_device_fns = create_device_fns; @@ -218,11 +218,11 @@ TEST(StreamExecutor, SuccessfulRegistration) { port::Status status = InitStreamExecutorPlugin(plugin_init); TF_ASSERT_OK(status); port::StatusOr maybe_platform = - MultiPlatformManager::PlatformWithName("MyDevice"); + MultiPlatformManager::PlatformWithName("MY_DEVICE"); TF_ASSERT_OK(maybe_platform.status()); Platform* platform = maybe_platform.ConsumeValueOrDie(); - ASSERT_EQ(platform->Name(), DEVICE_NAME); - ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT); + ASSERT_EQ(platform->Name(), kDeviceName); + ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount); port::StatusOr maybe_executor = platform->ExecutorForDevice(0); @@ -244,6 +244,39 @@ TEST(StreamExecutor, NameNotSet) { ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); } +TEST(StreamExecutor, InvalidNameWithSemicolon) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->name = "INVALID:NAME"; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + EXPECT_THAT( + status.error_message(), + testing::ContainsRegex("Device name/type 'INVALID:NAME' must match")); +} + +TEST(StreamExecutor, InvalidNameWithSlash) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->platform->name = "INVALID/"; + params->destroy_platform = destroy_platform; + params->destroy_platform_fns = destroy_platform_fns; + }; + + port::Status status = InitStreamExecutorPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + EXPECT_THAT(status.error_message(), + testing::ContainsRegex("Device name/type 'INVALID/' must match")); +} + TEST(StreamExecutor, CreateDeviceNotSet) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 3bdaa866ee6..7ec1f4cc951 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { - mutex_lock l(graph->mu); - tensorflow::shape_inference::InferenceContext* ic = - graph->refiner.GetContext(&new_src.oper->node); - - if (ic->num_outputs() <= new_src.index) { - status->status = tensorflow::errors::OutOfRange( - "Cannot update edge. Output index [", new_src.index, - "] is greater than the number of total outputs [", ic->num_outputs(), - "]."); - return; - } - tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); - - tensorflow::shape_inference::InferenceContext* ic_dst = - graph->refiner.GetContext(&dst.oper->node); - if (ic_dst->num_inputs() <= dst.index) { - status->status = tensorflow::errors::OutOfRange( - "Cannot update edge. Input index [", dst.index, - "] is greater than the number of total inputs [", ic_dst->num_inputs(), - "]."); - return; - } - if (!ic_dst->MergeInput(dst.index, shape)) { - status->status = tensorflow::errors::InvalidArgument( - "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), - " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); - return; - } - status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, - &dst.oper->node, dst.index); - - if (TF_GetCode(status) == TF_OK) { - // This modification only updates the destination node for - // the purposes of running this graph in a session. Thus, we don't - // record the source node as being modified. - RecordMutation(graph, *dst.oper, "updating input tensor"); - } + TF_UpdateEdge(graph, new_src, dst, status); } void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { @@ -136,6 +100,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { auto* out_shape_and_type = handle_data.add_shape_and_type(); ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); out_shape_and_type->set_dtype(p.dtype); + out_shape_and_type->set_specialized_type(p.specialized_type); } } string result; @@ -163,7 +128,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, status->status = ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); if (TF_GetCode(status) != TF_OK) return; - shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), + shape_and_type_proto.specialized_type()); } ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } diff --git a/tensorflow/c/tf_shape.cc b/tensorflow/c/tf_shape.cc new file mode 100644 index 00000000000..a715544a13f --- /dev/null +++ b/tensorflow/c/tf_shape.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/tf_shape.h" + +#include + +#include "tensorflow/c/tf_shape_internal.h" +#include "tensorflow/core/framework/tensor_shape.h" + +extern "C" { + +TF_Shape* TF_NewShape() { + return tensorflow::wrap(new tensorflow::PartialTensorShape()); +} + +int TF_ShapeDims(const TF_Shape* shape) { + return tensorflow::unwrap(shape)->dims(); +} + +int64_t TF_ShapeDimSize(const TF_Shape* shape, int d) { + return tensorflow::unwrap(shape)->dim_size(d); +} + +void TF_DeleteShape(TF_Shape* shape) { delete tensorflow::unwrap(shape); } + +} // end extern "C" diff --git a/tensorflow/c/tf_shape.h b/tensorflow/c/tf_shape.h new file mode 100644 index 00000000000..f218d05e274 --- /dev/null +++ b/tensorflow/c/tf_shape.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api_macros.h" + +#ifndef TENSORFLOW_C_TF_SHAPE_H_ +#define TENSORFLOW_C_TF_SHAPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// An opaque type corresponding to a shape in tensorflow. In the future, +// we may expose the ABI of TF_Shape for performance reasons. +typedef struct TF_Shape TF_Shape; + +// Return a new, unknown rank shape object. The caller is responsible for +// calling TF_DeleteShape to deallocate and destroy the returned shape. +TF_CAPI_EXPORT extern TF_Shape* TF_NewShape(); + +// Returns the rank of `shape`. If `shape` has unknown rank, returns -1. +TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape); + +// Returns the `d`th dimension of `shape`. If `shape` has unknown rank, +// invoking this function is undefined behavior. Returns -1 if dimension is +// unknown. +TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d); + +// Deletes `shape`. +TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_SHAPE_H_ diff --git a/tensorflow/c/tf_shape_internal.h b/tensorflow/c/tf_shape_internal.h new file mode 100644 index 00000000000..fe97726460f --- /dev/null +++ b/tensorflow/c/tf_shape_internal.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ +#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/core/framework/tensor_shape.h" + +typedef struct TF_Shape TF_Shape; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape); + +} + +#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index bf5ada89cdd..8f7e447d322 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -251,7 +251,6 @@ cc_library_with_android_deps( deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", - "//tensorflow/core:lib_experimental", "//tensorflow/core:protos_all_cc", ], ) @@ -266,7 +265,6 @@ tf_cc_test( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_experimental", "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index e9173227aad..480243a29e6 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -15,13 +15,12 @@ limitations under the License. #include +#include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/cc/framework/grad_op_registry.h" -#include "tensorflow/cc/framework/gradients.h" - namespace tensorflow { namespace ops { namespace { @@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); -Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { - grad_outputs->push_back(Identity(scope, grad_inputs[0])); - grad_outputs->push_back(NoGradient()); - grad_outputs->push_back(NoGradient()); +Status QuantizeAndDequantizeV4GradHelper(const Scope& scope, + const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Input input = Shape(scope, op.input(0)); + Input input_min = op.input(1); + Input input_max = op.input(2); + int64 axis; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); + auto qdq_v4_grad = QuantizeAndDequantizeV4Grad( + scope, grad_inputs[0], input, input_min, input_max, + QuantizeAndDequantizeV4Grad::Axis(axis)); + grad_outputs->push_back(qdq_v4_grad.input_backprop); + grad_outputs->push_back(qdq_v4_grad.input_min_backprop); + grad_outputs->push_back(qdq_v4_grad.input_max_backprop); return scope.status(); } -REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); +REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4", + QuantizeAndDequantizeV4GradHelper); Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 243f86ed787..056c99eed8e 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -21,10 +21,7 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files([ - "LICENSE", - "loader.h", -]) +exports_files(["loader.h"]) cc_library( name = "constants", @@ -45,13 +42,15 @@ cc_library( name = "reader", srcs = ["reader.cc"], hdrs = ["reader.h"], - deps = [":constants"] + if_not_mobile([ + deps = [ + ":constants", + "//tensorflow/core:protos_all_cc", + ] + if_not_mobile([ # TODO(b/111634734): :lib and :protos_all contain dependencies that # cannot be built on mobile platforms. Instead, include the appropriate # tf_lib depending on the build platform. "@com_google_absl//absl/memory:memory", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", ]), ) diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 73d84483199..e8e128f9a16 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -12,8 +12,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "freeze_saved_model", srcs = ["freeze_saved_model.cc"], diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 3552f96578c..4a41caf1d40 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -75,7 +75,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", ] + if_llvm_system_z_available([ "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep ]) + if_llvm_aarch64_available([ diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 95d20805ab9..06745de647b 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -336,9 +336,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:regexp", "//third_party/eigen3", "@com_google_absl//absl/strings", ], @@ -559,9 +559,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:regexp", "//third_party/eigen3", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 1a61930a23d..742cb308b3c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -127,7 +127,7 @@ def tf_library( "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), - tools = [tfcompile_tool], + exec_tools = [tfcompile_tool], # Run tfcompile on the build host, rather than forge, since it's # typically way faster on the local machine. local = 1, @@ -242,7 +242,7 @@ def tf_library( " --out_function_object=$(@D)/" + function_object_file + " " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag ), - tools = [tfcompile_tool], + exec_tools = [tfcompile_tool], visibility = visibility, testonly = testonly, # Run tfcompile on the build host since it's typically faster on the @@ -281,7 +281,7 @@ def tf_library( " --out_session_module=$(@D)/" + session_module_pb + " " + flags ), - tools = [tfcompile_tool], + exec_tools = [tfcompile_tool], visibility = visibility, testonly = testonly, local = 1, diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f6d376e01f4..7d87b5f0715 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test") # buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts") +load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") # buildifier: disable=same-origin-load @@ -77,7 +77,7 @@ cc_library( "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - ] + if_tpu( + ] + if_libtpu( if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"], if_true = [], ), @@ -114,7 +114,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", - ] + if_tpu( + ] + if_libtpu( if_false = [ "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep ], @@ -141,7 +141,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/gpu:gpu_init", - ] + if_tpu( + ] + if_libtpu( if_false = [ "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep ], @@ -204,7 +204,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/core:resource_variable_ops_op_lib", "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:state_ops_op_lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:function_ops", @@ -375,7 +375,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", - ] + if_tpu( + ] + if_libtpu( if_false = [ "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", @@ -435,6 +435,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime/eager:tensor_handle", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1022,10 +1023,10 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", - "//tensorflow/core:direct_session_internal", "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:test", + "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:partitioned_function_ops", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 20efbe248d7..62e121420c3 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name, return Status::OK(); } +xla::StatusOr> MakeCallNodesFromAttribute( + const Node& node, absl::string_view attr_name, + absl::string_view call_name) { + std::vector attr_lists; + TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists)); + + std::vector out; + for (int i = 0; i < attr_lists.size(); i++) { + out.emplace_back(); + NodeDef& inserted = out.back(); + inserted.set_name(absl::StrCat(call_name, "_", i)); + inserted.set_op(attr_lists[i].name()); + *inserted.mutable_attr() = attr_lists[i].attr(); + } + return out; +} + // Utility which searches for values in a sorted list by scanning over it once. // No matter how many times ScanForValue is called, the list is scanned at most // once. However, if a call to ScanForValue skips over a value, that value is @@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf( return is_compilable; } +bool RecursiveCompilabilityChecker::IsCompilableCase( + const Node& case_node, FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { + xla::StatusOr> calls = + MakeCallNodesFromAttribute(case_node, "branches", "branch"); + if (!calls.ok()) { + VLOG(2) << "Rejecting node " << case_node.name() << ": " + << "missing attribute 'branches'"; + return false; + } + + bool is_compilable = true; + + for (const NodeDef& call : *calls) { + is_compilable &= + IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes); + } + return is_compilable; +} + // Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. @@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + if (op_filter_.require_always_compilable && node.IsCaseNode() && + !IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes)) { + LogNotCompilable(node, "unsupported case"); + return false; + } + if (!op_filter_.allow_stateful_rng_ops && IsStatefulRandomOp(node.type_string())) { absl::string_view uncompilable_reason = "stateful random op"; diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 3c1378bf764..65da072483b 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker { // Whether ops known to have numerical accuracy issues should be considered // compilable.. bool allow_inaccurate_ops = false; + + // Require the function to be always compilable, regardless whether some + // control flow branches might be dead for a given input. + bool require_always_compilable = false; }; RecursiveCompilabilityChecker(OperationFilter op_filter, @@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker { NameAttrList* encapsulating_function, UncompilableNodesMap* uncompilable_nodes) const; + // Tests whether 'case_node' is compilable. Every operator in all branches + // must be compilable. + bool IsCompilableCase(const Node& case_node, + FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; + // Returns compilability of node def retrieved from `node`'s attribute with // name `attr_name`. bool ExtractNodeDefAndCheckCompilability( diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 3851c66ba1a..9058b129589 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -34,7 +34,16 @@ limitations under the License. namespace tensorflow { namespace { +AttrValue FuncListAttr(const absl::Span names) { + AttrValue attr; + for (const char* name : names) { + attr.mutable_list()->add_func()->set_name(name); + } + return attr; +} + constexpr char kFunctionalIfNodeName[] = "If"; +constexpr char kFunctionalCaseNodeName[] = "Case"; constexpr char kFunctionalWhileNodeName[] = "While"; constexpr char kCompilableFunctionName[] = "CompilableFn"; constexpr char kCompilableFunctionNodeName[] = "n_c"; @@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test { op_filter_.allow_inaccurate_ops = false; op_filter_.allow_slow_ops = false; - checker_ = absl::make_unique(op_filter_, - device_type_); + checker_ = CreateCompilabilityChecker(); + } + + std::unique_ptr CreateCompilabilityChecker() { + return absl::make_unique(op_filter_, + device_type_); } FunctionLibraryRuntime* GetFunctionLibraryRuntime() { @@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { "unsupported op")); } +TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) { + FunctionDefLibrary flib; + *flib.add_function() = FunctionDefHelper::Define( + /*Function*/ kUncompilableFunctionName, + /*Inputs*/ {"n_a:float"}, + /*Outputs*/ {"n_c_uncompilable:float"}, + /*Attributes*/ {}, + // Node info + {{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}}); + *flib.add_function() = FunctionDefHelper::Define( + /*Function*/ kUncompilableFunctionTwoName, + /*Inputs*/ {"n_a:float"}, + /*Outputs*/ {"n_d_uncompilable:float"}, + /*Attribute*/ {}, + // Node info + {{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}}); + + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib)); + auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32); + auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32); + std::vector inputes( + {NodeBuilder::NodeOut(placeholder.node())}); + Node* case_node; + TF_ASSERT_OK( + NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def()) + .Input(branch_index.node()) + .Input(inputes) + .Attr("branches", FuncListAttr({kUncompilableFunctionName, + kUncompilableFunctionTwoName})) + .Attr("Tout", {DT_INT32}) + .Finalize(root.graph(), &case_node)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + + auto case_node_it = std::find_if( + graph->nodes().begin(), graph->nodes().end(), + [&](const Node* n) { return n->name() == kFunctionalCaseNodeName; }); + EXPECT_NE(case_node_it, graph->nodes().end()); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + op_filter_.require_always_compilable = false; + checker_ = CreateCompilabilityChecker(); + EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime)); + op_filter_.require_always_compilable = true; + checker_ = CreateCompilabilityChecker(); + EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime)); +} + TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) { GraphDefBuilder b(GraphDefBuilder::kFailImmediately); Scope root = Scope::NewRootScope().ExitOnError(); diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 21c3194eadc..08b3bea1084 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/lib/core/status.h" @@ -47,8 +48,8 @@ static xla::StatusOr GetLocalExecutable( xla::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, - absl::string_view func_name, Device* dev, - absl::Span inputs) { + absl::string_view func_name, Device* dev, EagerContext* context, + absl::Span inputs_handles) { NameAttrList function; function.set_name(std::string{func_name}); @@ -65,6 +66,25 @@ xla::StatusOr GetCompilerIr( GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); + std::deque inputs_storage; + std::vector inputs; + inputs.reserve(inputs_handles.size()); + for (int i = 0; i < inputs_handles.size(); i++) { + const TensorHandle* th = inputs_handles[i]; + const Tensor* t; + // Handle owns the tensor. + TF_RETURN_IF_ERROR(th->Tensor(&t)); + if (absl::c_binary_search(constant_arg_indices, i)) { + // Need to make sure it's on the host. + inputs_storage.emplace_back(t->dtype(), t->shape()); + TF_RETURN_IF_ERROR( + th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back())); + inputs.push_back(&inputs_storage.back()); + } else { + inputs.push_back(t); + } + } + std::vector variable_infos; TF_RETURN_IF_ERROR(GetVariableInfosFromInputs( rmgr, dev, inputs, resource_arg_indices, &variable_infos)); diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h index 0ec69fd9b38..0a0a1a44271 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.h +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -24,6 +24,8 @@ namespace tensorflow { class ProcessFunctionLibraryRuntime; class Device; class Tensor; +class TensorHandle; +class EagerContext; enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT }; @@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT }; // `runtime` on a device `dev` with given `inputs`. xla::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, - absl::string_view func_name, Device* dev, - absl::Span inputs); + absl::string_view func_name, Device* dev, EagerContext* context, + absl::Span inputs); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index be9e1d31f3b..1f400137f5b 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -34,7 +34,7 @@ XLA_OPS_DEPS = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:state_ops_op_lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:tf_allocator_adapter", ] diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 41f80203353..ada7766fcbb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } - if (!RecursiveCompilabilityChecker{ - CreateOperationFilter(*registration), - DeviceType{registration->compilation_device_name}} - .IsCompilableNode(*node, lib_runtime)) { + RecursiveCompilabilityChecker::OperationFilter filter = + CreateOperationFilter(*registration); + filter.require_always_compilable = true; + + RecursiveCompilabilityChecker checker( + filter, DeviceType{registration->compilation_device_name}); + + if (!checker.IsCompilableNode(*node, lib_runtime)) { continue; } @@ -2062,6 +2066,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaSpmdFullToShardShape", "XlaSpmdShardToFullShape", "XlaSvd", + "XlaVariadicReduce", "XlaWhile", "Zeta", "_Arg", diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3ab3c19e439..d7d5ee02265 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -47,7 +47,7 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" -#if !defined(LIBTFTPU) +#if !defined(LIBTPU_ON_GCE) #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h" #endif @@ -289,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp( }); const ConfigProto* config = ctx->function_library()->config_proto(); bool use_mlir = config && config->experimental().enable_mlir_bridge(); -#ifdef LIBTFTPU +#ifdef LIBTPU_ON_GCE if (use_mlir && has_tensor_list_arg) { LOG(WARNING) << "MLIR is not supported in this environment."; } @@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp( } GraphDebugInfo debug_info; + std::vector control_rets; + if (result_dtypes.empty()) { + control_rets.push_back(node_def.name()); + } return CompileGraphToXlaHlo( - *graph, mlir::SpanToArrayRef(args), + *graph, mlir::SpanToArrayRef(args), control_rets, options.device_type.type_string(), compile_options.use_tuple_arg, *options.flib_def, debug_info, options.shape_representation_fn, result); #endif diff --git a/tensorflow/compiler/mlir/README.md b/tensorflow/compiler/mlir/README.md index cbb0b08503a..c415edceb8c 100644 --- a/tensorflow/compiler/mlir/README.md +++ b/tensorflow/compiler/mlir/README.md @@ -9,3 +9,31 @@ dialects and utilities for 3. TF Lite See [MLIR's website](https://mlir.llvm.org) for complete documentation. + +## Getting started + +Building dialects and utilities here follow the standard approach using +`bazel` as the rest of TensorFlow. + +### Using local LLVM repo + +To develop across MLIR core and TensorFlow, it is useful to override the repo +to use a local version instead of fetching from head. This can be achieved as +below but note, the BUILD files are not automatically generated from or CMake +used, so if your change requires a BUILD file change (or you are using a +different version of LLVM than set in tensorflow/workspace.bzl's LLVM_COMMIT) +then manual BUILD file changes may be required. + +```sh +LLVM_SRC=... + +# Create basic workspace file +echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE +# and copy over the bazel BUILD files. +cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD +cp third_party/mlir/BUILD $LLVM_SRC/mlir +cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD + +bazel build --override_repository=llvm-project=$LLVM_SRC \ + -c opt tensorflow/compiler/mlir:tf-opt +``` diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 7550c7e7ba6..30e8d8a86a7 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -48,6 +48,7 @@ filegroup( "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", @@ -539,6 +540,8 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 92ae4d29c25..13d5f02368b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], }]; } +def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], + HLO_FpOrComplexTensor> { + let summary = "Atan operator"; + + let description = [{ + Returns `Atan(operand)` element-wise. + + $$ + \atan(x) = \atan2(x, 1) + $$ + }]; +} + def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], HLO_FpOrComplexTensor> { let summary = "Sinh operation"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index d545c2a1c04..507f7c11d63 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", >]; } +def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp; + def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp; @@ -193,12 +196,10 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; -def HLO_ImagOp: HLO_Op< - "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); +def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], + HLO_ComplexTensor>, BASE_HLO_ImagOp { let results = (outs HLO_FpTensor); let hasFolder = 1; } @@ -237,12 +238,10 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, BASE_HLO_PopulationCountOp; -def HLO_RealOp: HLO_Op< - "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value val">]; - - let arguments = (ins HLO_ComplexTensor); +def HLO_RealOp: HLO_UnaryElementwiseOp<"real", + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods], + HLO_ComplexTensor>, BASE_HLO_RealOp { let results = (outs HLO_FpTensor); let hasFolder = 1; } @@ -321,12 +320,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add", def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; -def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsAndResultShape]>, +def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex", + [NoSideEffect, SameOperandsAndResultShape, + DeclareOpInterfaceMethods]>, BASE_HLO_ComplexOp { - let builders = [OpBuilder< - "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; - let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); let results = (outs HLO_ComplexTensor); let hasFolder = 1; @@ -356,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp { + let hasFolder = 1; +} def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; @@ -913,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", let results = (outs HLO_Tensor); } -// TODO(hinsu): Make this struct dialect independent so that it can be shared -// between HLO and LHLO dialect. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { - let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs), + ConvolutionAttributes.attributes); let results = (outs HLO_Tensor); } @@ -1198,14 +1170,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, let results = (outs HLO_Tensor); } -def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { +def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp { let arguments = (ins Variadic:$operands, DefaultValuedAttr:$dimension, DefaultValuedAttr:$is_stable ); - let results = (outs HLO_TensorOrTuple); + let results = (outs Variadic); let regions = (region SizedRegion<1>:$comparator); @@ -1429,4 +1401,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> { let hasCustomHLOConverter = 1; } +// This is an op for purposes internal to XLA/GPU. +def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp { + let arguments = (ins HLO_Tensor:$operand); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + +def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>, + BASE_HLO_ReducePrecisionOp { + let arguments = (ins + HLO_FpTensor:$operand, + I32Attr:$exponent_bits, + I32Attr:$mantissa_bits + ); + let results = (outs HLO_FpTensor:$output); +} + #endif // HLO_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index b8378fed01a..cba2dc370f0 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -127,6 +127,17 @@ class BASE_HLO_AbsOp { }]; } +class BASE_HLO_CbrtOp { + string summary = "Cubic root operator"; + + string description = [{ + Returns element-wise cubic root of the operand. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + class BASE_HLO_CeilOp { string summary = "Ceil operator"; @@ -996,6 +1007,42 @@ class BASE_HLO_ConcatenateOp { }]; } +//===----------------------------------------------------------------------===// +// Common convolution attributes +//===----------------------------------------------------------------------===// + +class ConvDimensionNumbersBase + : StructAttr<"ConvDimensionNumbers", dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +class ConvolutionAttributes { + dag attributes = (ins + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_strides, + // Default value: zero for each of the spatial dimension. + OptionalAttr:$padding, + // Default value: one for each of the spatial dimension. + OptionalAttr:$lhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$rhs_dilation, + ConvDimensionNumbersBase:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + HLO_PrecisionConfigAttr:$precision_config + ); +} + class BASE_HLO_ConvOp { string summary = "Convolution operator"; @@ -1336,4 +1383,17 @@ class BASE_HLO_WhileOp { }]; } +class BASE_HLO_BitcastOp { + string summary = "Bitcast operator"; + + string description = [{ + This op changes the shape of the input in the way that the physical + arranggment of elements are unchanged. + + However, the op needs layout information to make sense of "physical + arrangement of elements". Layout support in MHLO is currently under + exploration. + }]; +} + #endif // HLO_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 37757fa0033..c013939c544 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" def LHLO_Dialect : Dialect { let name = "lmhlo"; let cppNamespace = "::mlir::lmhlo"; } -//===----------------------------------------------------------------------===// -// LMHLO type definitions. -//===----------------------------------------------------------------------===// - -// Any integer tensor types -def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; - -// Any floating-point tensor types -def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; - -def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; - -def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; - -def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; - -// Any integer or floating-point tensor types -def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; - -def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; - -def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; - -def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; - //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. //===----------------------------------------------------------------------===// @@ -289,6 +264,16 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); } +def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { + let arguments = (ins + Arg, "", [MemRead]>:$args, + Arg:$output, + StrAttr:$call_target_name, + DefaultValuedAttr:$has_side_effect, + DefaultValuedAttr:$backend_config + ); +} + //===----------------------------------------------------------------------===// // LMHLO tuple op definitions. //===----------------------------------------------------------------------===// @@ -335,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { def HLO_StaticMemRefCastOp: Op]> { let summary = [{ - "modifies the offset, sizes and strides of a statically shaped memref. + modifies the offset, sizes and strides of a statically shaped memref }]; let description = [{ - Allows to modify the offset, sizes and strides of a statically shaped memref. + Casts the statically shaped memref operand to a memref with optionally + modified offsets, sizes and strides. Example: ```mlir @@ -354,12 +340,11 @@ def HLO_StaticMemRefCastOp: Op:$operand); let results = (outs Res:$result); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, MemRefType resultType, " # - "Value operand", [{ - result.addOperands(operand); - result.types.push_back(resultType); - }]>]; + let builders = [OpBuilder<"MemRefType resultType, Value operand", + [{ + $_state.addOperands(operand); + $_state.types.push_back(resultType); + }]>]; let extraClassDeclaration = [{ MemRefType getType() { return getResult().getType().cast(); } @@ -400,13 +385,13 @@ def HLO_DynamicMemRefCastOp: Op:$result); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, MemRefType resultType, " # - "Value operand, ValueRange sizes, ValueRange strides", [{ - result.addOperands(operand); - result.addOperands(sizes); - result.addOperands(strides); - result.types.push_back(resultType); + let builders = [ + OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, " + "ValueRange strides", [{ + $_state.addOperands(operand); + $_state.addOperands(sizes); + $_state.addOperands(strides); + $_state.types.push_back(resultType); }]>]; let extraClassDeclaration = [{ @@ -582,40 +567,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { ); } -// TODO(bondhugula): Make this struct dialect independent so that it can be -// shared between the HLO and LHLO dialects. -def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [ - StructFieldAttr<"input_batch_dimension",I64Attr>, - StructFieldAttr<"input_feature_dimension", I64Attr>, - StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, - StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, - StructFieldAttr<"output_batch_dimension", I64Attr>, - StructFieldAttr<"output_feature_dimension", I64Attr>, - StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - - let description = "Structure of dimension information for conv op"; -} - def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { - let arguments = (ins - Arg:$lhs, - Arg:$rhs, - Arg:$output, - // Default value: one for each of the spatial dimension. - OptionalAttr:$window_strides, - // Default value: zero for each of the spatial dimension. - OptionalAttr:$padding, - // Default value: one for each of the spatial dimension. - OptionalAttr:$lhs_dilation, - // Default value: one for each of the spatial dimension. - OptionalAttr:$rhs_dilation, - ConvDimensionNumbers:$dimension_numbers, - I64Attr:$feature_group_count, - I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config - ); + let arguments = !con( + (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output), + ConvolutionAttributes.attributes); } def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { @@ -856,9 +814,8 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result, " - "ArrayRef attributes"> - ]; + OpBuilder<"ArrayRef attributes"> + ]; } def TerminatorOp : @@ -867,9 +824,8 @@ def TerminatorOp : let description = [{ Terminator operation for the LHLO dialect. }]; - let builders = [OpBuilder< - "OpBuilder &b, OperationState &result, ValueRange operands", - [{ build(b, result, llvm::None, operands, llvm::None); }] + let builders = [OpBuilder<"ValueRange operands", + [{ build($_builder, $_state, llvm::None, operands, llvm::None); }] >]; } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td new file mode 100644 index 00000000000..9cd77417ffd --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_OPS_BASE +#define LHLO_OPS_BASE + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" + +//===----------------------------------------------------------------------===// +// LMHLO type definitions. +//===----------------------------------------------------------------------===// + +// Any integer tensor types +def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; + +// Any floating-point tensor types +def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; + +def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; + +def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; + +def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; + +// Any integer or floating-point tensor types +def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; + +def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; + +def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; + +def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; + +#endif // LHLO_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index d2621759213..ac67619e6e3 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -50,6 +50,7 @@ MAP_HLO_TO_LHLO(ConvOp); MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CosOp); +MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); @@ -57,11 +58,13 @@ MAP_HLO_TO_LHLO(FloorOp); MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); +MAP_HLO_TO_LHLO(IsFiniteOp); MAP_HLO_TO_LHLO(LogOp); MAP_HLO_TO_LHLO(MaxOp); MAP_HLO_TO_LHLO(MinOp); MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(NegOp); +MAP_HLO_TO_LHLO(NotOp); MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReshapeOp); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 1199dae1ab2..9cf1b6ce57e 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + template inline Optional getCmpPredicate(StringRef comparison_direction) { return llvm::None; @@ -345,6 +354,22 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + if (args[0].getType().isa()) { + auto pos_inf = APFloat::getInf( + args[0].getType().cast().getFloatSemantics()); + auto const_pos_inf = + b->create(loc, b->getFloatAttr(args[0].getType(), pos_inf)); + Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]); + return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x, + const_pos_inf); + } + return nullptr; +} + /// Implements the conversion of HLO op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template @@ -431,6 +456,21 @@ inline Value MapLhloOpToStdScalarOp(Location loc, return nullptr; } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + Type element_type = args.front().getType(); + if (auto integer_type = element_type.dyn_cast()) { + // lmhlo.not(x) -> x ^ -1 + auto all_ones = + b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); + return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); + } + return nullptr; +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, @@ -454,11 +494,27 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); - if (element_type.isa()) { - FloatType float_type = element_type.cast(); - APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); - Value one = b->create(loc, const_value, float_type); + if (auto float_type = element_type.dyn_cast()) { + bool ignored; + APFloat one_apfloat(1.0f); + one_apfloat.convert(float_type.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &ignored); + Value one = b->create(loc, one_apfloat, float_type); return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); + } else if (auto integer_type = element_type.dyn_cast()) { + // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) + Value zero = + b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + Value cmp = + b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); + Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( + loc, integer_type.getWidth() - 1, integer_type.getWidth()); + Value ashr = + b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); + Value one = + b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); + Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); + return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); } return nullptr; } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 6c623ec859d..4348464fa74 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -15,9 +15,9 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> { - let summary = "Test pass for applying chlo -> hlo legalization patterns."; - let constructor = "createTestChloLegalizeToHloPass()"; +def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> { + let summary = "Legalize CHLO to HLO."; + let constructor = "createChloLegalizeToHloPass()"; } def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 2c0735a33fa..b1933f6686b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -44,6 +44,9 @@ std::unique_ptr> createControlFlowToScfPass(); /// Lowers from HLO dialect to Standard dialect. std::unique_ptr> createLegalizeToStdPass(); +/// Lowers from the CHLO dialect to the HLO dialect. +std::unique_ptr createChloLegalizeToHloPass(); + /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// buffers if necessary. If `results_escape_functions` is set to true, /// allocated buffers for function results will be returned and escape the @@ -63,7 +66,7 @@ std::unique_ptr> createSinkConstantsToControlFlowPass(); std::unique_ptr> createMhloFusionPass(); /// Lowers trigonometric operations from the standard dialect to approximations -// that do not use intrinsics. +/// that do not use intrinsics. std::unique_ptr> createLegalizeTrigonometricToApproximationPass(); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h index 8f70f64359b..e9418f0e7a0 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h @@ -22,7 +22,6 @@ limitations under the License. namespace mlir { namespace mhlo { -std::unique_ptr createTestChloLegalizeToHloPass(); std::unique_ptr createTestInferShapedTypeMethodsPass(); std::unique_ptr createTestMaterializeBroadcastsPass(); std::unique_ptr createTestUnfuseBatchNormPass(); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index b0a382d6b0f..a58d0c6304c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -20,7 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/BufferPlacement.h" +#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 1ebec669ea5..0a5bb0e018a 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -185,8 +185,7 @@ struct GatherSlice : public OpRewritePattern { return failure(); const auto& dnums = gather.dimension_numbers(); - if (dnums.collapsed_slice_dims().getNumElements() != 0 || - dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1) + if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1) return failure(); // TODO(tberghammer): Remove when the verifier catches this case what is @@ -206,11 +205,35 @@ struct GatherSlice : public OpRewritePattern { } llvm::SmallVector slice_stride(slice_end.size(), 1); - rewriter.replaceOpWithNewOp( - gather, gather.getType(), gather.getOperand(0), + llvm::SmallVector slice_shape(slice_end.size()); + for (int64_t i = 0; i < slice_end.size(); ++i) { + slice_shape[i] = slice_end[i] - slice_start[i]; + } + Type element_type = gather.getType().cast().getElementType(); + auto slice_type = RankedTensorType::get(slice_shape, element_type); + Value result = rewriter.create( + gather.getLoc(), slice_type, gather.getOperand(0), GetI64ElementsAttr(slice_start, &rewriter), GetI64ElementsAttr(slice_end, &rewriter), GetI64ElementsAttr(slice_stride, &rewriter)); + + if (dnums.collapsed_slice_dims().getNumElements() > 0) { + auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range( + dnums.collapsed_slice_dims().getIntValues(), + [](const llvm::APInt& i) { return i.getSExtValue(); })); + llvm::SmallVector reshape_shape; + for (int64_t i = 0; i < slice_shape.size(); ++i) { + if (llvm::count(collapsed_slice_dims, i) == 0) { + reshape_shape.push_back(slice_shape[i]); + } + } + auto reshape_type = RankedTensorType::get(reshape_shape, element_type); + result = + rewriter.create(gather.getLoc(), reshape_type, result); + } + + result.setType(gather.getType()); + rewriter.replaceOp(gather, result); return success(); } }; @@ -889,9 +912,10 @@ static LogicalResult Verify(ClampOp op) { // ComplexOp //===----------------------------------------------------------------------===// -void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, - Value rhs) { - auto type = lhs.getType(); +LogicalResult ComplexOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto type = operands[0].getType(); auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); Type result_ty; if (auto ranked_type = type.dyn_cast()) { @@ -901,8 +925,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, } else { result_ty = element_ty; } - - build(builder, state, result_ty, lhs, rhs); + inferredReturnTypes.push_back(result_ty); + return success(); } OpFoldResult ComplexOp::fold(ArrayRef operands) { @@ -932,8 +956,11 @@ Type CreateRealType(Type type) { } } // namespace -void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { - build(builder, state, CreateRealType(val.getType()), val); +LogicalResult ImagOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(CreateRealType(operands[0].getType())); + return success(); } OpFoldResult ImagOp::fold(ArrayRef operands) { @@ -945,8 +972,11 @@ OpFoldResult ImagOp::fold(ArrayRef operands) { return {}; } -void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { - build(builder, state, CreateRealType(val.getType()), val); +LogicalResult RealOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, DictionaryAttr, + RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(CreateRealType(operands[0].getType())); + return success(); } OpFoldResult RealOp::fold(ArrayRef operands) { @@ -1971,6 +2001,23 @@ struct divide { APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } }; +template +struct remainder : std::modulus {}; + +template <> +struct remainder { + APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); } +}; + +template <> +struct remainder { + APFloat operator()(const APFloat& a, const APFloat& b) const { + APFloat result(a); + result.remainder(b); + return result; + } +}; + template struct max { T operator()(const T& a, const T& b) const { return std::max(a, b); } @@ -2012,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus); BINARY_FOLDER(SubOp, std::minus); BINARY_FOLDER(MulOp, std::multiplies); BINARY_FOLDER(DivOp, divide); +BINARY_FOLDER(RemOp, remainder); BINARY_FOLDER(MaxOp, max); BINARY_FOLDER(MinOp, min); @@ -2261,10 +2309,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state, state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); state.addAttribute("is_stable", builder.getBoolAttr(dimension)); - SmallVector element_types; - element_types.reserve(operands.size()); - for (Value operand : operands) element_types.push_back(operand.getType()); - state.addTypes(builder.getTupleType(element_types)); + for (Value operand : operands) state.addTypes(operand.getType()); state.addRegion(); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 626b5d3bd59..42d6d70b524 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp auto if_op = rewriter.create( loc, result_type, IsScalarTensor(rewriter, op, lhs), true); OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(); - Value reshaped_lhs = if_lhs_scalar_builder.create( + Value reshaped_lhs = if_lhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); Value if_lhs_scalar_result = if_lhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, @@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp else_lhs_scalar_builder.create(loc, if_rhs_scalar_op.getResult(0)); OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(); - Value reshaped_rhs = if_rhs_scalar_builder.create( + Value reshaped_rhs = if_rhs_scalar_builder.create( loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); Value if_rhs_scalar_result = if_rhs_scalar_builder.create( loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, @@ -516,7 +516,7 @@ struct HloCompareAdaptor { void PopulateLegalizeChloToHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); // Instantiate conversion templates for conforming binary elementwise ops // that do not have different dtypes between operands and results and do diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index 263b6cdd1c3..d2f415d91f9 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -27,8 +27,8 @@ namespace mhlo { namespace { -struct TestChloLegalizeToHloPass - : public PassWrapper { +struct ChloLegalizeToHloPass + : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -36,11 +36,12 @@ struct TestChloLegalizeToHloPass void runOnFunction() override { ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - conversionTarget.addIllegalDialect(); + // Consider the mhlo dialect legal for tests. conversionTarget.addLegalDialect(); - // The conversion uses helpers from the Standard dialect. + + // The conversion uses helpers from the standard dialect. conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); @@ -56,8 +57,8 @@ struct TestChloLegalizeToHloPass } // namespace -std::unique_ptr createTestChloLegalizeToHloPass() { - return std::make_unique(); +std::unique_ptr createChloLegalizeToHloPass() { + return std::make_unique(); } } // namespace mhlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index 9d9e6d978b2..a48abb6190c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -24,16 +24,17 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td" //===----------------------------------------------------------------------===// // Expand acos to MHLO dialect as follows: -// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1 // = pi if x == -1 def : Pat<(HLOClient_AcosOp $input), (HLO_SelectOp - (HLO_CompareOp $input, - (HLO_ConstantLike<"0"> $input), + (HLO_CompareOp + $input, + (HLO_ConstantLike<"-1"> $input), HLO_COMPARISON_DIRECTION_NE ), (HLO_MulOp - (HLO_ConstantLike<"2.0f"> $input), + (HLO_ConstantLike<"2"> $input), (HLO_Atan2Op (HLO_SqrtOp (HLO_SubOp @@ -47,7 +48,16 @@ def : Pat<(HLOClient_AcosOp $input), ) ) ), - (HLO_ConstantLike<"M_PI"> $input))>; + (HLO_ConstantLike<"M_PI"> $input) + )>; + +// Express `atan` as +// atan(x) = atan2(x, 1) +def : Pat<(HLOClient_AtanOp $input), + (HLO_Atan2Op + $input, + (HLO_ConstantLike<"1"> $input) + )>; // Express `sinh` as // sinh(x) = (e^x - e^-x) / 2 if |x| < 1 @@ -95,4 +105,3 @@ def : Pat<(HLOClient_TanOp $input), (HLO_SinOp $input), (HLO_CosOp $input) )>; - diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 0f1a3d034eb..a2a940e2c58 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -20,6 +20,8 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -32,7 +34,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/BufferPlacement.h" +#include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -45,7 +47,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast(); + auto result_type = result.getType().dyn_cast(); if (!result_type) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects ranked results"; @@ -53,17 +55,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); - Operation* op = result.getDefiningOp(); - // Extract the required element out of the vector. SmallVector dynamic_operands; for (auto shape_element : llvm::enumerate(result_type.getShape())) { if (shape_element.value() != ShapedType::kDynamicSize) continue; - Value index = rewriter->create( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), - shape_element.index())); - Value alloc_operand = rewriter->create(loc, shape_operand, - ValueRange{index}); + Value index = rewriter->create(loc, shape_element.index()); + Value alloc_operand = + rewriter->create(loc, shape_operand, index); if (!alloc_operand.getType().isIndex()) { alloc_operand = rewriter->create(loc, alloc_operand, rewriter->getIndexType()); @@ -71,15 +69,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, dynamic_operands.push_back(alloc_operand); } - // Insert in front of op to ensure sizes are available. - OpBuilder allocBuilder(op); - auto alloc = allocBuilder.create(loc, memref_type, dynamic_operands); - return alloc; + return rewriter->create(loc, memref_type, dynamic_operands); } Value InsertAlloc(Location loc, OpResult result, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast(); + auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects statically shaped results"; @@ -112,19 +107,21 @@ class HloToLhloOpConverter : public BaseOpConversion { buffer_args.push_back( InsertAlloc(op->getLoc(), result.value(), &rewriter)); } else { - SmallVector results_shape; auto shape_type_op = dyn_cast(op); if (!shape_type_op) return failure(); - if (failed( - shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) - return failure(); + + SmallVector results_shape; + auto status = + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); + if (failed(status)) return failure(); buffer_args.push_back(InsertDynamicAllocAndDealloc( op->getLoc(), result.value(), results_shape.front(), &rewriter)); } } rewriter.create>(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); - rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); + rewriter.replaceOp( + op, llvm::makeArrayRef(buffer_args).drop_front(operands.size())); return success(); } }; @@ -453,6 +450,10 @@ struct HloLegalizeToLhlo return std::all_of(op.operand_type_begin(), op.operand_type_end(), isMemRefType); }); + target.addDynamicallyLegalOp([&](shape::AssumingOp op) { + return std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); + }); auto kind = results_escape_function ? BufferAssignmentTypeConverter::KeepAsFunctionResult @@ -463,8 +464,9 @@ struct HloLegalizeToLhlo populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter, - &patterns); + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, converter, + patterns); + populateShapeTypeConversionPatterns(&context, converter, patterns); if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); } @@ -498,6 +500,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -505,11 +508,13 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -526,7 +531,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter - >(context, converter); + >(context, *converter); // clang-format on } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index b6e23a6b131..adf2a398a00 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -32,8 +32,6 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LogicalResult.h" -using mlir::PassRegistration; - namespace mlir { namespace mhlo { namespace { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 1e8442ec95c..57859b64bed 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -838,6 +839,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -847,6 +849,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, @@ -930,6 +933,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -945,6 +949,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -953,6 +958,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, TransposeConverter>(context); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index d2d4bab45ab..84255c2810e 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -193,7 +193,7 @@ std::unique_ptr> createLegalizeToStdPass() { void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, mlir::MLIRContext *ctx) { - mlir::populateWithGenerated(ctx, patterns); + mlir::populateWithGenerated(ctx, *patterns); patterns->insert(ctx); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index 7021fb8bf4a..10030866d0f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering the tanh standard ops to an -// approximation. +// This file implements the lowering for trigonometric standard ops to +// approximations. #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" @@ -27,102 +27,228 @@ namespace mlir { namespace mhlo { namespace { -/// Emits the fast tanh approximation that is also used by XLA. -Value EmitTanhApproximation(Value input, Location loc, - PatternRewriter &rewriter) { - // For small values of x, we can approximate tanh(x)=x. For extremely small - // values of x (|x| < 1e-37), the other approximation would evaluate - // tanh(x) = 0. - constexpr float kCanUseApprox = 0.0004; - Value abs_value = rewriter.create(loc, input); - Value can_use_approx = - rewriter.create(loc, rewriter.getF32FloatAttr(kCanUseApprox)); - Value return_input = rewriter.create(loc, CmpFPredicate::OLT, - abs_value, can_use_approx); - // Clamp the input to [-c, c]. - Value max_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(7.90531110763549805f)); - Value smaller_than_max = - rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); - Value clamped_half = - rewriter.create(loc, smaller_than_max, input, max_clamp); - Value min_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); - Value larger_than_min = - rewriter.create(loc, CmpFPredicate::UGE, clamped_half, min_clamp); - Value input_clamped = - rewriter.create(loc, larger_than_min, clamped_half, min_clamp); - - static constexpr std::array numerator_coeffs{ - -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}; - - static constexpr std::array denominator_coeffs{ - 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}; - - Value input_squared = - rewriter.create(loc, input_clamped, input_clamped); - Value numerator = rewriter.create( - loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); - for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = rewriter.create( - loc, rewriter.create(loc, input_squared, numerator), - rewriter.create( - loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); - } - - numerator = rewriter.create(loc, input_clamped, numerator); - - Value denominator = rewriter.create( - loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); - for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = rewriter.create( - loc, rewriter.create(loc, input_squared, denominator), - rewriter.create( - loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); - } - - Value approx = rewriter.create(loc, numerator, denominator); - - return rewriter.create(loc, return_input, input, approx); -} - -class ApproximateTanhLowering : public OpRewritePattern { +template +class ApproximateOnExtendedF32Lowering : public OpRewritePattern { public: - explicit ApproximateTanhLowering(MLIRContext *ctx) - : OpRewritePattern(ctx, 100) {} + explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/100) {} - LogicalResult matchAndRewrite(TanhOp tanhOp, + virtual Value emitApproximation(ValueRange, Location, + PatternRewriter &) const = 0; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Type operand_type = tanhOp.getType(); + Location loc = op.getLoc(); + auto raw_args = op.getOperation()->getOperands(); - if (operand_type.isF64()) { + // Supports only f16 and f32 for now. + if (!op.getType().isF16() && !op.getType().isF32()) return failure(); + + // Extend operands to f32 if needed and possible. + SmallVector f32_args; + f32_args.reserve(raw_args.size()); + for (Value arg : raw_args) { // Similar to XLA, do not rewrite f64 as precision might matter. - return failure(); + Type arg_ty = arg.getType(); + if (arg_ty.isF64()) return failure(); + + if (arg_ty.isF16()) + arg = rewriter.create(loc, arg, rewriter.getF32Type()); + + // If we still do not have f32, fail. + if (!arg.getType().isF32()) return failure(); + + f32_args.push_back(arg); } - Location loc = tanhOp.getLoc(); - Value input = tanhOp.operand(); - if (operand_type.isF16()) { - input = rewriter.create(loc, input, rewriter.getF32Type()); - } - - // If we still do not have f32, fail. - if (!input.getType().isF32()) { - return failure(); - } - - Value result = EmitTanhApproximation(input, loc, rewriter); + Value result = emitApproximation(f32_args, loc, rewriter); + assert(result.getType().isF32() && "Expect f32 intermediate result."); // Truncate back if needed. - if (operand_type.isF16()) { + if (op.getType().isF16()) result = rewriter.create(loc, result, rewriter.getF16Type()); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class ApproximateTanhLowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateTanhLowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Emits the fast tanh approximation that is also used by XLA. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + // For small values of x, we can approximate tanh(x) = x. For extremely + // small values of x (|x| < 1e-37), the other approximation would evaluate + // tanh(x) = 0. + Value input = args.front(); + assert(input.getType().isF32()); + constexpr float kCanUseApprox = 0.0004; + Value abs_value = rewriter.create(loc, input); + Value can_use_approx = rewriter.create( + loc, rewriter.getF32FloatAttr(kCanUseApprox)); + Value return_input = rewriter.create(loc, CmpFPredicate::OLT, + abs_value, can_use_approx); + // Clamp the input to [-c, c]. + Value max_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(7.90531110763549805f)); + Value smaller_than_max = + rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); + Value clamped_half = + rewriter.create(loc, smaller_than_max, input, max_clamp); + Value min_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); + Value larger_than_min = rewriter.create(loc, CmpFPredicate::UGE, + clamped_half, min_clamp); + Value input_clamped = rewriter.create(loc, larger_than_min, + clamped_half, min_clamp); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + Value input_squared = + rewriter.create(loc, input_clamped, input_clamped); + Value numerator = rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = rewriter.create( + loc, rewriter.create(loc, input_squared, numerator), + rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); } - rewriter.replaceOp(tanhOp, {result}); - return success(); + numerator = rewriter.create(loc, input_clamped, numerator); + + Value denominator = rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = rewriter.create( + loc, rewriter.create(loc, input_squared, denominator), + rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); + } + + Value approx = rewriter.create(loc, numerator, denominator); + + return rewriter.create(loc, return_input, input, approx); + } +}; + +class ApproximateAtan2Lowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateAtan2Lowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Reduces atan2 to atan in the same way XLA does it. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + Value y = args[0]; + Value x = args[1]; + assert(x.getType().isF32() && y.getType().isF32() && + "expect f32 arguments"); + Value ax = rewriter.create(loc, x); + Value ay = rewriter.create(loc, y); + Value le_ax_ay = rewriter.create(loc, CmpFPredicate::OLE, ax, ay); + Value min_ax_ay = rewriter.create(loc, le_ax_ay, ax, ay); + Value max_ax_ay = rewriter.create(loc, le_ax_ay, ay, ax); + Value zero_to_one = rewriter.create(loc, min_ax_ay, max_ax_ay); + Value a = emitAtanCoreApproximation(zero_to_one, loc, rewriter); + + Value pi_over_2 = + rewriter.create(loc, rewriter.getF32FloatAttr(1.57079637f)); + a = rewriter.create( + loc, le_ax_ay, rewriter.create(loc, pi_over_2, a), a); + + Value zero = rewriter.create(loc, rewriter.getF32FloatAttr(0)); + Value lt_x_0 = rewriter.create(loc, CmpFPredicate::OLT, x, zero); + Value pi = + rewriter.create(loc, rewriter.getF32FloatAttr(3.14159274f)); + a = rewriter.create(loc, lt_x_0, + rewriter.create(loc, pi, a), a); + + Value t = rewriter.create(loc, lt_x_0, pi, zero); + Value eq_y_0 = rewriter.create(loc, CmpFPredicate::OEQ, y, zero); + a = rewriter.create(loc, eq_y_0, t, a); + + // Propagate nan. + Value is_nan = rewriter.create(loc, CmpFPredicate::UNO, y, x); + Value nan = rewriter.create( + loc, rewriter.getF32FloatAttr(std::numeric_limits::quiet_NaN())); + a = rewriter.create(loc, is_nan, nan, a); + + // x and y are +- inf. + Value three_pi_over_4 = + rewriter.create(loc, rewriter.getF32FloatAttr(2.3561945f)); + Value pi_over_4 = rewriter.create( + loc, rewriter.getF32FloatAttr(0.785398185f)); + t = rewriter.create(loc, lt_x_0, three_pi_over_4, + pi_over_4); + Value inf = rewriter.create( + loc, rewriter.getF32FloatAttr(std::numeric_limits::infinity())); + Value eq_x_inf = rewriter.create(loc, CmpFPredicate::OEQ, x, inf); + Value eq_y_inf = rewriter.create(loc, CmpFPredicate::OEQ, y, inf); + Value all_inf = rewriter.create(loc, eq_x_inf, eq_y_inf); + a = rewriter.create(loc, all_inf, t, a); + + return rewriter.create(loc, a, y); + } + + private: + // The core atan reduction derives from the heuristic described in + // https://arxiv.org/abs/1508.03211 and has a < 0.95 ulp error in the [-1, 1] + // range (though that assumed FMA was available, and it is not here). This is + // the same approximation that is also used by XLA. + Value emitAtanCoreApproximation(Value x, Location loc, + PatternRewriter &rewriter) const { + auto constant = [&](float c) { + return rewriter.create(loc, rewriter.getF32FloatAttr(c)); + }; + + // Computes ab + c. + auto mul_add = [&](Value a, Value b, Value c) { + Value prod = rewriter.create(loc, a, b); + return rewriter.create(loc, prod, c); + }; + + Value s = rewriter.create(loc, x, x); + Value r = constant(0.0027856871f); + r = mul_add(r, s, constant(-0.0158660002f)); + r = mul_add(r, s, constant(0.042472221f)); + r = mul_add(r, s, constant(-0.0749753043f)); + r = mul_add(r, s, constant(0.106448799f)); + r = mul_add(r, s, constant(-0.142070308f)); + r = mul_add(r, s, constant(0.199934542f)); + r = mul_add(r, s, constant(-0.333331466f)); + r = rewriter.create(loc, r, s); + return mul_add(r, x, x); + } +}; + +class ApproximateAtanLowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateAtanLowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation + // for the argument range [-1, 1]. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + Value x = args.front(); + assert(x.getType().isF32()); + Value one = rewriter.create(loc, rewriter.getF32FloatAttr(1)); + return rewriter.create(loc, x, one); } }; @@ -146,7 +272,12 @@ createLegalizeTrigonometricToApproximationPass() { void PopulateTrigonometricToApproximationPatterns( mlir::MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + // clang-format off + patterns->insert< + ApproximateAtanLowering, + ApproximateAtan2Lowering, + ApproximateTanhLowering>(context); + // clang-format on } } // namespace mhlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc index 9f7c946577d..491f1c01cf7 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -37,7 +37,6 @@ limitations under the License. using mlir::FunctionPass; using mlir::OwningRewritePatternList; -using mlir::PassRegistration; using mlir::PassWrapper; namespace { @@ -60,7 +59,7 @@ namespace { void PopulateComplexLoweringPatterns(MLIRContext* context, OwningRewritePatternList* patterns) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); } } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 8dec4bba0fb..ada30a289a4 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -38,7 +38,6 @@ using mlir::LogicalResult; using mlir::MLIRContext; using mlir::OpRewritePattern; using mlir::OwningRewritePatternList; -using mlir::PassRegistration; using mlir::PassWrapper; using mlir::PatternRewriter; using mlir::RankedTensorType; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc index 32a846e79ef..febd4423bf2 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc @@ -24,7 +24,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" using mlir::FunctionPass; -using mlir::PassRegistration; using mlir::PassWrapper; namespace { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 0be4c6819f7..7c01fa22372 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -48,7 +48,7 @@ namespace { // TODO(herhut): Generate these out of op definitions. #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(TanOp) sep fn(AcosOp) sep fn(SinhOp) + fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp) template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 6d8814545d4..974585b12c5 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> { return %2 : tensor<4xf64> } +// CHECK-LABEL: remainder_fold_int +func @remainder_fold_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32> + %1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32> + // CHECK: mhlo.constant dense<[2, 1, 0, 1]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: remainder_fold_float +func @remainder_fold_float() -> tensor<4xf32> { + %0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32> + %1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) + return %2 : tensor<4xf32> +} + // CHECK-LABEL: max_scalar_fold func @max_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<7> : tensor<4xi64> @@ -951,6 +969,22 @@ func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32 // CHECK: return %[[RET]] : tensor<5x6x4xf32> } +// CHECK-LABEL: gather_to_slice_reshape +func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> { + %0 = constant dense<[1, 2]> : tensor<2xi32> + %1 = "mhlo.gather"(%arg0, %0) { + dimension_numbers = {collapsed_slice_dims = dense<[2]> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, + start_index_map = dense<[0, 2]> : tensor<2xi64>}, + indices_are_sorted = false, + slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32> + return %1 : tensor<3x6xf32> + // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 3]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32> + // CHECK: %[[V1:.*]] = "mhlo.reshape"(%[[V0]]) : (tensor<3x6x1xf32>) -> tensor<3x6xf32> + // CHECK: return %[[V1]] : tensor<3x6xf32> +} + // CHECK-LABEL: func @fold_and_same func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir index af19a9b5c1c..60ec26f48a1 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s // Check the non-broadcast case for each registered op, then just check a // representative op for detailed broadcast semantics. @@ -325,7 +325,7 @@ func @addUnrankedUnranked( // CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index // Handle scalar LHS case // CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[SCALAR_LHS:.*]] = tensor_cast %[[LHS]] : tensor<*xf32> to tensor // CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor, tensor<*xf32>) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_10]] : tensor<*xf32> // CHECK: } else { @@ -334,7 +334,7 @@ func @addUnrankedUnranked( // CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index // Handle scalar RHS case // CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { -// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[SCALAR_RHS:.*]] = tensor_cast %[[RHS]] : tensor<*xf32> to tensor // CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor) -> tensor<*xf32> // CHECK: scf.yield %[[VAL_16]] : tensor<*xf32> // CHECK: } else { diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir index 371e730c30b..2bec91203f9 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_mhlo.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s // Lower statically shaped `constant_like` to constant. // CHECK-LABEL: @constant_like_static_shape diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index 960a769c388..3caa4f0bd3b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -236,6 +236,21 @@ func @complex(%real: memref<2x2xf32>, // ----- +// BOTH-LABEL: func @complex_dyn +func @complex_dyn(%real: memref, + %imag: memref, + %result: memref>) { + %tensor_real = tensor_load %real : memref + %tensor_imag = tensor_load %imag : memref + %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) + : (tensor, tensor) -> tensor> + // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref> + return +} + +// ----- + // BOTH-LABEL: func @real func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> @@ -248,6 +263,18 @@ func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @real_dyn +func @real_dyn(%operand: memref>, %result: memref) { + %tensor_operand = tensor_load %operand : memref> + %tensor_result = "mhlo.real"(%tensor_operand) + : (tensor>) -> tensor + // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref + return +} + +// ----- + // BOTH-LABEL: func @imag func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> @@ -260,6 +287,18 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @imag_dyn +func @imag_dyn(%operand: memref>, %result: memref) { + %tensor_operand = tensor_load %operand : memref> + %tensor_result = "mhlo.imag"(%tensor_operand) + : (tensor>) -> tensor + // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref + return +} + +// ----- + // BOTH-LABEL: func @iota func @iota(%result: memref<10xi32>) { %tensor_result = "mhlo.iota"() @@ -344,6 +383,18 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// BOTH-LABEL: func @not +func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) { + %tensor_operand = tensor_load %operand : memref<2x2xi32> + %tensor_result = "mhlo.not"(%tensor_operand) + : (tensor<2x2xi32>) -> tensor<2x2xi32> + // BOTH: "lmhlo.not"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xi32> + return +} + +// ----- + // BOTH-LABEL: func @rsqrt func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> @@ -535,3 +586,50 @@ func @transpose(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { tensor_store %tensor_result, %result : memref<2x2xf32> return } + +// ----- + +// BOTH-LABEL: func @custom_call +// BOTH-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>, [[RESULT:%.*]]: memref<4x4xf16>) +func @custom_call(%arg0: memref<2x2xf32>, %arg1: memref<2x3xf32>, %result: memref<4x4xf16>) { + %arg0_tensor = tensor_load %arg0 : memref<2x2xf32> + %arg1_tensor = tensor_load %arg1 : memref<2x3xf32> + // BOTH: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false} + %result_tensor = "mhlo.custom_call"(%arg0_tensor, %arg1_tensor) + {backend_config = "", call_target_name = "foo", has_side_effect = false} + : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16> + tensor_store %result_tensor, %result: memref<4x4xf16> + return +} + +// ---- + +// BOTH-LABEL: func @isfinite +func @isfinite(%arg0: memref<2x2xf32>, %result: memref<2x2xi1>) { + %arg0_tensor = tensor_load %arg0 : memref<2x2xf32> + // BOTH: "lmhlo.is_finite"(%{{.*}}, %{{.*}}) + %result_tensor = "mhlo.is_finite"(%arg0_tensor) : (tensor<2x2xf32>) -> tensor<2x2xi1> + tensor_store %result_tensor, %result: memref<2x2xi1> + return +} + +// ----- + +// Test that assuming ops propagate memref types. +// BOTH-LABEL: func @shape_assuming_memref +func @shape_assuming_memref(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = shape.const_witness true + // BOTH: shape.assuming %{{.*}} -> (memref) + %2 = shape.assuming %1 -> (tensor) { + %3 = shape.shape_of %arg0 : tensor -> tensor + %4 = tensor_cast %3 : tensor to tensor<1xindex> + %5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + %6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor + // BOTH: "lmhlo.maximum"(%6, %9, %20) : (memref, memref, memref) -> () + %7 = mhlo.maximum %5, %6 : tensor + // BOTH: shape.assuming_yield %{{.*}} : memref + shape.assuming_yield %7 : tensor + } + return %2 : tensor +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 263ea1b4040..91490b43f95 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -252,6 +252,20 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { // ----- +// CHECK-LABEL: func @is_finte +func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { + %0 = "mhlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32 +// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32 +// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @select func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir index b138e1cec2e..c25545ca2bd 100644 --- a/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-trigonometric-to-approximation.mlir @@ -119,3 +119,262 @@ func @tanh_f16(%arg0 : f16) -> f16 { // CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32 // CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 // CHECK: return %[[VAL_44]] : f16 + +// ----- + +// CHECK-LABEL: @atan2_f64 +func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 { + // CHECK: atan2 + %res = atan2 %arg0, %arg1 : f64 + return %res : f64 +} + +// ----- + +// CHECK-LABEL: func @atan2_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> f32 +func @atan2_f32(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 + // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 + // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 + // CHECK: %[[VAL_0:.*]] = absf %[[ARG1]] : f32 + // CHECK: %[[VAL_1:.*]] = absf %[[ARG0]] : f32 + // CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 + // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[CST]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_0]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_1]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_2]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_3]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_4]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_5]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_6]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_24:.*]] = subf %[[CST_7]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_26:.*]] = cmpf "olt", %[[ARG1]], %[[CST_8]] : f32 + // CHECK: %[[VAL_27:.*]] = subf %[[CST_9]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_29:.*]] = select %[[VAL_26]], %[[CST_9]], %[[CST_8]] : f32 + // CHECK: %[[VAL_30:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_8]] : f32 + // CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : f32 + // CHECK: %[[VAL_32:.*]] = cmpf "uno", %[[ARG0]], %[[ARG1]] : f32 + // CHECK: %[[VAL_35:.*]] = select %[[VAL_32]], %[[CST_10]], %[[VAL_31]] : f32 + // CHECK: %[[VAL_36:.*]] = select %[[VAL_26]], %[[CST_11]], %[[CST_12]] : f32 + // CHECK: %[[VAL_37:.*]] = cmpf "oeq", %[[ARG1]], %[[CST_13]] : f32 + // CHECK: %[[VAL_38:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_13]] : f32 + // CHECK: %[[VAL_39:.*]] = and %[[VAL_37]], %[[VAL_38]] : i1 + // CHECK: %[[VAL_40:.*]] = select %[[VAL_39]], %[[VAL_36]], %[[VAL_35]] : f32 + // CHECK: %[[VAL_41:.*]] = copysign %[[VAL_40]], %[[ARG0]] : f32 + // CHECK: return %[[VAL_41]] : f32 + %res = atan2 %arg0, %arg1 : f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @atan2_f16 +// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> f16 +func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 + // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 + // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 + // CHECK: %[[VAL_0:.*]] = fpext %[[ARG0]] : f16 to f32 + // CHECK: %[[VAL_1:.*]] = fpext %[[ARG1]] : f16 to f32 + // CHECK: %[[VAL_2:.*]] = absf %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = absf %[[VAL_0]] : f32 + // CHECK: %[[VAL_4:.*]] = cmpf "ole", %[[VAL_2]], %[[VAL_3]] : f32 + // CHECK: %[[VAL_5:.*]] = select %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : f32 + // CHECK: %[[VAL_6:.*]] = select %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_7:.*]] = divf %[[VAL_5]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = mulf %[[VAL_7]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[CST]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_0]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_1]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_2]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_3]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_4]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_5]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_22:.*]] = addf %[[VAL_21]], %[[CST_6]] : f32 + // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_26:.*]] = subf %[[CST_7]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_27:.*]] = select %[[VAL_4]], %[[VAL_26]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = cmpf "olt", %[[VAL_1]], %[[CST_8]] : f32 + // CHECK: %[[VAL_29:.*]] = subf %[[CST_9]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_31:.*]] = select %[[VAL_28]], %[[CST_9]], %[[CST_8]] : f32 + // CHECK: %[[VAL_32:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_8]] : f32 + // CHECK: %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : f32 + // CHECK: %[[VAL_34:.*]] = cmpf "uno", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_37:.*]] = select %[[VAL_34]], %[[CST_10]], %[[VAL_33]] : f32 + // CHECK: %[[VAL_38:.*]] = select %[[VAL_28]], %[[CST_11]], %[[CST_12]] : f32 + // CHECK: %[[VAL_39:.*]] = cmpf "oeq", %[[VAL_1]], %[[CST_13]] : f32 + // CHECK: %[[VAL_40:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_13]] : f32 + // CHECK: %[[VAL_41:.*]] = and %[[VAL_39]], %[[VAL_40]] : i1 + // CHECK: %[[VAL_42:.*]] = select %[[VAL_41]], %[[VAL_38]], %[[VAL_37]] : f32 + // CHECK: %[[VAL_43:.*]] = copysign %[[VAL_42]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 + // CHECK: return %[[VAL_44]] : f16 + %res = atan2 %arg0, %arg1 : f16 + return %res : f16 +} + +// ----- + +// CHECK-LABEL: @atan_f64 +func @atan_f64(%arg : f64) -> f64 { + // CHECK: atan + %res = atan %arg : f64 + return %res : f64 +} + +// ----- + +// CHECK-LABEL: func @atan_f32 +// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32 +func @atan_f32(%arg : f32) -> f32 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32 + // CHECK: %[[VAL_1:.*]] = absf %arg0 : f32 + // CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 + // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_26:.*]] = cmpf "oeq", %arg0, %[[CST_9]] : f32 + // CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = cmpf "uno", %arg0, %[[CST]] : f32 + // CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32 + // CHECK: return %[[VAL_30]] : f32 + %res = atan %arg : f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @atan_f16 +// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16 +func @atan_f16(%arg : f16) -> f16 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32 + // CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32 + // CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32 + // CHECK: %[[VAL_3:.*]] = cmpf "ole", %[[VAL_1]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32 + // CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32 + // CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32 + // CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32 + // CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32 + // CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32 + // CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32 + // CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32 + // CHECK: %[[VAL_27:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_9]] : f32 + // CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32 + // CHECK: %[[VAL_29:.*]] = cmpf "uno", %[[VAL_0]], %[[CST]] : f32 + // CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32 + // CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16 + // CHECK: return %[[VAL_32]] : f16 + %res = atan %arg : f16 + return %res : f16 +} diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index d0e19c16391..debb035328c 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -125,6 +125,20 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { // ----- +// CHECK-LABEL: func @is_finte +func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { + "lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32 +// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { @@ -534,6 +548,19 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // ----- +// CHECK-LABEL: func @not +func @not(%input: memref<2x2xi64>, %result: memref<2x2xi64>) { + "lmhlo.not"(%input, %result) : (memref<2x2xi64>, memref<2x2xi64>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i64, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[N1:.*]] = constant -1 : i64 +// CHECK-NEXT: %[[RESULT:.*]] = xor %[[N1]], %[[OPERAND_IN]] : i64 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i64 + +// ----- + // CHECK-LABEL: func @rem func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -573,6 +600,37 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @sign_bf16 +func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) { + "lmhlo.sign"(%input, %result) : (memref<2x2xbf16>, memref<2x2xbf16>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16 +// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16 + +// ----- + +// CHECK-LABEL: func @sign_i16 +func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) { + "lmhlo.sign"(%input, %result) : (memref<2x2xi16>, memref<2x2xi16>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16 +// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16 +// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16 +// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16 +// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16 +// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 + +// ----- + // CHECK-LABEL: func @sqrt func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () diff --git a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir index a7bd21257a6..b9c91d61377 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s +// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir index aff2f7ff675..4462d9c45c6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -1010,34 +1010,34 @@ func @constant_invalid() -> () { func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK: mhlo.sort - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // ----- func @sort_no_operands() { - // expected-error @+1 {{op requires at least one input}} - %0 = "mhlo.sort"() ( { + // expected-error @+1 {{expected named operation to have atleast 1 result}} + %0:0 = "mhlo.sort"() ( { ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> + }) {dimension = 1 : i64, is_stable = true} : () -> () return } // ----- func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1045,23 +1045,23 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // ----- func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { - // expected-error @+1 {{op requires all inputs to have the same dimensions}} - %0 = "mhlo.sort"(%input0, %input1) ( { + // expected-error @+1 {{op requires the same shape for all operands and results}} + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1069,11 +1069,11 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1081,11 +1081,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1093,11 +1093,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block should have 4 arguments}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1105,11 +1105,11 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1193,3 +1193,24 @@ func @incompatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tens %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor return %0 : tensor } + +// ----- + +func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} + +// ----- + +func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} + +// ----- + +func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} diff --git a/tensorflow/compiler/mlir/init_mlir.cc b/tensorflow/compiler/mlir/init_mlir.cc index 54f8a57d8a6..fac9f51d8ba 100644 --- a/tensorflow/compiler/mlir/init_mlir.cc +++ b/tensorflow/compiler/mlir/init_mlir.cc @@ -20,6 +20,11 @@ limitations under the License. namespace tensorflow { InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) { + llvm::setBugReportMsg( + "TensorFlow crashed, please file a bug on " + "https://github.com/tensorflow/tensorflow/issues with the trace " + "below.\n"); + constexpr char kSeparator[] = "--"; // Find index of separator between two sets of flags. diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 8084ae73d90..eff591895e1 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -477,7 +477,6 @@ cc_library( "transforms/default_quant_params.cc", "transforms/generated_post_quantize.inc", "transforms/generated_quantize.inc", - "transforms/load_quantization_recipe.cc", "transforms/post_quantize.cc", "transforms/prepare_quantize.cc", "transforms/quantize.cc", @@ -670,6 +669,7 @@ cc_library( "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -706,6 +706,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 34200fb88b6..3f4f3997536 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -75,6 +75,7 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/tools/versioning/op_version.h" #include "tensorflow/lite/tools/versioning/runtime_version.h" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 62eaffa8ed9..7d64e268063 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -75,6 +75,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" using llvm::ArrayRef; using mlir::Builder; @@ -271,18 +272,18 @@ StatusOr GetMlirOpName(const tflite::OperatorT& op, return std::string("tfl.basic_lstm"); } - if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { + auto builtin_code = tflite::GetBuiltinCode(&op_code); + if (builtin_code == tflite::BuiltinOperator_CUSTOM) { return std::string("tfl.custom"); } - if (op_code.builtin_code == tflite::BuiltinOperator_IF) { + if (builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); } - if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) { + if (builtin_code == tflite::BuiltinOperator_WHILE) { return std::string("tf.While"); } - llvm::StringRef op_name( - tflite::EnumNameBuiltinOperator(op_code.builtin_code)); + llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code)); return llvm::Twine("tfl.", op_name.lower()).str(); } @@ -637,7 +638,8 @@ StatusOr ConvertOp( } llvm::SmallVector attrs; - if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { + auto builtin_code = tflite::GetBuiltinCode(&op_code); + if (builtin_code == tflite::BuiltinOperator_CUSTOM) { auto status = mlir::CustomOptionsToAttributes( op_code.custom_code, op.custom_options, builder, loc, &attrs); if (!status.ok()) { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 93e85d9ced0..f7ee323957d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -385,28 +385,27 @@ def BinaryOpSameElementTypeConstraint : //===----------------------------------------------------------------------===// def TFL_BroadcastableBinaryBuilder : OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", + "Value lhs, Value rhs", [{ auto resultType = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); if (!resultType) - mlir::emitError(result.location, "non-broadcastable operands"); - result.addOperands({lhs, rhs}); - result.types.push_back(resultType); + mlir::emitError($_state.location, "non-broadcastable operands"); + $_state.addOperands({lhs, rhs}); + $_state.types.push_back(resultType); }]>; def TFL_FusedBroadcastableBinaryBuilder : OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "StringAttr fusedActivationFunction", + "Value lhs, Value rhs, StringAttr fusedActivationFunction", [{ buildFusedBroadcastableBinOp( - &builder, result, lhs, rhs, fusedActivationFunction); + &$_builder, $_state, lhs, rhs, fusedActivationFunction); }]>; def TFL_ComparisonBinaryBuilder : OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", + "Value lhs, Value rhs", [{ - buildComparisonBinOp(&builder, result, lhs, rhs); + buildComparisonBinOp(&$_builder, $_state, lhs, rhs); }]>; //===----------------------------------------------------------------------===// @@ -520,7 +519,11 @@ def TFL_AddOp : TFL_Op<"add", [ let hasOptions = 1; } -def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResultsScale]> { +def TFL_AddNOp : TFL_Op<"add_n", [ + Commutative, + NoSideEffect, + SameOperandsAndResultsScale, + NoQuantizableResult]> { let summary = "add_n operator"; let description = [{ @@ -536,7 +539,9 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect, SameOperandsAndResu ); } -def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [NoSideEffect]> { +def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [ + NoSideEffect, + NoQuantizableResult]> { let summary = [{ Computes the "logical or" of elements across dimensions of a tensor. }]; @@ -693,7 +698,8 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { def TFL_CeilOp: TFL_Op<"ceil", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Ceil operator"; let description = [{ @@ -765,10 +771,10 @@ def TFL_ConstOp : Op ]; } @@ -817,13 +823,12 @@ def TFL_SparseConstOp : Op ]; } @@ -889,7 +894,7 @@ def TFL_DepthwiseConv2DOp : let extraClassDeclaration = [{ // AffineQuantizedOpInterface: int GetChannelDimIndex() { return 3; } - int GetQuantizationDimIndex() { return 3; } + int GetQuantizationDimIndex() { return 3; } // SparseOpInterface: std::vector GetSparseOperands() { return {1}; } std::vector> GetFloatBlockSize() { return {}; } @@ -1002,9 +1007,8 @@ def TFL_GatherOp : TFL_Op<"gather", [ let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value params, Value indices, IntegerAttr axis", - [{ BuildGatherOp(&builder, result, params, indices, axis); }]> + OpBuilder<"Value params, Value indices, IntegerAttr axis", + [{ BuildGatherOp(&$_builder, $_state, params, indices, axis); }]> ]; let results = (outs @@ -1093,7 +1097,8 @@ def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization", [ TFL_OperandHasRank<0, 4>, SameOperandsAndResultShape, SameOperandsAndResultType, - NoSideEffect]> { + NoSideEffect, + NoQuantizableResult]> { let summary = "Local Response Normalization."; let description = [{ @@ -1220,7 +1225,8 @@ def TFL_NonMaxSuppressionV4Op : TFL_Op<"non_max_suppression_v4", [ TFL_OperandHasRank<1, 1>, // Other operands are scalar params. TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>, - TFL_OperandHasRank<4, 0>]> { + TFL_OperandHasRank<4, 0>, + NoQuantizableResult]> { let summary = [{ Greedily selects a subset of bounding boxes in descending order of score, }]; @@ -1269,7 +1275,8 @@ def TFL_NonMaxSuppressionV5Op : TFL_Op<"non_max_suppression_v5", [ TFL_OperandHasRank<1, 1>, // Other operands are scalar params. TFL_OperandHasRank<2, 0>, TFL_OperandHasRank<3, 0>, - TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>]> { + TFL_OperandHasRank<4, 0>, TFL_OperandHasRank<5, 0>, + NoQuantizableResult]> { let summary = [{ Greedily selects a subset of bounding boxes in descending order of score, }]; @@ -1336,10 +1343,9 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ let builders = [ - OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs", + OpBuilder<"Value lhs, Value rhs", [{ - buildComparisonBinOp(&builder, result, lhs, rhs); + buildComparisonBinOp(&$_builder, $_state, lhs, rhs); }]> ]; @@ -1383,7 +1389,8 @@ def TFL_DivOp : TFL_Op<"div", [ def TFL_EluOp: TFL_Op<"elu", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Exponential Linear Unit operator"; let description = [{ Computes the exponential linear @@ -1441,8 +1448,10 @@ def TFL_EqualOp: TFL_Op<"equal", [ let builders = [TFL_ComparisonBinaryBuilder]; } -def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, - SameOperandsAndResultType]> { +def TFL_ExpOp: TFL_Op<"exp", [ + NoSideEffect, + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Natural exponentiation operator"; let description = [{ @@ -1546,7 +1555,8 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] def TFL_FillOp: TFL_Op<"fill", [ NoSideEffect, PredOpTrait<"input and result must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 1>>]> { + TFL_TCresVTEtIsSameAsOp<0, 1>>, + NoQuantizableResult]> { let summary = "Fill the tensor with given value."; let description = [{ Fill the tensor with given value. @@ -1563,7 +1573,8 @@ def TFL_FillOp: TFL_Op<"fill", [ def TFL_FloorOp: TFL_Op<"floor", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Floor operator"; let description = [{ @@ -1581,7 +1592,8 @@ def TFL_FloorDivOp : TFL_Op<"floor_div", [ BinaryOpSameElementTypeConstraint, PredOpTrait<"lhs and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> { + TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, + NoQuantizableResult]> { let summary = "Floor div operator"; let description = [{ @@ -1606,7 +1618,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [ BinaryOpSameElementTypeConstraint, PredOpTrait<"lhs and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>]> { + TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, + NoQuantizableResult]> { let summary = "Division reminder"; let description = [{ @@ -1745,7 +1758,9 @@ def TFL_LessOp : TFL_Op<"less", [ let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; } -def TFL_LogicalAndOp : TFL_Op<"logical_and", [NoSideEffect]> { +def TFL_LogicalAndOp : TFL_Op<"logical_and", [ + NoSideEffect, + NoQuantizableResult]> { let summary = "Logical AND operator"; let description = [{ @@ -1778,7 +1793,9 @@ def TFL_LogicalNotOp : TFL_Op<"logical_not", [ let results = (outs TFL_BoolTensor:$output); } -def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { +def TFL_LogicalOrOp : TFL_Op<"logical_or", [ + NoSideEffect, + NoQuantizableResult]> { let summary = "Logical OR operator"; let description = [{ @@ -2005,7 +2022,8 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { def TFL_RoundOp: TFL_Op<"round", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Round operator"; let description = [{ @@ -2213,7 +2231,8 @@ def TFL_MulOp : TFL_Op<"mul", [ def TFL_NegOp: TFL_Op<"neg", [ NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultType]> { + SameOperandsAndResultType, + NoQuantizableResult]> { let summary = "Negation operator"; let description = [{ @@ -2462,11 +2481,10 @@ def TFL_ReluOp: TFL_Op<"relu", [ // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + let builders = [OpBuilder<"Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; } @@ -2490,11 +2508,10 @@ def TFL_Relu6Op: TFL_Op<"relu6", [ // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + let builders = [OpBuilder<"Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; } @@ -2519,10 +2536,10 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + "Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; } @@ -2622,7 +2639,8 @@ def TFL_RangeOp: TFL_Op<"range", [ TFL_OperandHasRank<2, 0>, PredOpTrait<"operands and output must have same element type", And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>, - TCresVTEtIsSameAsOp<0, 2>]>>]> { + TCresVTEtIsSameAsOp<0, 2>]>>, + NoQuantizableResult]> { let summary = "Range operator"; let description = [{ @@ -2701,12 +2719,11 @@ def TFL_SelectOp : TFL_Op<"select", [ TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); // TODO(jpienaar): autogenerate this. - let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value condition, Value x, Value y", + let builders = [OpBuilder<"Value condition, Value x, Value y", [{ auto resultType = x.getType(); - result.addOperands({condition, x, y}); - result.types.push_back(resultType); + $_state.addOperands({condition, x, y}); + $_state.types.push_back(resultType); }]>]; let hasOptions = 1; @@ -2737,10 +2754,9 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [ let results = (outs TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); - let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value cond, Value x, Value y", + let builders = [OpBuilder<"Value cond, Value x, Value y", [{ - BuildSelectV2Op(&builder, result, cond, x, y); + BuildSelectV2Op(&$_builder, $_state, cond, x, y); }]>]; let hasOptions = 1; @@ -2915,11 +2931,10 @@ def TFL_TanhOp: TFL_Op<"tanh", [ // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Value input", + let builders = [OpBuilder<"Value input", [{ - state.addOperands({input}); - state.addTypes(input.getType()); + $_state.addOperands({input}); + $_state.addTypes(input.getType()); }]> ]; @@ -2989,9 +3004,8 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [ TFL_TensorOf<[F32, I8, I32, I64, UI8, QI8, QUI8]>:$values, TFL_I32Tensor:$indices); - let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value input, Value k", - [{ BuildTopKOp(&builder, result, input, k); }]>]; + let builders = [OpBuilder<"Value input, Value k", + [{ BuildTopKOp(&$_builder, $_state, input, k); }]>]; let hasOptions = 1; } @@ -3066,7 +3080,8 @@ def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultType, SameOperandsAndResultShape, - NoSideEffect]> { + NoSideEffect, + NoQuantizableResult]> { let summary = "ZerosLike operator"; let description = [{ @@ -3523,11 +3538,11 @@ def TFL_QConstOp : Op:$output); let builders = [OpBuilder< - "OpBuilder &, OperationState &state, TypeAttr qtype, Attribute value", + "TypeAttr qtype, Attribute value", [{ - state.addAttribute("qtype", qtype); - state.addAttribute("value", value); - state.addTypes(qtype.getValue()); + $_state.addAttribute("qtype", qtype); + $_state.addAttribute("value", value); + $_state.addTypes(qtype.getValue()); }]> ]; } @@ -3552,14 +3567,14 @@ def TFL_SparseQConstOp : Op:$output); let builders = [OpBuilder< - "OpBuilder &, OperationState &state, TypeAttr qtype, " - "Attribute value, SparsityParameterAttr s_param, Attribute compressed_data", + "TypeAttr qtype, Attribute value, SparsityParameterAttr s_param, " + "Attribute compressed_data", [{ - state.addTypes(qtype.getValue()); - state.addAttribute("qtype", qtype); - state.addAttribute("value", value); - state.addAttribute("s_param", s_param); - state.addAttribute("compressed_data", compressed_data); + $_state.addTypes(qtype.getValue()); + $_state.addAttribute("qtype", qtype); + $_state.addAttribute("value", value); + $_state.addAttribute("s_param", s_param); + $_state.addAttribute("compressed_data", compressed_data); }]> ]; } @@ -4240,7 +4255,8 @@ def TFL_SVDFOp : def TFL_SegmentSumOp: TFL_Op<"segment_sum", [ NoSideEffect, PredOpTrait<"input and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoQuantizableResult]> { let summary = "SegmentSum operator"; let description = [{ diff --git a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir index 6b3c5b04aa4..d92bdc3f460 100644 --- a/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir +++ b/tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir @@ -2,7 +2,7 @@ func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> @@ -29,8 +29,8 @@ func @testDilatedConvWithNonConstantPadAndCrops(%arg0: tensor<1x128x128x3xf32>, func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> - %cst_1 = constant dense<1> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> @@ -44,10 +44,11 @@ func @testDilatedConvWithNonZeroBasePadding(%arg0: tensor<1x128x128x3xf32>, %arg func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithNonTrivialDilations @@ -59,80 +60,85 @@ func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %ar func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %2 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConv // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>) - // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> - %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithPad // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> %2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32> - %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %4 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithPad // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.Conv2D"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedConvWithBiasAdd // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>, %arg2: tensor<8xf32>) -> tensor<1x128x128x8xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<2> : tensor<2x2xi32> + %cst_0 = constant dense<4> : tensor<2x2xi32> + %cst_1 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> %1 = "tf.DepthwiseConv2dNative"(%0, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32> - %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> + %2 = "tf.BatchToSpaceND"(%1, %cst, %cst_1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> %3 = "tf.BiasAdd"(%2, %arg2) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>) - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32> } @@ -140,12 +146,13 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> @@ -153,7 +160,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -162,12 +169,13 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %5 = "tf.BiasAdd"(%4, %arg2) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> @@ -175,7 +183,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, % // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -184,20 +192,21 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedConvWithExpandSqueeze2 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -206,20 +215,21 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>, %arg2: tensor) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32> %4 = "tf.BiasAdd"(%3, %arg2) : (tensor<4x?x?xf32>, tensor) -> tensor<4x?x?xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> return %5 : tensor<1x128x128xf32> // CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2 // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -228,7 +238,8 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> @@ -251,13 +262,14 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32> %4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32> - %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> + %5 = "tf.BatchToSpaceND"(%4, %cst, %cst_2) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32> %6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> return %6 : tensor<1x128x128xf32> @@ -265,7 +277,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, % // CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>) // CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor // CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor) -> tensor<1x128x128x1xf32> - // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> + // CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> // CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32> @@ -274,12 +286,13 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, % func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32> { %cst = constant dense<[2, 2]> : tensor<2xi32> %cst_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %cst_1 = constant dense<2> : tensor<2x2xi32> + %cst_1 = constant dense<4> : tensor<2x2xi32> + %cst_2 = constant dense<0> : tensor<2x2xi32> %0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32> %1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor) -> tensor<4x68x68x1xf32> %2 = "tf.Conv2D"(%1, %arg1) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32> %3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32> - %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> + %4 = "tf.BatchToSpaceND"(%3, %cst, %cst_2) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32> return %4 : tensor<1x128x128x1xf32> // CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt index adfcd93b4bc..117edd02beb 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/fake_quant_per_channel.pbtxt @@ -459,11 +459,13 @@ node { # CHECK-LABEL: { # CHECK: version: 3, # CHECK: operator_codes: [ { -# CHECK: builtin_code: CONV_2D, -# CHECK: version: 3 +# CHECK: deprecated_builtin_code: 3, +# CHECK: version: 3, +# CHECK: builtin_code: CONV_2D # CHECK: }, { -# CHECK: builtin_code: RESHAPE, +# CHECK: deprecated_builtin_code: 22, # CHECK: version: 1 +# CHECK: builtin_code: RESHAPE # CHECK: } ], # CHECK: subgraphs: [ { # CHECK: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index a576e559c00..4de278ee324 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1003,11 +1003,11 @@ func @batch_to_space_nd_unsupported(%arg0: tensor, %arg1: tensor< // CHECK: "tf.BatchToSpaceND" } -func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor { - %0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor - return %0 : tensor +func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<*xf32> { + %0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32> + return %0 : tensor<*xf32> // CHECK-LABEL: space_to_batch_nd - // CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor + // CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32> } func @split(%arg0: tensor, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir index 8389045fc57..f5e5087d420 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/basic_lstm.mlir @@ -4,8 +4,9 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LSTM, +// CHECK-NEXT: deprecated_builtin_code: 16, // CHECK-NEXT: version: 2 +// CHECK-NEXT: builtin_code: LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir index 2d906d6901e..02767ddceba 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir @@ -6,14 +6,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: MUL, +// CHECK-NEXT: deprecated_builtin_code: 18, // CHECK-NEXT: version: 1 +// CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "MyCustomOp" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "MyCustomOp", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: EXP, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 47, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: EXP // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir index 98c3eb154e1..a576c84e207 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d.mlir @@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: DEQUANTIZE, + // CHECK-NEXT: deprecated_builtin_code: 6, // CHECK-NEXT: version: 1 + // CHECK-NEXT: builtin_code: DEQUANTIZE // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D, + // CHECK-NEXT: deprecated_builtin_code: 4, // CHECK-NEXT: version: 1 + // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir index 86f27936946..7d2d84d242a 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/depthwise_conv2d_v2.mlir @@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: DEQUANTIZE, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 6, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: DEQUANTIZE // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D, - // CHECK-NEXT: version: 2 + // CHECK-NEXT: deprecated_builtin_code: 4, + // CHECK-NEXT: version: 2, + // CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir index c034fa7e462..66ec5ed8a04 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex_enable_builtin.mlir @@ -5,11 +5,13 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: MUL, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 18, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: EXP, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 47, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: EXP // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir index 6d8c54b783a..dc8590b4a20 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fake_quant.mlir @@ -6,8 +6,9 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: FAKE_QUANT, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 80, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: FAKE_QUANT // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir index 018d99fc74d..245260d994d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_exclusively.mlir @@ -4,8 +4,9 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, +// CHECK-NEXT: deprecated_builtin_code: 32, // CHECK-NEXT: custom_code: "FlexAddV2" +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir index a5e6d4aabb5..82ac52d2a64 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir @@ -5,8 +5,9 @@ func @main(tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex, tensor<4xf64>) -> tensor<4xf64> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "FlexAdd" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "FlexAdd", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir index 8a9175b5c59..e15c7f23585 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_tflite_op.mlir @@ -5,14 +5,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: deprecated_builtin_code: 18, +// CHECK-NEXT: version: 1, // CHECK-NEXT: builtin_code: MUL -// CHECK-NEXT: version: 1 // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "FlexDiv" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "FlexDiv", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: }, { +// CHECK-NEXT: deprecated_builtin_code: 47, +// CHECK-NEXT: version: 1, // CHECK-NEXT: builtin_code: EXP -// CHECK-NEXT: version: 1 // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir index bbe4fdb8337..58f693c1b0a 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected.mlir @@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: FULLY_CONNECTED, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 9, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: FULLY_CONNECTED // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir index 0abe720ccba..913a69c4d46 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/fully_connected_v2.mlir @@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: FULLY_CONNECTED, - // CHECK-NEXT: version: 2 + // CHECK-NEXT: deprecated_builtin_code: 9, + // CHECK-NEXT: version: 2, + // CHECK-NEXT: builtin_code: FULLY_CONNECTED // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir index 3adee1dec77..2d5852dd83d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/hashtable_resource.mlir @@ -3,8 +3,9 @@ // CHECK: { // CHECK: version: 3, // CHECK: operator_codes: [ { -// CHECK: builtin_code: CUSTOM, -// CHECK: custom_code: "HashTableV2" +// CHECK: deprecated_builtin_code: 32, +// CHECK: custom_code: "HashTableV2", +// CHECK: builtin_code: CUSTOM // CHECK: } ], // CHECK: subgraphs: [ { // CHECK: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir index 7290209cc4a..86794afdf4c 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir @@ -4,16 +4,19 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LESS, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 58, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: LESS // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: IF, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 118, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: IF // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: MUL, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 18, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir index 84cbf48c099..d24fa33fa13 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/logical.mlir @@ -5,11 +5,13 @@ func @main(tensor<4xi1>) -> tensor<4xi1> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: LOGICAL_OR, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 84, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: LOGICAL_OR // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: LOGICAL_AND, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 86, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: LOGICAL_AND // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index 707bc926870..19bfc661425 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -4,8 +4,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LSTM, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 16, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir index 5985ffaa446..a48e3c82e9d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm_quantized.mlir @@ -7,8 +7,9 @@ func @main(%arg0: tensor<1x528x!quant.uniform> // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: LSTM, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 16, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir index 297a8b8cb59..c4cb910f17b 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/math.mlir @@ -5,20 +5,25 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 99, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: MUL, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 18, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: MUL // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: DIV, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 42, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: DIV // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: EXP, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 47, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: EXP // CHECK-NEXT: }, { - // CHECK-NEXT: builtin_code: NEG, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 59, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: NEG // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir index 15fce806a70..04ceb3855f2 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/mul_v2.mlir @@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform>) -> tensor<3x!quant.uniform) -> tensor<1x1x1x16xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { - // CHECK-NEXT: builtin_code: AVERAGE_POOL_2D, - // CHECK-NEXT: version: 1 + // CHECK-NEXT: deprecated_builtin_code: 1, + // CHECK-NEXT: version: 1, + // CHECK-NEXT: builtin_code: AVERAGE_POOL_2D // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir index 4f28ad327df..e39ded18b86 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/numeric_verify.mlir @@ -3,8 +3,9 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "NumericVerify" +// CHECK-NEXT: deprecated_builtin_code: 32, +// CHECK-NEXT: custom_code: "NumericVerify", +// CHECK-NEXT: builtin_code: CUSTOM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir index dbe10a3f90c..81065798271 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/quantization.mlir @@ -4,20 +4,25 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: QUANTIZE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 114, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: QUANTIZE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CONV_2D, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 3, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: CONV_2D // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: RESHAPE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 22, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: RESHAPE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: SOFTMAX, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 25, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SOFTMAX // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: DEQUANTIZE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 6, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: DEQUANTIZE // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir index 15defbc3957..129e037a2ee 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/reshape.mlir @@ -5,8 +5,9 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: RESHAPE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 22, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: RESHAPE // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir index 2182db1d39e..ea380f8f47d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir @@ -7,8 +7,9 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: SUB, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 41, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SUB // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: } ], diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir index 3d29823c93c..e2ad4c73baa 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir @@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: SVDF, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 27, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SVDF // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir index 8dfa68798b8..c1d30a9b4d4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf_v2.mlir @@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) -> // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: SVDF, -// CHECK-NEXT: version: 2 +// CHECK-NEXT: deprecated_builtin_code: 27, +// CHECK-NEXT: version: 2, +// CHECK-NEXT: builtin_code: SVDF // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir index 996543cc9c7..87e3ccf4688 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/tfl_while_op.mlir @@ -3,14 +3,17 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: WHILE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 119, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: WHILE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: GREATER, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 61, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: GREATER // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: SUB, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 41, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SUB // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: } ], diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir index ca335ebd000..5252dc21a59 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir @@ -4,8 +4,9 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: TRANSPOSE_CONV, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 67, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: TRANSPOSE_CONV // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir index 01410d370d4..690331dec84 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/type_attr.mlir @@ -3,8 +3,9 @@ // CHECK: { // CHECK: version: 3, // CHECK: operator_codes: [ { -// CHECK: builtin_code: CUSTOM, -// CHECK: custom_code: "SomeOperation" +// CHECK: deprecated_builtin_code: 32, +// CHECK: custom_code: "SomeOperation", +// CHECK: builtin_code: CUSTOM // CHECK: } ], // CHECK: subgraphs: [ { // CHECK: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir index 9b0315e1e20..38c1ed40b35 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -4,8 +4,9 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 44, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir index 67349b857f7..575499c8d66 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir @@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 35, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir index d69e8f40311..e58c54219ab 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir @@ -3,14 +3,17 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: WHILE, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 119, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: WHILE // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: GREATER, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 61, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: GREATER // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: SUB, -// CHECK-NEXT: version: 1 +// CHECK-NEXT: deprecated_builtin_code: 41, +// CHECK-NEXT: version: 1, +// CHECK-NEXT: builtin_code: SUB // CHECK-NEXT: }, { // CHECK-NEXT: version: 1 // CHECK-NEXT: } ], diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 8d64bc6ed0a..154e4fa8e0e 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -272,6 +272,22 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { // CHECK: return %[[RES]] : tensor<4x2xf32> } +// CHECK-LABEL: @fuseBroadcastMulIntoFullyConnected +func @fuseBroadcastMulIntoFullyConnected(%arg0: tensor<1x10368xbf16>) -> tensor<32x1x256xbf16> { + %cst_0 = constant dense<2.0> : tensor<256x10368xbf16> + %cst_1 = constant unit + %cst_2 = constant dense<3.0> : tensor<32x1x256xbf16> + %0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) { + fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT" + } : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16> + %1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16> + return %1 : tensor<32x1x256xbf16> + +// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16> +// CHECK: %[[V1:.*]] = "tfl.mul"(%[[V0]], {{.*}}) {{{.*}}} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16> +// CHECK: return %[[V1]] : tensor<32x1x256xbf16> +} + // CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index a0cc6cc1fdb..4e7c08945d4 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -666,4 +666,11 @@ func @xla_gather_to_slice(%arg0 : tensor<1x9x104x768xf32>) -> tensor<*xf32> { // CHECK: return %[[V0]] : tensor<*xf32> } +// CHECK-LABEL: DontMatchFusedBatchNormV3 +func @DontMatchFusedBatchNormV3(%arg0 :tensor, %arg1 : tensor<576xf32>, %arg2 : tensor<576xf32>, %arg3 : tensor<576xf32>,%arg4 : tensor<576xf32>) -> (tensor) { + %result:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>) -> (tensor, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<*xf32>) + return %result : tensor + // CHECK: "tf.FusedBatchNormV3" +} + } diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index c3fc2bcdaaf..a2387e89483 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -139,6 +139,9 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( bool emit_select_tf_ops, bool emit_custom_ops, const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result, mlir::PassManager* pass_manager) { + // Explicitly disable dumping Op details on failures. + module.getContext()->printOpOnDiagnostic(false); + // Register a warning handler only log to std out. mlir::ScopedDiagnosticHandler s( module.getContext(), [](mlir::Diagnostic& diag) { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 6f7f3b88471..13c7a08a094 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -665,7 +665,7 @@ void LegalizeTF::runOnFunction() { auto func = getFunction(); // Add the generated patterns to the list. - populateWithGenerated(context, &patterns); + populateWithGenerated(context, patterns); patterns .insert(); OwningRewritePatternList patterns; - populateWithGenerated(context, &patterns); + populateWithGenerated(context, patterns); patterns.insert { LogicalResult matchAndRewrite(TFL::MulOp mul_op, PatternRewriter &rewriter) const override { + // If we are broadcasting on the lhs then don't fold the multiply as it + // would increase the amount of compute done by the fully connected op. + if (mul_op.lhs().getType() != mul_op.getType()) return failure(); + // Mul. DenseElementsAttr cst; Value constant_val = mul_op.rhs(); @@ -794,7 +798,7 @@ void Optimize::runOnFunction() { // Potentially the binary ops might be fused together, like hard_swish, thus // we explore these potentially first and then fuse the binary ops with the // following ops in a second pattern match. - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); patterns.insert, FuseFullyConnectedAndReluX, diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 1c6550bc902..ca30b2f1fcf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -144,7 +144,7 @@ void PostQuantizePass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto* ctx = func.getContext(); - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); patterns.insert>(ctx); applyPatternsAndFoldGreedily(func, patterns); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 326b6b23398..5cfdb4b982d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -40,42 +40,6 @@ def : Pat< (TF_MulOp $t, (TF_MulOp:$mul (TF_RsqrtOp (TF_AddOp $v, (TF_ConstOp $variance_epsilon))), $gamma)), (TF_SubOp $beta, (TF_MulOp $m, $mul)))>; -// Converts tf.FusedBatchNormV3 into a sequence of more primitive arithmetic -// operations. Specifically, performs the following calculation: -// -// (x - mean) * scale / sqrt(variance + epsilon) + offset -// -// Let multiplier = scale / sqrt(variance + epsilon), -// to compute -// (x - mean) * scale / sqrt(variance + epsilon) + offset, -// is then to compute -// (x * multiplier) + (offset - mean * multiplier). - -def : Pattern< - (TF_FusedBatchNormV3Op:$root - $x, $scale, $offset, $mean, $variance, - F32Attr:$epsilon, $exponential_avg_factor, - $data_format, FalseBoolAttr:$is_training), - [(TF_AddOp - (TF_MulOp - $x, - (TF_MulOp:$multiplier - $scale, - (TF_RsqrtOp - (TF_AddOp $variance, - (TF_ConstOp $epsilon))))), - (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), - // We already guaranteed that the last five results have no use so it does - // not matter what value we provide here for replacement. - /*batch_mean=*/(replaceWithValue $x), - /*batch_variance=*/(replaceWithValue $x), - /*reserve_space_1=*/(replaceWithValue $x), - /*reserve_space_2=*/(replaceWithValue $x), - /*reserve_space_3=*/(replaceWithValue $x)], - [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), - (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), - (HasNoUseOf:$root__5)]>; - class TFi32 : ConstantAttr(v)>; // Matmul without transpose on b to matmul with explicit transpose op and diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 2b118d0b810..ecca3d38deb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -765,6 +765,278 @@ struct ConvertFusedBatchNorm : public OpRewritePattern { } }; +// The below pattern is equivalent to the DRR rule below +// The checks are dependent on generated values, so we can't add +// the checks on intermediate values, ideally we should find equivalent +// checks that guarantees the resultant ops are valid. +// The extra conditions are the broadcasting conditions. +// +// The pattern lower FusedBatchNormV3 to arithmetic ops. +// Specifically, performs the following calculation: +// +// (x - mean) * scale / sqrt(variance + epsilon) + offset +// +// Let multiplier = scale / sqrt(variance + epsilon), +// to compute +// (x - mean) * scale / sqrt(variance + epsilon) + offset, +// is then to compute +// (x * multiplier) + (offset - mean * multiplier). +// +// def : Pattern< +// (TF_FusedBatchNormV3Op:$root +// $x, $scale, $offset, $mean, $variance, +// F32Attr:$epsilon, $exponential_avg_factor, +// $data_format, FalseBoolAttr:$is_training), +// [(TF_AddOp +// (TF_MulOp +// $x, +// (TF_MulOp:$multiplier +// $scale, +// (TF_RsqrtOp +// (TF_AddOp $variance, +// (TF_ConstOp $epsilon))))), +// (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), +// // We already guaranteed that the last five results have no use so it does +// // not matter what value we provide here for replacement. +// /*batch_mean=*/(replaceWithValue $x), +// /*batch_variance=*/(replaceWithValue $x), +// /*reserve_space_1=*/(replaceWithValue $x), +// /*reserve_space_2=*/(replaceWithValue $x), +// /*reserve_space_3=*/(replaceWithValue $x)], +// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), +// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), +// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>; + +struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { + explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context) + : ::mlir::RewritePattern( + "tf.FusedBatchNormV3", + {"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}, 1, + context) {} + + ::mlir::LogicalResult matchAndRewrite( + ::mlir::Operation *fused_batch_norm, + ::mlir::PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used for creating ops + Operation::operand_range mean(fused_batch_norm->getOperands()); + ::mlir::FloatAttr exponential_avg_factor; + ::mlir::StringAttr data_format; + ::mlir::TF::FusedBatchNormV3Op root; + Operation::operand_range offset(fused_batch_norm->getOperands()); + Operation::operand_range x(fused_batch_norm->getOperands()); + Operation::operand_range scale(fused_batch_norm->getOperands()); + Operation::operand_range variance(fused_batch_norm->getOperands()); + ::mlir::FloatAttr epsilon; + ::mlir::BoolAttr is_training; + + // Match + auto fused_batch_norm_op = + dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm); + root = fused_batch_norm_op; + x = fused_batch_norm_op.getODSOperands(0); + scale = fused_batch_norm_op.getODSOperands(1); + offset = fused_batch_norm_op.getODSOperands(2); + mean = fused_batch_norm_op.getODSOperands(3); + variance = fused_batch_norm_op.getODSOperands(4); + { + epsilon = fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>("epsilon"); + if (!epsilon) + epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f); + + if (!(((epsilon.isa<::mlir::FloatAttr>())) && + ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to " + "satisfy constraint: 32-bit float attribute"; + }); + } + } + { + exponential_avg_factor = + fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>( + "exponential_avg_factor"); + if (!exponential_avg_factor) + exponential_avg_factor = + rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f); + } + { + data_format = + fused_batch_norm_op.getAttrOfType<::mlir::StringAttr>("data_format"); + if (!data_format) data_format = rewriter.getStringAttr("NHWC"); + } + { + is_training = + fused_batch_norm_op.getAttrOfType<::mlir::BoolAttr>("is_training"); + if (!is_training) is_training = rewriter.getBoolAttr(true); + + if (!((!is_training.getValue()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "op 'tf.FusedBatchNormV3' attribute 'is_training' failed " + "to " + "satisfy constraint: FalseBoolAttr"; + }); + } + } + + if (!(((*root.getODSResults(1).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(2).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(3).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(4).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + + if (!(((*root.getODSResults(5).begin()).use_empty()))) { + return rewriter.notifyMatchFailure( + fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { + diag << "entities '' failed to satisfy constraint: has no use"; + }); + } + // Rewrite + auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()}); + ::llvm::SmallVector<::mlir::Value, 4> replace_values; + ::mlir::TF::ConstOp epsilon_const_op; + { + epsilon_const_op = + rewriter.create<::mlir::TF::ConstOp>(odsLoc, + /*value=*/epsilon); + } + ::mlir::TF::AddOp add_op_1; + { + ::mlir::Value tblgen_value_0 = (*variance.begin()); + ::mlir::Value tblgen_value_1 = + (*epsilon_const_op.getODSResults(0).begin()); + add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + // We need to make sure the Add operands are broadcastable. + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(add_op_1) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::RsqrtOp rsqrt_op; + { + ::mlir::SmallVector<::mlir::Value, 4> tblgen_values; + ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; + tblgen_values.push_back((*add_op_1.getODSResults(0).begin())); + rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values, + tblgen_attrs); + } + ::mlir::TF::MulOp multiplier; + { + ::mlir::Value tblgen_value_0 = (*scale.begin()); + ::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin()); + // We need to make sure the Add operands are broadcastable. + multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(multiplier) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::MulOp mul_op_1; + { + ::mlir::Value tblgen_value_0 = (*x.begin()); + ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin()); + mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + // We need to make sure the Mul operands are broadcastable. + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(mul_op_1) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::MulOp mul_op_2; + { + ::mlir::Value tblgen_value_0 = (*mean.begin()); + ::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin()); + mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(mul_op_2) + .value == LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::SubOp sub_op; + { + ::mlir::Value tblgen_value_0 = (*offset.begin()); + ::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin()); + sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc, + /*x=*/tblgen_value_0, + /*y=*/tblgen_value_1); + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(sub_op).value == + LogicalResult::Failure) { + return failure(); + } + } + ::mlir::TF::AddOp add_op_2; + { + ::mlir::SmallVector<::mlir::Value, 4> tblgen_values; + ::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs; + tblgen_values.push_back((*mul_op_1.getODSResults(0).begin())); + tblgen_values.push_back((*sub_op.getODSResults(0).begin())); + ::mlir::SmallVector<::mlir::Type, 4> tblgen_types; + for (auto v : fused_batch_norm_op.getODSResults(0)) { + tblgen_types.push_back(v.getType()); + } + add_op_2 = rewriter.create<::mlir::TF::AddOp>( + odsLoc, tblgen_types, tblgen_values, tblgen_attrs); + // We need to make sure the Add operands are broadcastable. + if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(add_op_2) + .value == LogicalResult::Failure) { + return failure(); + } + } + for (auto v : + ::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) { + replace_values.push_back(v); + } + rewriter.replaceOp(fused_batch_norm, replace_values); + return success(); + }; +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc" // Returns success if all the operations in the `op`'s regions including `op` @@ -927,11 +1199,11 @@ void PrepareTFPass::runOnFunction() { // This pattern will try to identify and optimize for dilated convolution. // e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be // replaced with a single Conv op with dilation parameter. - patterns.insert, + patterns.insert, FusedBatchNormV3Pat, ConvertTFDilatedConvOp>(ctx); patterns.insert(ctx); - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); // TODO(karimnosseir): Split to separate pass probably after // deciding on long term plan for this optimization. // This will allow optimizing any TF_Mul->TF_Conv in the graph @@ -942,7 +1214,7 @@ void PrepareTFPass::runOnFunction() { // Load the generated pattern again, so new quantization pass-through // will be applied. patterns.clear(); - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); if (unfold_batch_matmul_) { patterns.insert, TF::ConvertTFBatchMatMulOp>(ctx); diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index ba25b5c897c..529e57780c3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -84,7 +84,7 @@ void QuantizePass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); auto* ctx = func.getContext(); - TFL::populateWithGenerated(ctx, &patterns); + TFL::populateWithGenerated(ctx, patterns); patterns.insert( ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify); applyPatternsAndFoldGreedily(func, patterns); diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD index 31bce8d1bf6..47bff366311 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -20,8 +20,8 @@ tf_python_pybind_extension( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:FileCheckLib", "@llvm-project//llvm:Support", - "@llvm-project//llvm:filecheck-lib", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:StandardOps", @@ -37,8 +37,8 @@ tf_python_pybind_extension( deps = [ "//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:FileCheckLib", "@llvm-project//llvm:Support", - "@llvm-project//llvm:filecheck-lib", "@pybind11", ], ) diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index e403a75d3b9..17410b4e5b2 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -74,8 +74,8 @@ tool_names = [ 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt', - 'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt', - 'tfjs-opt' + 'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary', + 'xla-thunks-opt', 'tfjs-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 01f162cf8e5..1c740731acd 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -5,7 +5,7 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//third_party/mlir:tblgen.bzl", "gentbl") -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py") package( default_visibility = [":friends"], @@ -110,6 +110,8 @@ cc_library( deps = [ ":tensorflow_op_interfaces_inc_gen", ":tensorflow_structs", + "//tensorflow/core:framework", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -810,9 +812,8 @@ cc_library( ], deps = [ ":tensorflow", + ":tensorflow_op_interfaces", ":tensorflow_types", - "//tensorflow/core:framework", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", @@ -1046,6 +1047,7 @@ cc_library( "translate/import_model.h", ], deps = [ + ":convert_attr", ":convert_tensor", ":convert_type", ":dump_mlir_util", @@ -1171,7 +1173,6 @@ cc_library( cc_library( name = "export_tf_dialect_op", srcs = [ - "translate/derived_attr_populator.inc", "translate/export_tf_dialect_op.cc", ], hdrs = [ @@ -1181,6 +1182,7 @@ cc_library( ":convert_type", ":export_utils", ":tensorflow", + "//tensorflow/compiler/mlir:string_container_utils", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -1189,6 +1191,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:IR", ], ) @@ -1262,6 +1265,24 @@ cc_library( ], ) +cc_library( + name = "convert_attr", + srcs = ["utils/convert_attr.cc"], + hdrs = ["utils/convert_attr.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":convert_tensor", + ":convert_type", + ":tensorflow_attributes", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor/lib", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "convert_type", srcs = ["utils/convert_type.cc"], @@ -1394,6 +1415,7 @@ cc_library( ":decode_constant_pass", ":eval_util", ":tensorflow", + ":tensorflow_traits", ":tensorflow_types", "//tensorflow/c:tf_status", "//tensorflow/c/eager:c_api", @@ -1545,38 +1567,6 @@ tf_cc_test( ], ) -tf_native_cc_binary( - name = "derived_attr_populator_gen", - srcs = [ - "translate/derived_attr_populator_gen.cc", - ], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//llvm:TableGen", - "@llvm-project//mlir:TableGen", - ], -) - -gentbl( - name = "derived_attr_populator_inc", - compatible_with = get_compatible_with_cloud(), - tbl_outs = [ - ("", "translate/derived_attr_populator.inc"), - ], - tblgen = ":derived_attr_populator_gen", - td_file = "ir/tf_ops.td", - td_srcs = [ - "@llvm-project//mlir:include/mlir/IR/OpBase.td", - "@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td", - "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", - "ir/tf_generated_ops.td", - "ir/tf_op_base.td", - "ir/tf_op_interfaces.td", - ], -) - filegroup( name = "tensorflow_optimize_td_files", srcs = [ @@ -1993,6 +1983,7 @@ cc_library( deps = [ ":convert_tensor", ":convert_type", + ":export_tf_dialect_op", ":export_utils", ":tensorflow", ":tensorflow_attributes", diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index b25014fe09e..cdc9e33e368 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -18,21 +18,17 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "mlir/Analysis/CallGraph.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -43,8 +39,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { @@ -231,48 +227,16 @@ BacktrackAnalysisInfo::BacktrackAnalysisInfo( backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result)); } -namespace { - -//===----------------------------------------------------------------------===// -// ResourceAliasAnalysisInfo helper functions. -//===----------------------------------------------------------------------===// - -constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; - -// Returns if a VarHandleOp is anonymous, which means it always creates a new -// variable. -bool IsResourceHandleAnonymous(VarHandleOp handle) { - return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; -} - -// Returns a string unique identifier for a non-anonymous VarHandleOp. -std::string GetVarHandleStringId(VarHandleOp handle) { - auto device = handle.getAttrOfType("device"); - return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(), - "/", device ? device.getValue().str() : std::string("")); -} - -// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always -// creates a new ID; otherwise, tries to reuse the existing ID for the -// referenced variable if it exists, or creates a new one if not. -int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t* next_id, - llvm::StringMap* name_id_map) { - // Always create a new ID for anonymous handle. - if (IsResourceHandleAnonymous(handle)) return (*next_id)++; - - auto name = GetVarHandleStringId(handle); - auto emplace_res = name_id_map->try_emplace(name, *next_id); - // New ID created, increment next_id. - if (emplace_res.second) ++(*next_id); - return emplace_res.first->second; -} - -} // namespace - //===----------------------------------------------------------------------===// // ResourceAliasAnalysisInfo //===----------------------------------------------------------------------===// +namespace { + +constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; + +} // namespace + constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId; // Constructs the analysis info by analyzing the given function. @@ -338,13 +302,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( } }); - llvm::StringMap var_handle_name_id_map; + llvm::SmallDenseMap resource_handle_id_map; func_op.walk([&](Operation* op) { - if (auto var_handle = dyn_cast(op)) { - AddValueUniqueIDMapping( - var_handle.resource(), - GetOrCreateIdForVarHandle(var_handle, &next_unique_id, - &var_handle_name_id_map)); + if (auto resource_alloc = dyn_cast(op)) { + ResourceHandleValueAndId resource = + resource_alloc.GetResourceHandleValueAndId(resource_handle_id_map, + next_unique_id); + AddValueUniqueIDMapping(resource.value, resource.id); } else if (llvm::isa(op)) { for (auto result : filter_resources(op->getResults())) PropagateInputToOutput(op->getOperand(result.getResultNumber()), diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 4a2080c5951..4d2c237e9a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Optional.h" diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index f97f59c67f8..32c51f2e2bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -512,7 +512,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { return Status::OK(); } PassManager pm(func_.getContext()); - ::tensorflow::SetCrashReproducer(pm); + ::tensorflow::applyTensorflowAndCLOptions(pm); pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); pm.addPass(CreateBreakUpIslandsPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc index 40cc2c99c27..746b34a018a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "mlir/IR/Attributes.h" // from @llvm-project - namespace mlir { namespace TF { @@ -79,6 +77,14 @@ ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, return Base::get(context, ArrayRef(), /*unranked=*/true); } +// Get or create a shape attribute. +ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, ShapedType shaped_type) { + if (shaped_type.hasRank()) + return Base::get(context, shaped_type.getShape(), /*unranked=*/false); + + return Base::get(context, ArrayRef(), /*unranked=*/true); +} + llvm::Optional> ShapeAttr::getValue() const { if (hasRank()) return getShape(); return llvm::None; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h index 5a18b77ab5c..0927aefff68 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -20,6 +20,8 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project namespace mlir { namespace TF { @@ -43,6 +45,9 @@ class ShapeAttr : public Attribute::AttrBase> shape); + // Get or create a shape attribute from a ShapedType type. + static ShapeAttr get(mlir::MLIRContext* context, ShapedType shaped_type); + llvm::Optional> getValue() const; bool hasRank() const; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 8d8e5d9896b..362c0081c77 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -29,6 +29,7 @@ limitations under the License. // Ops in this file are sorted alphabetically. include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the absolute value of a tensor."; @@ -223,7 +224,7 @@ retained with length 1. let verifier = [{ return Verify(*this); }]; } -def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect]> { +def TF_AllToAllOp : TF_Op<"AllToAll", [NoSideEffect, TF_NoConstantFold]> { let summary = "An Op to exchange data across TPU replicas."; let description = [{ @@ -296,6 +297,33 @@ Equivalent to np.angle. TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } +def TF_AnonymousIteratorOp : TF_Op<"AnonymousIterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + +def TF_AnonymousIteratorV2Op : TF_Op<"AnonymousIteratorV2", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle, + TF_VariantTensor:$deleter + ); +} + def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> { let summary = ""; @@ -307,6 +335,21 @@ def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> { ); } +def TF_AnonymousMultiDeviceIteratorOp : TF_Op<"AnonymousMultiDeviceIterator", []> { + let summary = "A container for a multi device iterator resource."; + + let arguments = (ins + Confined]>:$devices, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle, + TF_VariantTensor:$deleter + ); +} + def TF_AnonymousRandomSeedGeneratorOp : TF_Op<"AnonymousRandomSeedGenerator", []> { let summary = ""; @@ -2111,7 +2154,7 @@ of corresponding 3-element vectors is cross-multiplied independently. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { +def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>, TF_NoConstantFold]> { let summary = "An Op to sum inputs across replicated TPU instances."; let description = [{ @@ -2484,6 +2527,17 @@ is the same, though it is cleaner to use `tf.io.decode_image`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } +def TF_DeleteIteratorOp : TF_Op<"DeleteIterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Arg:$handle, + TF_VariantTensor:$deleter + ); + + let results = (outs); +} + def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> { let summary = ""; @@ -2495,6 +2549,20 @@ def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> { let results = (outs); } +def TF_DeleteMultiDeviceIteratorOp : TF_Op<"DeleteMultiDeviceIterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + Arg:$multi_device_iterator, + Arg, "", [TF_DatasetIteratorRead]>:$iterators, + TF_VariantTensor:$deleter + ); + + let results = (outs); + + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>; +} + def TF_DeleteRandomSeedGeneratorOp : TF_Op<"DeleteRandomSeedGenerator", []> { let summary = ""; @@ -2718,6 +2786,19 @@ Computes the gradients of depthwise convolution with respect to the input. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_DeserializeIteratorOp : TF_Op<"DeserializeIterator", []> { + let summary = [{ +Converts the given variant tensor to an iterator and stores it in the given resource. + }]; + + let arguments = (ins + Arg:$resource_handle, + TF_VariantTensor:$serialized + ); + + let results = (outs); +} + def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> { let summary = "Return the index of device the op runs."; @@ -3242,8 +3323,7 @@ tf.math.equal(x, y) ==> array([True, True]) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value x, " - "Value y, BoolAttr incompatible_shape_error"> + OpBuilder<"Value x, Value y, BoolAttr incompatible_shape_error"> ]; let verifier = [{ @@ -3389,8 +3469,7 @@ size 1. TF_DerivedOperandTypeAttr Tdim = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value condition, " - "Value dim"> + OpBuilder<"Value condition, Value dim"> ]; } @@ -3567,6 +3646,24 @@ Quantization is called fake since the output is still in floating point. }]; } +def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradient", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation."; + + let arguments = (ins + F32Tensor:$gradients, + F32Tensor:$inputs, + + DefaultValuedAttr:$min, + DefaultValuedAttr:$max, + DefaultValuedAttr:$num_bits, + DefaultValuedAttr:$narrow_range + ); + + let results = (outs + F32Tensor:$backprops + ); +} + def TF_FakeQuantWithMinMaxVarsOp : TF_Op<"FakeQuantWithMinMaxVars", [NoSideEffect]> { let summary = [{ Fake-quantize the 'inputs' tensor of type float via global float scalars @@ -3617,6 +3714,26 @@ values. }]; } +def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradient", [NoSideEffect]> { + let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation."; + + let arguments = (ins + F32Tensor:$gradients, + F32Tensor:$inputs, + F32Tensor:$min, + F32Tensor:$max, + + DefaultValuedAttr:$num_bits, + DefaultValuedAttr:$narrow_range + ); + + let results = (outs + F32Tensor:$backprops_wrt_input, + F32Tensor:$backprop_wrt_min, + F32Tensor:$backprop_wrt_max + ); +} + def TF_FakeQuantWithMinMaxVarsPerChannelOp : TF_Op<"FakeQuantWithMinMaxVarsPerChannel", [NoSideEffect]> { let summary = [{ Fake-quantize the 'inputs' tensor of type float via per-channel floats @@ -3711,9 +3828,9 @@ fill([2, 3], 9) ==> [[9, 9, 9] let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value dims, Value value" - >]; + let builders = [ + OpBuilder<"Value dims, Value value"> + ]; } def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> { @@ -3893,7 +4010,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. TF_Float32Tensor:$reserve_space_3, DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$data_format, + DefaultValuedAttr, "NHWC">:$data_format, DefaultValuedAttr:$is_training ); @@ -3978,7 +4095,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. DefaultValuedAttr:$epsilon, DefaultValuedAttr:$exponential_avg_factor, - DefaultValuedAttr:$data_format, + DefaultValuedAttr, "NHWC">:$data_format, DefaultValuedAttr:$is_training ); @@ -4644,6 +4761,39 @@ tf.imag(input) ==> [4.75, 5.75] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } +def TF_InTopKV2Op : TF_Op<"InTopKV2", [NoSideEffect]> { + let summary = "Says whether the targets are in the top `K` predictions."; + + let description = [{ +This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the +prediction for the target class is among the top `k` predictions among +all predictions for example `i`. Note that the behavior of `InTopK` differs +from the `TopK` op in its handling of ties; if multiple classes have the +same prediction value and straddle the top-`k` boundary, all of those +classes are considered to be in the top `k`. + +More formally, let + + \\(predictions_i\\) be the predictions for all classes for example `i`, + \\(targets_i\\) be the target class for example `i`, + \\(out_i\\) be the output for example `i`, + +$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$ + }]; + + let arguments = (ins + F32Tensor:$predictions, + TF_I32OrI64Tensor:$targets, + TF_I32OrI64Tensor:$k + ); + + let results = (outs + I1Tensor:$precision + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_InfeedDequeueOp : TF_Op<"InfeedDequeue", []> { let summary = [{ A placeholder op for a value that will be fed into the computation. @@ -4895,11 +5045,58 @@ tf.math.is_nan(x) ==> [False, True, False, True, False] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_IteratorOp : TF_Op<"Iterator", []> { + let summary = "A container for an iterator resource."; + + let arguments = (ins + StrAttr:$shared_name, + StrAttr:$container, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + +def TF_IteratorFromStringHandleOp : TF_Op<"IteratorFromStringHandle", []> { + let summary = [{ +Converts the given string representing a handle to an iterator to a resource. + }]; + + let arguments = (ins + TF_StrTensor:$string_handle, + + DefaultValuedAttr:$output_types, + DefaultValuedAttr:$output_shapes + ); + + let results = (outs + Res:$resource_handle + ); +} + +def TF_IteratorFromStringHandleV2Op : TF_Op<"IteratorFromStringHandleV2", []> { + let summary = ""; + + let arguments = (ins + TF_StrTensor:$string_handle, + + DefaultValuedAttr:$output_types, + DefaultValuedAttr:$output_shapes + ); + + let results = (outs + Res:$resource_handle + ); +} + def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> { let summary = "Gets the next output from the given iterator ."; let arguments = (ins - TF_ResourceTensor:$iterator + Arg:$iterator ); let results = (outs @@ -4910,6 +5107,74 @@ def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> { TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>; } +def TF_IteratorGetNextAsOptionalOp : TF_Op<"IteratorGetNextAsOptional", []> { + let summary = [{ +Gets the next output from the given iterator as an Optional variant. + }]; + + let arguments = (ins + Arg:$iterator, + + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$optional + ); +} + +def TF_IteratorGetNextSyncOp : TF_Op<"IteratorGetNextSync", []> { + let summary = "Gets the next output from the given iterator."; + + let description = [{ +This operation is a synchronous version IteratorGetNext. It should only be used +in situations where the iterator does not block the calling thread, or where +the calling thread is not a member of the thread pool used to execute parallel +operations (e.g. in eager mode). + }]; + + let arguments = (ins + Arg:$iterator + ); + + let results = (outs + Variadic:$components + ); + + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; + TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>; +} + +def TF_IteratorToStringHandleOp : TF_Op<"IteratorToStringHandle", []> { + let summary = [{ +Converts the given `resource_handle` representing an iterator to a string. + }]; + + let arguments = (ins + Arg:$resource_handle + ); + + let results = (outs + TF_StrTensor:$string_handle + ); +} + +def TF_IteratorV2Op : TF_Op<"IteratorV2", []> { + let summary = ""; + + let arguments = (ins + StrAttr:$shared_name, + StrAttr:$container, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> { let summary = "L2 Loss."; @@ -5516,6 +5781,24 @@ A 2-D example: TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; } +def TF_MakeIteratorOp : TF_Op<"MakeIterator", []> { + let summary = [{ +Makes a new iterator from the given `dataset` and stores it in `iterator`. + }]; + + let description = [{ +This operation may be executed multiple times. Each execution will reset the +iterator in `iterator` to the first element of `dataset`. + }]; + + let arguments = (ins + TF_VariantTensor:$dataset, + Arg:$iterator + ); + + let results = (outs); +} + def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = [{ Multiply the matrix "a" by the matrix "b". @@ -6426,10 +6709,9 @@ retained with length 1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value input, " - "Value reduction_indices, BoolAttr keep_dims" - >]; + let builders = [ + OpBuilder<"Value input, Value reduction_indices, BoolAttr keep_dims"> + ]; } def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { @@ -6840,6 +7122,82 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MultiDeviceIteratorOp : TF_Op<"MultiDeviceIterator", []> { + let summary = "Creates a MultiDeviceIterator resource."; + + let arguments = (ins + Confined]>:$devices, + StrAttr:$shared_name, + StrAttr:$container, + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + Res:$handle + ); +} + +def TF_MultiDeviceIteratorFromStringHandleOp : TF_Op<"MultiDeviceIteratorFromStringHandle", []> { + let summary = [{ +Generates a MultiDeviceIterator resource from its provided string handle. + }]; + + let arguments = (ins + TF_StrTensor:$string_handle, + + DefaultValuedAttr:$output_types, + DefaultValuedAttr:$output_shapes + ); + + let results = (outs + Res:$multi_device_iterator + ); +} + +def TF_MultiDeviceIteratorGetNextFromShardOp : TF_Op<"MultiDeviceIteratorGetNextFromShard", []> { + let summary = "Gets next element for the provided shard number."; + + let arguments = (ins + Arg:$multi_device_iterator, + TF_Int32Tensor:$shard_num, + TF_Int64Tensor:$incarnation_id + ); + + let results = (outs + Variadic:$components + ); + + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; + TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>; +} + +def TF_MultiDeviceIteratorInitOp : TF_Op<"MultiDeviceIteratorInit", []> { + let summary = "Initializes the multi device iterator with the given dataset."; + + let arguments = (ins + TF_VariantTensor:$dataset, + Arg:$multi_device_iterator, + TF_Int64Tensor:$max_buffer_size + ); + + let results = (outs + TF_Int64Tensor:$incarnation_id + ); +} + +def TF_MultiDeviceIteratorToStringHandleOp : TF_Op<"MultiDeviceIteratorToStringHandle", []> { + let summary = "Produces a string handle for the given MultiDeviceIterator."; + + let arguments = (ins + Arg:$multi_device_iterator + ); + + let results = (outs + TF_StrTensor:$string_handle + ); +} + def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> { let summary = "Draws samples from a multinomial distribution."; @@ -6972,6 +7330,34 @@ I.e., \\(y = -x\\). let hasCanonicalizer = 1; } +def TF_NextAfterOp : TF_Op<"NextAfter", [NoSideEffect, ResultsBroadcastableShape]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Returns the next representable value of `x1` in the direction of `x2`, element-wise. + }]; + + let description = [{ +This operation returns the same result as the C++ std::nextafter function. + +It can also return a subnormal number. + +@compatibility(cpp) +Equivalent to C++ std::nextafter function. +@end_compatibility + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$x1, + TF_F32OrF64Tensor:$x2 + ); + + let results = (outs + TF_F32OrF64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> { let summary = "Does nothing. Only useful as a placeholder for control edges."; @@ -6980,6 +7366,49 @@ def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> { let results = (outs); } +def TF_NonMaxSuppressionV3Op : TF_Op<"NonMaxSuppressionV3", [NoSideEffect]> { + let summary = [{ +Greedily selects a subset of bounding boxes in descending order of score, + }]; + + let description = [{ +pruning away boxes that have high intersection-over-union (IOU) overlap +with previously selected boxes. Bounding boxes with score less than +`score_threshold` are removed. Bounding boxes are supplied as +[y1, x1, y2, x2], where (y1, x1) and (y2, x2) are the coordinates of any +diagonal pair of box corners and the coordinates can be provided as normalized +(i.e., lying in the interval [0, 1]) or absolute. Note that this algorithm +is agnostic to where the origin is in the coordinate system and more +generally is invariant to orthogonal transformations and translations +of the coordinate system; thus translating or reflections of the coordinate +system result in the same boxes being selected by the algorithm. +The output of this operation is a set of integers indexing into the input +collection of bounding boxes representing the selected boxes. The bounding +box coordinates corresponding to the selected indices can then be obtained +using the `tf.gather operation`. For example: + selected_indices = tf.image.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_boxes = tf.gather(boxes, selected_indices) + }]; + + let arguments = (ins + TensorOf<[TF_Float16, TF_Float32]>:$boxes, + TensorOf<[TF_Float16, TF_Float32]>:$scores, + TF_Int32Tensor:$max_output_size, + TensorOf<[TF_Float16, TF_Float32]>:$iou_threshold, + TensorOf<[TF_Float16, TF_Float32]>:$score_threshold + ); + + let results = (outs + TF_Int32Tensor:$selected_indices + ); + + TF_DerivedOperandTypeAttr T_threshold = TF_DerivedOperandTypeAttr<3>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasCanonicalizer = 1; +} + def TF_NonMaxSuppressionV4Op : TF_Op<"NonMaxSuppressionV4", [NoSideEffect]> { let summary = [{ Greedily selects a subset of bounding boxes in descending order of score, @@ -7096,8 +7525,7 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value x, " - "Value y, BoolAttr incompatible_shape_error"> + OpBuilder<"Value x, Value y, BoolAttr incompatible_shape_error"> ]; let verifier = [{ @@ -7215,8 +7643,7 @@ output = TF_DerivedOperandTypeAttr TI = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value indices, " - "Value depth, Value on_value, Value off_value, " + OpBuilder<"Value indices, Value depth, Value on_value, Value off_value, " "IntegerAttr axis"> ]; @@ -7225,6 +7652,44 @@ output = }]; } +def TF_OneShotIteratorOp : TF_Op<"OneShotIterator", []> { + let summary = [{ +Makes a "one-shot" iterator that can be iterated only once. + }]; + + let description = [{ +A one-shot iterator bundles the logic for defining the dataset and +the state of the iterator in a single op, which allows simple input +pipelines to be defined without an additional initialization +("MakeIterator") step. + +One-shot iterators have the following limitations: + +* They do not support parameterization: all logic for creating the underlying + dataset must be bundled in the `dataset_factory` function. +* They are not resettable. Once a one-shot iterator reaches the end of its + underlying dataset, subsequent "IteratorGetNext" operations on that + iterator will always produce an `OutOfRange` error. + +For greater flexibility, use "Iterator" and "MakeIterator" to define +an iterator using an arbitrary subgraph, which may capture tensors +(including fed values) as parameters, and which may be reset multiple +times by rerunning "MakeIterator". + }]; + + let arguments = (ins + SymbolRefAttr:$dataset_factory, + Confined]>:$output_types, + Confined]>:$output_shapes, + StrAttr:$container, + StrAttr:$shared_name + ); + + let results = (outs + Res:$handle + ); +} + def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> { let summary = "Enqueue multiple Tensor values on the computation outfeed."; @@ -8024,8 +8489,7 @@ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value start, " - "Value limit, Value delta"> + OpBuilder<"Value start, Value limit, Value delta"> ]; } @@ -8078,7 +8542,7 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value input"> + OpBuilder<"Value input"> ]; let hasFolder = 1; @@ -8349,8 +8813,7 @@ reshape(t, []) ==> 7 TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder< - "OpBuilder& builder, OperationState& result, Value tensor, Value shape"> + OpBuilder<"Value tensor, Value shape"> ]; let verifier = [{ @@ -10137,7 +10600,7 @@ def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value condition, Value e, Value t"> + OpBuilder<"Value condition, Value e, Value t"> ]; } @@ -10217,6 +10680,22 @@ Computes gradients for the scaled exponential linear (Selu) operation. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_SerializeIteratorOp : TF_Op<"SerializeIterator", []> { + let summary = [{ +Converts the given `resource_handle` representing an iterator to a variant tensor. + }]; + + let arguments = (ins + Arg:$resource_handle, + + DefaultValuedAttr:$external_state_policy + ); + + let results = (outs + TF_VariantTensor:$serialized + ); +} + def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> { let summary = "Returns the shape of a tensor."; @@ -10247,7 +10726,7 @@ shape(t) ==> [2, 2, 3] }]; let builders = [ - OpBuilder<"OpBuilder& builder, OperationState& result, Value input, BoolAttr use32Bit"> + OpBuilder<"Value input, BoolAttr use32Bit"> ]; let hasFolder = 1; @@ -10694,7 +11173,7 @@ block size. TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>; } -def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> { +def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "SpaceToBatch for N-D tensors of type T."; let description = [{ @@ -10723,6 +11202,12 @@ precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; let verifier = [{ return Verify(*this); }]; + + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(ArrayRef l, ArrayRef r) { + return ArraysAreCastCompatible(l, r); + } + }]; } def TF_SpaceToDepthOp : TF_Op<"SpaceToDepth", [NoSideEffect]> { @@ -11267,7 +11752,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> { ); } -def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> { +def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect, TF_NoConstantFold]> { let summary = "Draws samples from a multinomial distribution."; let arguments = (ins @@ -11285,7 +11770,82 @@ def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> { TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>; } -def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> { +def TF_StatelessParameterizedTruncatedNormalOp : TF_Op<"StatelessParameterizedTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> { + let summary = ""; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$means, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$stddevs, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$minvals, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$maxvals + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output + ); + + TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_StatelessRandomBinomialOp : TF_Op<"StatelessRandomBinomial", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random numbers from a binomial distribution. + }]; + + let description = [{ +Outputs random values from a binomial distribution. + +The outputs are a deterministic function of `shape`, `seed`, `counts`, and `probs`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$counts, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$probs + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output + ); + + TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatelessRandomGammaV2Op : TF_Op<"StatelessRandomGammaV2", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random numbers from a gamma distribution. + }]; + + let description = [{ +Outputs random values from a gamma distribution. + +The outputs are a deterministic function of `shape`, `seed`, and `alpha`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$alpha + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom values from a normal distribution. }]; @@ -11310,7 +11870,34 @@ The outputs are a deterministic function of `shape` and `seed`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } -def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> { +def TF_StatelessRandomPoissonOp : TF_Op<"StatelessRandomPoisson", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random numbers from a Poisson distribution. + }]; + + let description = [{ +Outputs random values from a Poisson distribution. + +The outputs are a deterministic function of `shape`, `seed`, and `lam`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$lam + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; + TF_DerivedOperandTypeAttr Rtype = TF_DerivedOperandTypeAttr<2>; +} + +def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom random values from a uniform distribution. }]; @@ -11336,7 +11923,32 @@ The outputs are a deterministic function of `shape` and `seed`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } -def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> { +def TF_StatelessRandomUniformFullIntOp : TF_Op<"StatelessRandomUniformFullInt", [NoSideEffect, TF_NoConstantFold]> { + let summary = [{ +Outputs deterministic pseudorandom random integers from a uniform distribution. + }]; + + let description = [{ +The generated values are uniform integers covering the whole range of `dtype`. + +The outputs are a deterministic function of `shape` and `seed`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$seed + ); + + let results = (outs + TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom random integers from a uniform distribution. }]; @@ -11363,7 +11975,7 @@ The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxv TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; } -def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> { +def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> { let summary = [{ Outputs deterministic pseudorandom values from a truncated normal distribution. }]; @@ -11711,10 +12323,9 @@ retained with length 1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value input, " - "Value reduction_indices, BoolAttr keep_dims" - >]; + let builders = [ + OpBuilder<"Value input, Value reduction_indices, BoolAttr keep_dims"> + ]; } def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> { @@ -11789,6 +12400,30 @@ For internal use only. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_TPUEmbeddingActivationsOp : TF_Op<"TPUEmbeddingActivations", [NoSideEffect]> { + let summary = "An op enabling differentiation of TPU Embeddings."; + + let description = [{ +This op simply returns its first input, which is assumed to have been sliced +from the Tensors returned by TPUEmbeddingDequeueActivations. The presence of +this op, and its first argument being a trainable Variable, enables automatic +differentiation of graphs containing embeddings via the TPU Embedding Python +libraries. + }]; + + let arguments = (ins + TF_Float32Tensor:$embedding_variable, + TF_Float32Tensor:$sliced_activations, + + Confined]>:$table_id, + Confined]>:$lookup_id + ); + + let results = (outs + TF_Float32Tensor:$output + ); +} + def TF_TPUExecuteOp : TF_Op<"TPUExecute", []> { let summary = "Op that loads and executes a TPU program on a TPU device."; @@ -12666,10 +13301,8 @@ On GPU, if an out of bound index is found, the index is ignored. let verifier = [{ return Verify(*this); }]; let builders = [ - OpBuilder< - "OpBuilder& builder, OperationState& result, " - "Value tensor, Value indices, Value updates", - [{build(builder, result, tensor.getType(), tensor, indices, updates);}] + OpBuilder<"Value tensor, Value indices, Value updates", + [{build($_builder, $_state, tensor.getType(), tensor, indices, updates);}] > ]; } @@ -12813,8 +13446,7 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: TF_DerivedOperandTypeAttr Tperm = TF_DerivedOperandTypeAttr<1>; let builders = [ - OpBuilder< - "OpBuilder& builder, OperationState& result, Value x, Value perm"> + OpBuilder<"Value x, Value perm"> ]; let verifier = [{ @@ -13518,6 +14150,30 @@ Handling of out-of-bounds slice indices is implementation-defined. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaEinsumOp : TF_Op<"XlaEinsum", [NoSideEffect]> { + let summary = [{ +An op which supports basic einsum op with 2 inputs and 1 output. + }]; + + let description = [{ +This op has better TPU performance since it doesn't have explicitly reshape and +transpose operations as tf.einsum does. + }]; + + let arguments = (ins + TensorOf<[TF_Bfloat16, TF_Complex64, TF_Float32]>:$a, + TensorOf<[TF_Bfloat16, TF_Complex64, TF_Float32]>:$b, + + StrAttr:$equation + ); + + let results = (outs + TensorOf<[TF_Bfloat16, TF_Complex64, TF_Float32]>:$product + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaGatherOp : TF_Op<"XlaGather", [NoSideEffect]> { let summary = "Wraps the XLA Gather operator documented at"; @@ -14110,4 +14766,4 @@ execution the transfer corresponds to.}]>:$dynamic_key, let results = (outs); TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; -} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 3c23958339c..15c0d7b10f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -73,6 +73,9 @@ def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; // certain state around within their implementations. def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">; +// Trait to indicate an operation cannot be constant folded. +def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">; + // Coefficient wise binary operation with implicit broadcasting support, for // example tf.Sub operation. def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">; @@ -112,6 +115,7 @@ def TF_SummaryResource : TF_ResourceBase<"Summary">; def TF_LookupTableResource : TF_ResourceBase<"LookupTable">; def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">; def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">; +def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">; def TF_VariableRead : MemRead; def TF_StackRead : MemRead; @@ -119,6 +123,7 @@ def TF_TensorArrayRead : MemRead; def TF_LookupTableRead : MemRead; def TF_DatasetSeedGeneratorRead : MemRead; def TF_DatasetMemoryCacheRead : MemRead; +def TF_DatasetIteratorRead : MemRead; def TF_VariableWrite : MemWrite; def TF_StackWrite : MemWrite; @@ -127,6 +132,7 @@ def TF_SummaryWrite : MemWrite; def TF_LookupTableWrite : MemWrite; def TF_DatasetSeedGeneratorWrite : MemWrite; def TF_DatasetMemoryCacheWrite : MemWrite; +def TF_DatasetIteratorWrite : MemWrite; def TF_VariableAlloc : MemAlloc; def TF_StackAlloc : MemAlloc; @@ -135,12 +141,14 @@ def TF_SummaryAlloc : MemAlloc; def TF_LookupTableAlloc : MemAlloc; def TF_DatasetSeedGeneratorAlloc : MemAlloc; def TF_DatasetMemoryCacheAlloc : MemAlloc; +def TF_DatasetIteratorAlloc : MemAlloc; def TF_StackFree : MemFree; def TF_TensorArrayFree : MemFree; def TF_SummaryFree : MemFree; def TF_DatasetSeedGeneratorFree : MemFree; def TF_DatasetMemoryCacheFree : MemFree; +def TF_DatasetIteratorFree : MemFree; //===----------------------------------------------------------------------===// // TensorFlow op definitions @@ -524,7 +532,7 @@ class TF_DerivedResultShapeListAttr : DerivedAttr< // A derived attribute that returns the shape of the first result type. def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", "return (*getOperation()->result_type_begin()).cast();", - [{ TypeAttr::get($_self) }]>; + [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>; // A derived attribute that returns the element type of the tensor held by a // named resource-type operand or result. @@ -544,7 +552,7 @@ class TF_DerivedOperandOrResultHandleShapeAttr : DerivedAttr< " .cast();\n" "assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n" "return resource_type.getSubtypes().begin()->cast();", - [{ TypeAttr::get($_self) }]>; + [{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>; def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { let returnType = "Type"; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h index 1eb5c89f0fc..3a6a9336a24 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h @@ -16,10 +16,17 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ +#include + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" +#include "tensorflow/core/framework/resource_mgr.h" namespace mlir { namespace TF { @@ -49,8 +56,80 @@ struct ContractionFusion { SmallVector additional_attributes; }; +//===----------------------------------------------------------------------===// +// TensorFlow Resource Handles. +//===----------------------------------------------------------------------===// + +inline bool IsResourceHandleAnonymous(StringRef name) { + return name == ::tensorflow::ResourceHandle::ANONYMOUS_NAME; +} + +// Helper struct representing an identifier for a resource handle. For resource +// handles created explicitly and shared across resource allocator ops, +// `container`, `name`, and `device` can be set. If an resource handle is tied +// to an instance of an operation (e.g. TensorFlow runtime operation caching), +// `op` can be set instead. +struct ResourceHandle { + ResourceHandle(StringRef container, StringRef name, StringRef device, + Operation* op) + : container(container), name(name), device(device), op(op) {} + + bool operator==(const ResourceHandle& rhs) const { + return container == rhs.container && name == rhs.name && + device == rhs.device && op == rhs.op; + } + + // Make ResourceHandle hashable. + friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle); + + std::string container; + std::string name; + std::string device; + Operation* op = nullptr; +}; + +// Make ResourceHandle hashable. +inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) { + return ::llvm::hash_combine(resource_handle.container, resource_handle.name, + resource_handle.device, resource_handle.op); +} + +// Helper struct holding a resource handle value and unique id associated to the +// resource handle. +struct ResourceHandleValueAndId { + ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {} + + Value value; + int64_t id = -1; +}; + #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc" } // namespace TF } // namespace mlir +namespace llvm { +template <> +struct DenseMapInfo { + static mlir::TF::ResourceHandle getEmptyKey() { + return {/*container=*/"", /*name=*/"", /*device=*/"", + /*op=*/DenseMapInfo::getEmptyKey()}; + } + + static mlir::TF::ResourceHandle getTombstoneKey() { + return {/*container=*/"", /*name=*/"", /*device=*/"", + /*op=*/DenseMapInfo::getTombstoneKey()}; + } + + static unsigned getHashValue( + const mlir::TF::ResourceHandle& resource_handle) { + return mlir::TF::hash_value(resource_handle); + } + + static bool isEqual(const mlir::TF::ResourceHandle& lhs, + const mlir::TF::ResourceHandle& rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td index 3c41c04a0d6..1ed30c89a77 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td @@ -125,4 +125,27 @@ def TF_ContractionFusableInterface : OpInterface<"ContractionFusableInterface"> ]; } +//===----------------------------------------------------------------------===// +// TensorFlow Resource Handle Interfaces. +//===----------------------------------------------------------------------===// + +def TF_ResourceHandleAllocatorInterface : OpInterface<"ResourceHandleAllocatorInterface"> { + let description = [{ + A resource handle allocator operation is one that creates a resource handle, + or looks up and reuses an existing resource handle. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{Returns the resource handle value and unique id associated with + the resource handle. If a resource handle is reused, then an + existing id will be returned.}], + /*retTy=*/"ResourceHandleValueAndId", + /*methodName=*/"GetResourceHandleValueAndId", + /*args=*/(ins "llvm::SmallDenseMap&":$resource_handle_id_map, + "int64_t&":$next_id) + >, + ]; +} + #endif // TF_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 2755a62a3c9..9ebd59007e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -113,8 +113,7 @@ class TensorFlowDialect : public Dialect { // same interface. template void addOperations() { - (void)std::initializer_list{ - 0, (addOperation(AbstractOperation::get(*this)), 0)...}; + Dialect::addOperations(); } using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 544f07f7075..67ad0fc4e70 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -787,7 +787,7 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co let results = (outs); } -def TF_VarHandleOp : TF_Op<"VarHandleOp", []> { +def TF_VarHandleOp : TF_Op<"VarHandleOp", [TF_ResourceHandleAllocatorInterface]> { let summary = "Creates a handle to a Variable resource from its name."; let description = [{ @@ -816,6 +816,13 @@ Example: TF_DerivedOperandOrResultHandleTypeAttr<"resource">; TF_DerivedOperandOrResultHandleShapeAttr shape = TF_DerivedOperandOrResultHandleShapeAttr<"resource">; + + let extraClassDeclaration = [{ + // TF_ResourceHandleAllocatorInterface: + ResourceHandleValueAndId GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id); + }]; } // Multiple variadic operands with different sizes are not supported by the diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index c1a26392273..8742a0e2b71 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -1179,6 +1180,17 @@ static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { // SpaceToBatchNDOp //===----------------------------------------------------------------------===// +int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type, + const TensorType paddings_type) { + if (block_shape_type.hasStaticShape()) { + return block_shape_type.getShape()[0]; + } else if (paddings_type.hasStaticShape()) { + return paddings_type.getShape()[0]; + } else { + return -1; + } +} + static LogicalResult Verify(SpaceToBatchNDOp op) { const auto input_type = op.input().getType().cast(); const auto block_shape_type = op.block_shape().getType().cast(); @@ -1209,12 +1221,8 @@ static LogicalResult Verify(SpaceToBatchNDOp op) { << "requires block_shape.shape[0] must equal paddings.shape[0]"; } - int64_t block_rank = -1; - if (block_shape_type.hasStaticShape()) { - block_rank = block_shape_type.getShape()[0]; - } else if (paddings_type.hasStaticShape()) { - block_rank = paddings_type.getShape()[0]; - } + const int64_t block_rank = + SpaceToBatchNDBlockRank(block_shape_type, paddings_type); // Further checks require block_rank to be known. if (block_rank == -1) { @@ -1282,7 +1290,68 @@ static LogicalResult Verify(SpaceToBatchNDOp op) { return success(); } -// TODO(b/157475606): Add SpaceToBatchNDOp::inferReturnTypes +// Infers returned rank if possible. Further, infers returned dimension sizes +// when possible. For all dimensions sizes to be inferred, the arguments +// block_shape and paddings must be constant. +LogicalResult SpaceToBatchNDOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + const Value input = operands[0]; + const Value block_shape_val = operands[1]; + const Value paddings_val = operands[2]; + const auto input_type = input.getType().cast(); + const auto block_shape_type = block_shape_val.getType().cast(); + const auto paddings_type = paddings_val.getType().cast(); + + // The return is unranked when the input is unranked. + if (!input_type.hasRank()) { + inferredReturnTypes.assign( + {UnrankedTensorType::get(input_type.getElementType())}); + return success(); + } + + const int64_t input_rank = input_type.getRank(); + const ArrayRef input_shape = input_type.getShape(); + const int64_t block_rank = + SpaceToBatchNDBlockRank(block_shape_type, paddings_type); + SmallVector return_shape(input_rank, ShapedType::kDynamicSize); + + // The return has all dimension sizes unknown when block_rank is unknown. + if (block_rank == -1) { + inferredReturnTypes.assign( + {RankedTensorType::get(return_shape, input_type.getElementType())}); + return success(); + } + + // The return preserves the remaining dimensions after blocked dimensions. + for (uint64_t i = 1 + block_rank; i < input_rank; ++i) { + return_shape[i] = input_shape[i]; + } + + // The rest of the dimension sizes can be calculated when block_shape and + // paddings arguments are constant. + ElementsAttr block_shape_attr; + ElementsAttr paddings_attr; + if (matchPattern(block_shape_val, m_Constant(&block_shape_attr)) && + matchPattern(paddings_val, m_Constant(&paddings_attr))) { + int64_t return_batch = input_shape[0]; + for (uint64_t i = 0; i < block_rank; ++i) { + int64_t paddings_sum = + paddings_attr.getValue({i, 0}).cast().getInt() + + paddings_attr.getValue({i, 1}).cast().getInt(); + int64_t block_shape_i = + block_shape_attr.getValue({i}).cast().getInt(); + return_batch *= block_shape_i; + return_shape[1 + i] = (paddings_sum + input_shape[i + 1]) / block_shape_i; + } + return_shape[0] = return_batch; + } + + inferredReturnTypes.assign( + {RankedTensorType::get(return_shape, input_type.getElementType())}); + return success(); +} //===----------------------------------------------------------------------===// // SparseSoftmaxCrossEntropyWithLogitsOp @@ -1378,7 +1447,8 @@ static LogicalResult Verify(SplitVOp op) { if (!split_sizes_type) return success(); if (split_sizes_type.getRank() != 1 || - split_sizes_type.getDimSize(0) != op.getNumResults()) + (split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize && + split_sizes_type.getDimSize(0) != op.getNumResults())) return op.emitOpError("split sizes should be a 1D tensor of ") << op.getNumResults() << " elements"; @@ -2213,6 +2283,45 @@ void TruncateDivOp::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// NonMaxSuppressionV3Op +//===----------------------------------------------------------------------===// + +namespace { + +// Canonicalize NonMaxSuppressionV3Op to NonMaxSuppressionV4Op. +class NMSV3ToNMSV4Op : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(NonMaxSuppressionV3Op nms_op, + PatternRewriter &rewriter) const override { + if (nms_op.getNumOperands() != 5) { + return failure(); + } + SmallVector new_result_types; + new_result_types.push_back(nms_op.getType()); + auto input_ty = nms_op.getType().template cast(); + // corresponds to the second result type of nmsv4 + RankedTensorType valid_output_type = + RankedTensorType::get({}, input_ty.getElementType()); + new_result_types.push_back(valid_output_type); + + auto nmsv4 = rewriter.create( + nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(), + nms_op.max_output_size(), nms_op.iou_threshold(), + nms_op.score_threshold()); + // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the + // two are different (v4 expects two output values vs v3 requires only one. + nms_op.replaceAllUsesWith(nmsv4.getResult(0)); + return success(); + } +}; +} // namespace. + +void NonMaxSuppressionV3Op::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // UnpackOp //===----------------------------------------------------------------------===// @@ -2280,6 +2389,27 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { return success(); } +//===----------------------------------------------------------------------===// +// VarHandleOp +//===----------------------------------------------------------------------===// + +ResourceHandleValueAndId VarHandleOp::GetResourceHandleValueAndId( + llvm::SmallDenseMap &resource_handle_id_map, + int64_t &next_id) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(shared_name())) return {resource(), next_id++}; + + llvm::StringRef device; + if (auto device_attr = getAttrOfType("device")) + device = device_attr.getValue(); + + ResourceHandle handle(container(), shared_name(), device, /*op=*/nullptr); + auto emplace_res = resource_handle_id_map.try_emplace(handle, next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++next_id; + return {resource(), emplace_res.first->second}; +} + //===----------------------------------------------------------------------===// // VarIsInitializedOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 7219fce7067..3c8ec1d38af 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -53,6 +53,10 @@ struct DatasetMemoryCache StringRef getName() final { return "DatasetMemoryCache"; } }; +struct DatasetIterator : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "DatasetIterator"; } +}; + } // namespace ResourceEffects } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 412bf113a0f..aef3c538bc8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -124,6 +124,10 @@ class CannotDuplicate : public TraitBase { } }; +// Trait to indicate an operation cannot be constant folded. +template +class NoConstantFold : public TraitBase {}; + // Coefficient-wise binary operation with implicit broadcasting support, for // example tf.Sub operation. template diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 50f034e8ba1..86369b993be 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -360,6 +360,16 @@ bool AreCastCompatible(ArrayRef types) { return true; } +bool ArraysAreCastCompatible(ArrayRef lhs, ArrayRef rhs) { + if (lhs.size() != rhs.size()) return false; + for (auto pair : llvm::zip(lhs, rhs)) { + auto lhs_i = std::get<0>(pair); + auto rhs_i = std::get<1>(pair); + if (!AreCastCompatible({lhs_i, rhs_i})) return false; + } + return true; +} + // Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default // type for a composed type (such as a ref type or a type with subtypes). template diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 60a86f32920..1d3ca0c4a60 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -302,6 +302,10 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs, // compatible. bool AreCastCompatible(ArrayRef types); +// Returns true if corresponding elements of lhs and rhs AreCastCompatible and +// lhs and rhs are the same length. +bool ArraysAreCastCompatible(ArrayRef lhs, ArrayRef rhs); + // If `ty` is a tensor type and its element type has subtypes, then returns a // new type of same shape but dropped subtypes for the element type. // Otherwise, if `ty` has subtypes, then returns corresponding type with dropped diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index ff90c6f4c5b..b0d96963088 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -1217,3 +1217,11 @@ func @testWhileDropOutputShapes(tensor<*xf32>) -> (tensor<*xf32>) { return %1 : tensor<*xf32> } + +// CHECK-LABEL: testNMSV3ToNMSV4 +func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor) -> tensor<2xi32> { + %max_size = constant dense<2> : tensor + // CHECK: "tf.NonMaxSuppressionV4" + %0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>) + return %0 : tensor<2xi32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index fff985efa6f..779065b94d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -502,3 +502,12 @@ func @fold_conv() -> tensor<1x520x520x1xf32> { // CHECK: tf.Const // CHECK-NOT: tf.DepthwiseConv2dNative } + +// CHECK-LABEL: DontFoldNoConstantFold +func @DontFoldNoConstantFold() -> tensor<8xf32> { + %0 = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tf.StatelessRandomUniform + %2 = "tf.StatelessRandomUniform"(%0, %1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<8xf32> + return %2 : tensor<8xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir index e7430993755..c963147b855 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -7,6 +7,14 @@ func @einsum_basic(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> } +func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tensor<7x5xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ae,ed->ad"}: (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> + return %0 : tensor<7x5xf32> + // CHECK-LABEL: einsum_matmul + // CHECK: %[[v0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<7x9xf32>, tensor<9x5xf32>) -> tensor<7x5xf32> + // CHECK: return %[[v0]] : tensor<7x5xf32> +} + func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> return %0 : tensor<3x4x6xf32> @@ -14,18 +22,27 @@ func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tens // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> } +func @einsum_broadcast4(%arg0: tensor<3x4x5x6x7xf32>, %arg1: tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "abcdh,hg->abcdg"}: (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> + return %0 : tensor<3x4x5x6x8xf32> + // CHECK-LABEL: einsum_broadcast4 + // CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<3x4x5x6x7xf32>, tensor<7x8xf32>) -> tensor<3x4x5x6x8xf32> +} + func @einsum_reducesum(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x2xf32>) -> tensor<5x7xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bl->bh"}: (tensor<2x5x7xf32>, tensor<5x2xf32>) -> tensor<5x7xf32> return %0 : tensor<5x7xf32> // CHECK-LABEL: einsum_reducesum // CHECK: %[[cst:.*]] = constant dense<[1, 2, 0]> : tensor<3xi32> - // CHECK: %[[cst_1:.*]] = constant dense<[5, 1, 2]> : tensor<3xi64> - // CHECK: %[[cst_2:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[5, 2, 1]> : tensor<3xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[5, 7]> : tensor<2xi64> // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<3xi32>) -> tensor<5x7x2xf32> - // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x1x2xf32> - // CHECK: %[[v2:.*]] = "tf.Mul"(%[[v0]], %[[v1]]) : (tensor<5x7x2xf32>, tensor<5x1x2xf32>) -> tensor<5x7x2xf32> - // CHECK: "tf.Sum"(%[[v2]], %[[cst_2]]) {keep_dims = false} : (tensor<5x7x2xf32>, tensor<1xi32>) -> tensor<5x7xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg1, %[[cst_1]]) : (tensor<5x2xf32>, tensor<3xi64>) -> tensor<5x2x1xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<5x7x2xf32>, tensor<5x2x1xf32>) -> tensor<5x7x1xf32> + // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<5x7x1xf32>, tensor<2xi64>) -> tensor<5x7xf32> + // CHECK: return %[[v3:.*]] : tensor<5x7xf32> } + func @einsum_transpose_matmul(%arg0: tensor<2x5x7xf32>, %arg1: tensor<5x3x2xf32>) -> tensor<5x3x7xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "lbh,bkl->bkh"}: (tensor<2x5x7xf32>, tensor<5x3x2xf32>) -> tensor<5x3x7xf32> return %0 : tensor<5x3x7xf32> @@ -88,12 +105,12 @@ func @einsum_transposereduceddim(%arg0: tensor<2x5x7xf32>, %arg1: tensor<2x5x3x7 %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bij,binj->bin"}: (tensor<2x5x7xf32>, tensor<2x5x3x7xf32>) -> tensor<2x5x3xf32> return %0 : tensor<2x5x3xf32> // CHECK-LABEL: einsum_transposereduceddim - // CHECK: %[[cst:.*]] = constant dense<[2, 5, 1, 7]> : tensor<4xi64> - // CHECK: %[[cst_1:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> + // CHECK: %[[cst:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[2, 5, 1, 7]> : tensor<4xi64> // CHECK: %[[cst_2:.*]] = constant dense<[2, 5, 3]> : tensor<3xi64> - // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32> - // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x5x3x7xf32>, tensor<4xi32>) -> tensor<2x5x7x3xf32> - // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<2x5x3x7xf32>, tensor<4xi32>) -> tensor<2x5x7x3xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%arg0, %[[cst_1]]) : (tensor<2x5x7xf32>, tensor<4xi64>) -> tensor<2x5x1x7xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v1]], %[[v0]]) {adj_x = false, adj_y = false} : (tensor<2x5x1x7xf32>, tensor<2x5x7x3xf32>) -> tensor<2x5x1x3xf32> // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_2]]) : (tensor<2x5x1x3xf32>, tensor<3xi64>) -> tensor<2x5x3xf32> // CHECK: return %[[v3]] : tensor<2x5x3xf32> } @@ -123,13 +140,26 @@ func @einsum_fourdtransposeall(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x11x7x // CHECK: return %[[v3]] : tensor<2x7x11x5xf32> } -func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { - %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +func @einsum_4d_1(%arg0: tensor<3x4x5x6xf32>, %arg1: tensor<3x7x5x6xf32>) -> tensor<3x5x4x7xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "jbki,jfki->jkbf"}: (tensor<3x4x5x6xf32>, tensor<3x7x5x6xf32>) -> tensor<3x5x4x7xf32> + return %0 : tensor<3x5x4x7xf32> + // CHECK-LABEL: einsum_4d_1 + // CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[0, 2, 3, 1]> : tensor<4xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst:.*]]) : (tensor<3x4x5x6xf32>, tensor<4xi32>) -> tensor<3x5x4x6xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<3x7x5x6xf32>, tensor<4xi32>) -> tensor<3x5x6x7xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<3x5x4x6xf32>, tensor<3x5x6x7xf32>) -> tensor<3x5x4x7xf32> + // CHECK: return %[[v2]] : tensor<3x5x4x7xf32> +} + +func @einsum_no_match(%arg0: tensor<4x5x6xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,j->i"}: (tensor<4x5x6xf32>, tensor<5xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> // CHECK-LABEL: einsum_no_match -// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> +// CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,j->i"} : (tensor<4x5x6xf32>, tensor<5xf32>) -> tensor<4xf32> // CHECK: return %[[v0]] } + func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -137,10 +167,15 @@ func @einsum_illegal_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> t // CHECK: %[[v0:.*]] = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,?zw->kq->i"} : (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> // CHECK: return %[[v0]] } -func @einsum_no_match5D(%arg0: tensor<4x5xf32>, %arg1: tensor<2x4x7x3x5xf32>) -> tensor<4xf32> { - %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<2x4x7x3x5xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -// CHECK-LABEL: einsum_no_match5D -// CHECK: %[[v0:.*]] = "tf.Einsum" -// CHECK: return %[[v0]] + +func @batch_multilhs_einsum(%arg0: tensor<2x1x1x11xf32>, %arg1: tensor<2x11x2xf32>) -> tensor<2x1x1x2xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "BiNj,BjS->BiNS"} : (tensor<2x1x1x11xf32>, tensor<2x11x2xf32>) -> tensor<2x1x1x2xf32> + return %0 : tensor<2x1x1x2xf32> +// CHECK-LABEL: batch_multilhs_einsum +// CHECK: %[[cst:.*]] = constant dense<[2, 1, 11]> : tensor<3xi64> +// CHECK: %[[cst_1:.*]] = constant dense<[2, 1, 1, 2]> : tensor<4xi64> +// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x1x1x11xf32>, tensor<3xi64>) -> tensor<2x1x11xf32> +// CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%[[v0]], %arg1) {adj_x = false, adj_y = false} : (tensor<2x1x11xf32>, tensor<2x11x2xf32>) -> tensor<2x1x2xf32> +// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<2x1x2xf32>, tensor<4xi64>) -> tensor<2x1x1x2xf32> +// CHECK: return %[[v2]] : tensor<2x1x1x2xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 67ff648aa2c..fb4b24037a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -86,10 +86,19 @@ func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32 return %0 : tensor<2x3xf32> } +// CHECK-LABEL: @is_inf +func @is_inf(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> { + // CHECK: %[[INF:.*]] = "tf.Const"() {value = dense<0x7F800000> : tensor} : () -> tensor + // CHECK: %[[ABS:.*]] = "tf.Abs"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xf32> + // CHECK: %[[RESULT:.*]] = "tf.Equal"(%[[ABS]], %[[INF]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor) -> tensor<3x4xi1> + %0 = "tf.IsInf"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1> + // CHECK: return %[[RESULT]] + return %0 : tensor<3x4xi1> +} + // CHECK-LABEL: @is_nan func @is_nan(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> { - // CHECK: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor} : () -> tensor - // CHECK: %[[RESULT:.*]] = "tf.Equal"(%arg0, %[[NAN]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor) -> tensor<3x4xi1> + // CHECK: %[[RESULT:.*]] = "tf.NotEqual"(%arg0, %arg0) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xi1> %0 = "tf.IsNan"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1> // CHECK: return %[[RESULT]] return %0 : tensor<3x4xi1> @@ -218,7 +227,7 @@ func @rsqrt_grad_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor< // %input has 1 batch dimension then 2 block dimensions then 1 remainder // dimension. // CHECK-LABEL: fourdim_SpaceToBatchND -func @fourdim_SpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> { +func @fourdim_SpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor { // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>} // CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor} // CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor} @@ -242,15 +251,15 @@ func @fourdim_SpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor< // CHECK-DAG: [[PERMUTED:%.+]] = "tf.Transpose"([[RESHAPED]], [[PERMUTATION]]) // CHECK-DAG: [[RESULT:%.+]] = "tf.Reshape"([[PERMUTED]], [[OUTPUT_SHAPE]]) // CHECK-DAG: return [[RESULT]] - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor } // %input has 1 batch dimension then 3 block dimensions then 2 remainder // dimensions. This checks only ops that are specific to the case with 3 block // dimension and 2 remainder dimensions. // CHECK-LABEL: sixdim_SpaceToBatchND -func @sixdim_SpaceToBatchND(%input: tensor<3x5x7x9x10x11xf32>, %block_shape: tensor<3xi64>, %paddings: tensor<3x2xi64>) -> tensor<*xf32> { +func @sixdim_SpaceToBatchND(%input: tensor<3x5x7x9x10x11xf32>, %block_shape: tensor<3xi64>, %paddings: tensor<3x2xi64>) -> tensor { // CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() // CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[PAD00]], {{.+}}) // CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 9, 10, 11]> : tensor<6xi64>} @@ -265,8 +274,8 @@ func @sixdim_SpaceToBatchND(%input: tensor<3x5x7x9x10x11xf32>, %block_shape: ten // CHECK-DAG: [[OUTPUT_BATCH_PART2:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART1]], [[BLOCK_SHAPE_SPLITS]]#1) // CHECK-DAG: [[OUTPUT_BATCH:%.+]] = "tf.Mul"([[OUTPUT_BATCH_PART2]], [[BLOCK_SHAPE_SPLITS]]#2) // CHECK-DAG: [[OUTPUT_SHAPE:%.+]] = "tf.ConcatV2"([[OUTPUT_BATCH]], [[OUTER_SHAPE_0]], [[OUTER_SHAPE_1]], [[OUTER_SHAPE_2]], [[PADDED_SHAPE_SPLITS]]#4, [[PADDED_SHAPE_SPLITS]]#5, {{.+}}) - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x9x10x11xf32>, tensor<3xi64>, tensor<3x2xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x9x10x11xf32>, tensor<3xi64>, tensor<3x2xi64>) -> tensor + return %0 : tensor } func @fake_quant_with_min_max_args(%arg0 : tensor) -> tensor { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index f37e25390cf..c8a6d5489c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -138,17 +138,17 @@ func @op_string_operand_string_result(%arg0: tensor) -> tensor return %0 : tensor } -// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation. +// Test that operations inside tf.IfRegion op are corrected marked for outside +// compilation. -// CHECK-LABEL: func @if_region_captured_string -func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { +// CHECK-LABEL: func @ops_inside_tf_if_outside_compiled +func @ops_inside_tf_if_outside_compiled(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf_device.cluster"() ( { - // CHECK: "tf.Const"() {value = dense<1> : tensor} - // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.IfRegion" - // CHECK: "tf.StringToNumber" - // CHECK-NOT: _xla_outside_compilation - // CHECK: _xla_outside_compilation = "auto1", is_stateless = true + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %2 = "tf.IfRegion"(%arg0) ( { %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor) -> tensor @@ -163,7 +163,8 @@ func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> return %0 : tensor } -// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation. +// Test that ops with string results/operands inside a tf.IfRegion branch are +// marked for outside compilation. // CHECK-LABEL: func @if_region_string_op func @if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { @@ -191,7 +192,8 @@ func @if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor } -// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation. +// Test that ops with string results/operands inside a nested tf.IfRegion branch +// are marked for outside compilation. // CHECK-LABEL: func @nested_if_region_string_op func @nested_if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { @@ -231,16 +233,17 @@ func @nested_if_region_string_op(%arg0: tensor, %arg1: tensor) -> ten return %0 : tensor } -// Test that a tf.WhileRegion op with a captured string operand is marked for outside compilation. +// Test that ops inside tf.WhileRegion op are correct marked for outside +// compilation. -// CHECK-LABEL: func @while_region_captured_string -func @while_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { +// CHECK-LABEL: func @ops_inside_while_outside_compiled +func @ops_inside_while_outside_compiled(%arg0: tensor, %arg1: tensor) -> tensor { %0 = "tf_device.cluster"() ( { - // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} // CHECK-NOT: _xla_outside_compilation - // CHECK: "tf.WhileRegion" - // CHECK: "tf.StringToNumber" - // CHECK: _xla_outside_compilation = "auto1", is_stateless = true + // CHECK: "tf.WhileRegion" + // CHECK: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation %1 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor %2:2 = "tf.WhileRegion"(%1, %arg0) ( { ^bb0(%carg0: tensor, %carg1: tensor): diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index 3e6d4f37bac..0813ee8db90 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -258,6 +258,19 @@ func @main(%arg0: tensor>>, %arg1: tensor) { // ----- +// Tests removal of dead local variables. + +// CHECK-LABEL: func @main +func @main(%arg0: tensor<2xf32>) { + // CHECK-NOT: tf.MlirLocalVarOp + // CHECK-NOT: tf.AssignVariableOp + %0 = "tf.MlirLocalVarOp"() : () -> tensor>> + "tf.AssignVariableOp"(%0, %arg0) : (tensor>>, tensor<2xf32>) -> () + return +} + +// ----- + // Tests first read of one resource is used as a value to write to another // resource. @@ -272,6 +285,26 @@ func @main(%arg0: tensor>>, %arg1: tensor) -> tensor<2xf32> { + // CHECK-NOT: tf.MlirLocalVarOp + // CHECK-NOT: tf.AssignVariableOp + %0 = "tf.MlirLocalVarOp"() : () -> tensor>> + %1 = "tf._SomeOp"() : () -> tensor<2xf32> + "tf.AssignVariableOp"(%0, %1) : (tensor>>, tensor<2xf32>) -> () + %2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor>>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} +func @callee(%arg0: tensor>>) -> tensor<2xf32> attributes {sym_visibility = "private"} { + %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + + // ----- // Tests main function with multiple blocks. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir index b65e88c589a..0c4dc77cf69 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_array_ops_decomposition.mlir @@ -54,6 +54,22 @@ func @main() -> tensor { // ----- +// Test inferring shape from the first scatter. + +// CHECK-LABEL: func @main +func @main() -> tensor { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + // CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor, tensor) + %indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %values = "tf.Const"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %write = "tf.TensorArrayScatterV3"(%ta#0, %indices, %values, %ta#1) : (tensor, tensor<2xi32>, tensor<2x3xf32>, tensor) -> tensor + %size_out = "tf.TensorArraySizeV3"(%ta#0, %write) : (tensor, tensor) -> tensor + return %size_out : tensor +} + +// ----- + // Test tensor array concat and split. // CHECK-LABEL: func @main @@ -259,6 +275,13 @@ func @main() -> () { // CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor>>) -> tensor<5x3xf32> // CHECK: "tf.Slice"(%[[READ]], %read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor, tensor, tensor) -> tensor<3xf32> + // CHECK: %[[READ_GVAR1:.*]] = "tf.ReadVariableOp"(%[[GVAR1]]) + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_GVAR1]], + // CHECK: "tf.AssignVariableOp"(%[[GVAR1]], %[[UPDATE]]) + %const = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem = "tf._SomeOp"() : () -> tensor<3xf32> + %grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor, tensor) -> (tensor, tensor) + %gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const, %elem, %grad#1) : (tensor, tensor, tensor<3xf32>, tensor) -> tensor return } // CHECK: func @then_branch(%[[TARG0:.*]]: tensor>>, %[[TARG1:.*]]: tensor>>, %[[TARG2:.*]]: tensor>>) @@ -412,6 +435,32 @@ func @callee() -> tensor attributes {sym_visibility = "public"} { // ----- +// CHECK-LABEL: func @main +func @main() -> () { + // CHECK: "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> tensor<*xf32> + %call = "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> (tensor<*xf32>) + return +} +func @callee() -> (tensor<*xf32>) attributes {sym_visibility = "private"} { + %size = "tf.Const"() {value = dense<5> : tensor} : () -> tensor + // CHECK: %[[LOCAL_VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor>> + %ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor) -> (tensor>>, tensor) + %index = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %value = "tf.Const"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: "tf.AssignVariableOp"(%[[LOCAL_VAR]], %[[UPDATE]]) : (tensor>>, tensor<5x3xf32>) -> () + %flow = "tf.TensorArrayWriteV3"(%ta#0, %index, %value, %ta#1) : (tensor>>, tensor, tensor<3xf32>, tensor) -> tensor + // CHECK: %[[SLICE:.*]] = "tf.Slice" + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + %val = "tf.TensorArrayReadV3"(%ta#0, %index, %ta#1) : (tensor>>, tensor, tensor) -> tensor<*xf32> + // CHECK: %[[CAST:.*]] = tensor_cast %[[ELEM]] : tensor<3xf32> to tensor<*xf32> + // CHECK: return %[[CAST]] : tensor<*xf32> + return %val : tensor<*xf32> +} + +// ----- + // Test the pass reports failure on unknown size. func @main(%arg0: tensor) -> () { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir index 92cb0458bf9..09a2dcb6713 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -34,7 +34,7 @@ func @main() -> (tensor, tensor) { // CHECK-NEXT: %[[SCALAR_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} // CHECK-NEXT: %[[LENGTH:.*]] = "tf.Reshape"(%[[NEW_SIZE]], %[[SCALAR_SHAPE]]) %length = "tf.TensorListLength"(%push) : (tensor>>) -> tensor - // CHECK-NEXT: return %[[ELEM]], %[[LENGTH]] : tensor, tensor + // CHECK-NEXT: return %[[ELEM]], %[[LENGTH]] : tensor, tensor return %pop#1, %length: tensor, tensor } @@ -81,7 +81,7 @@ func @main(%arg0: tensor) -> (tensor, tensor<10xf32>, tensor) { %stack = "tf.TensorListStack"(%addn2, %elem_shape) : (tensor>>, tensor<0xi32>) -> tensor<10xf32> // CHECK-NEXT: %[[LEN:.*]] = "tf.Const"() {value = dense<10> : tensor} : () -> tensor %length = "tf.TensorListLength"(%addn2) : (tensor>>) -> tensor - // CHECK-NEXT: return %[[ELEM]], %[[ADDN2]], %[[LEN]] : tensor, tensor<10xf32>, tensor + // CHECK-NEXT: return %[[ELEM]], %[[ADDN2]], %[[LEN]] : tensor, tensor<10xf32>, tensor return %get, %stack, %length : tensor, tensor<10xf32>, tensor } @@ -104,7 +104,7 @@ func @main(%arg0: tensor, %arg1: tensor<10xf32>) -> tensor { // CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor %get = "tf.TensorListGetItem"(%tl, %arg0, %elem_shape) : (tensor>>, tensor, tensor<0xi32>) -> tensor - // CHECK-NEXT: return %[[ELEM]] : tensor + // CHECK-NEXT: return %[[ELEM]] : tensor return %get: tensor } @@ -118,7 +118,7 @@ func @main(%arg0: tensor<10x8x9xf32>) -> tensor<2xi64> { %tl = "tf.TensorListFromTensor"(%arg0, %elem_shape) : (tensor<10x8x9xf32>, tensor<2xi32>) -> tensor>> // CHECK: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[8, 9]> : tensor<2xi64>} : () -> tensor<2xi64> %shape = "tf.TensorListElementShape"(%tl) : (tensor>>) -> tensor<2xi64> - // CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64> + // CHECK-NEXT: return %[[SHAPE]] : tensor<2xi64> return %shape: tensor<2xi64> } @@ -135,7 +135,7 @@ func @main(%arg0: tensor<10x8x9xf32>, %arg1: tensor<3xi32>) -> tensor<3x8x9xf32> // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[BUFFER]], %[[ARG1]], %[[AXIS]]) : (tensor<10x8x9xf32>, tensor<3xi32>, tensor) -> tensor<3x8x9xf32> %gather = "tf.TensorListGather"(%tl, %arg1, %elem_shape) : (tensor>>, tensor<3xi32>, tensor<2xi32>) -> tensor<3x8x9xf32> - // CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32> + // CHECK-NEXT: return %[[GATHER]] : tensor<3x8x9xf32> return %gather: tensor<3x8x9xf32> } @@ -173,7 +173,7 @@ func @main() -> () { : (tensor>>, tensor) -> (tensor>>, tensor) // CHECK: "tf.Slice" %pop:2 = "tf.TensorListPopBack"(%1#0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK-NOT: tf.EmptyTensorList + // CHECK-NOT: tf.TensorListPopBack // CHECK: return return } @@ -242,7 +242,7 @@ func @if_else(%arg0: tensor>>) -> tensor, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK: return %[[COPY]], %[[SUB]] + // CHECK: return %[[COPY]], %[[SUB]] return %pop#0 : tensor>> } @@ -289,7 +289,7 @@ func @branch_1(%arg0: tensor>>) -> tensor, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK: return %[[COPY]], %[[SUB]] + // CHECK: return %[[COPY]], %[[SUB]] return %pop#0 : tensor>> } // CHECK: func @branch_2(%[[EARG0:.*]]: tensor<10xf32>, %[[EARG1:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>) @@ -305,9 +305,145 @@ func @branch_2(%arg0: tensor>>) -> tensor, tensor<0xi32>) -> tensor // CHECK-NOT: "tf.TensorListPopBack" %pop:2 = "tf.TensorListPopBack"(%arg0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) - // CHECK: return %[[COPY]], %[[SUB]] + // CHECK: return %[[COPY]], %[[SUB]] return %pop#0 : tensor>> } + +// ----- + +// CHECK-LABEL: func @main +func @main() -> tensor { + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NOT: tf.EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<0xi32>, tensor) -> tensor>> + %while_op:2 = "tf.WhileRegion"(%tl, %size) ( { + // CHECK: ^bb0(%[[CARG0:.*]]: tensor<10xf32>, %[[CARG1:.*]]: tensor, %[[CARG2:.*]]: tensor<1xi32>): + ^bb0(%arg0: tensor>>, %arg1: tensor): // no predecessors + // CHECK: %[[PRED:.*]] = "tf._SomeOp"() + // CHECK: "tf.Yield"(%[[PRED]]) + %pred = "tf._SomeOp"() : () -> tensor + "tf.Yield"(%pred) : (tensor) -> () + }, { + // CHECK: ^bb0(%[[CARG0:.*]]: tensor<10xf32>, %[[CARG1:.*]]: tensor, %[[CARG2:.*]]: tensor<1xi32>): + ^bb0(%arg0: tensor>>, %arg1: tensor): // no predecessors + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CARG1]], %[[CST]]) + // CHECK: %[[ELEM:.*]] = "tf._SomeOp"() : () -> tensor + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %sub = "tf.Sub"(%arg1, %cst) : (tensor, tensor) -> tensor + %elem = "tf._SomeOp"() : () -> tensor + // CHECK-NOT: "tf.TensorListPushBack" + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[CARG0]] + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CARG2]], %[[ONE]]) + // CHECK-NOT: "tf.TensorListPushBack" + // CHECK: "tf.Yield"(%[[UPDATE]], %[[SUB]], %[[ADD]]) + // CHECK: }) {is_stateless = false} + %push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor>>, tensor) -> tensor>> + "tf.Yield"(%push, %sub) : (tensor>>, tensor) -> () + }) {is_stateless = false} : (tensor>>, tensor) -> (tensor>>, tensor) + // CHECK: "tf.Slice" + // CHECK-NOT: tf.TensorListPopBack + %pop:2 = "tf.TensorListPopBack"(%while_op#0, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: return + return %pop#1 : tensor +} +// ----- + +// CHECK-LABEL: func @main +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: %[[ZERO_F32:.*]] = "tf.Cast"(%[[ZERO]]) + // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} + // CHECK: %[[BUFFER:.*]] = "tf.BroadcastTo"(%[[ZERO_F32]], %[[MAX_SIZE]]) + // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} + // CHECK-NOT: tf.EmptyTensorList + %if_op = "tf.IfRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK-NOT: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%tl, %elem) : (tensor>>, tensor) -> tensor>> + "tf.Yield" (%push) : (tensor>>) -> () + }, { + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]]) + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + // CHECK-NOT: "tf.TensorListPopBack" + %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: "tf.Yield"(%[[COPY]], %[[SUB]]) + "tf.Yield" (%pop#0) : (tensor>>) -> () + }) + {is_stateless = false} + : (tensor) -> tensor>> + // CHECK: "tf.Slice" + // CHECK-NOT: tf.TensorListPopBack + %pop:2 = "tf.TensorListPopBack"(%if_op, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + return +} + +// ----- + +// CHECK-LABEL: func @main +func @main(%arg0: tensor) -> () { + %elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: %[[ZERO_F32:.*]] = "tf.Cast"(%[[ZERO]]) + // CHECK: %[[MAX_SIZE:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} + // CHECK: %[[BUFFER:.*]] = "tf.BroadcastTo"(%[[ZERO_F32]], %[[MAX_SIZE]]) + // CHECK: %[[BUFFER_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} + // CHECK-NOT: tf.EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> + %case_op = "tf.CaseRegion"(%arg0) ({ + %elem = "tf._SomeOp"() : () -> tensor + // CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice" + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK-NOT: "tf.TensorListPushBack" + %push = "tf.TensorListPushBack"(%tl, %elem) : (tensor>>, tensor) -> tensor>> + "tf.Yield" (%push) : (tensor>>) -> () + }, { + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]]) + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + // CHECK-NOT: "tf.TensorListPopBack" + %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: "tf.Yield"(%[[COPY]], %[[SUB]]) + "tf.Yield" (%pop#0) : (tensor>>) -> () + }, { + // CHECK: %[[COPY:.*]] = "tf.Identity"(%[[BUFFER]]) + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BUFFER_SIZE]], %[[ONE]]) + // CHECK: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} + // CHECK: %[[SLICE:.*]] = "tf.Slice"(%[[COPY]], %[[SUB]], %[[SLICE_SIZE]]) + // CHECK: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} + // CHECK: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) + // CHECK-NOT: "tf.TensorListPopBack" + %pop:2 = "tf.TensorListPopBack"(%tl, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + // CHECK: "tf.Yield"(%[[COPY]], %[[SUB]]) + "tf.Yield" (%pop#0) : (tensor>>) -> () + }) {is_stateless = false} + : (tensor) -> tensor>> + // CHECK: "tf.Slice" + // CHECK-NOT: tf.TensorListPopBack + %pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) + return +} + // ----- // Tests PartitionedCall/StatefulPartitionedCall. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 87c2d31c24c..0f8137af672 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1454,85 +1454,96 @@ func @testSoftmaxCrossEntropyWithLogits(%arg0: tensor<3xf32>, %arg1: tensor<3xf3 // Test valid tf.SpaceToBatchND // CHECK-LABEL: func @testSpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> { - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor { + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// Test valid tf.SpaceToBatchND +// CHECK-LABEL: func @testSpaceToBatchND +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>) -> tensor<36x2x3x10xf32> { + %block_shape = "tf.Const"() {value = dense<[4, 3]> : tensor<2xi64>} : () -> tensor<2xi64> + %paddings = "tf.Const"() {value = dense<[[1, 2], [1, 1]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<36x2x3x10xf32> + return %0 : tensor<36x2x3x10xf32> } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2x2xi64>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2x2xi64>, %paddings: tensor<2x2xi64>) -> tensor { // expected-error @+1 {{requires rank of block_shape = 1; got 2}} - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2x2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2xi64>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2xi64>) -> tensor { // expected-error @+1 {{requires rank of paddings = 2; got 1}} - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor + return %0 : tensor } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x10xi64>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x10xi64>) -> tensor { // expected-error @+1 {{requires paddings.shape[1] to be 2; got 10}} - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x10xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x10xi64>) -> tensor + return %0 : tensor } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<4xi64>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<4xi64>, %paddings: tensor<2x2xi64>) -> tensor { // expected-error @+1 {{requires block_shape.shape[0] must equal paddings.shape[0]}} - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<4xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<4xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5xf32>, %block_shape: tensor<2xi64>, %paddings: tensor<2x2xi64>) -> tensor { // expected-error @+1 {{requires rank of input >= 1 + rank of block}} - %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %0 : tensor<*xf32> + %0 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %0 : tensor } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %paddings: tensor<2x2xi64>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %paddings: tensor<2x2xi64>) -> tensor { %block_shape = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> tensor<2xi64> // expected-error @+1 {{requires all values of block_shape to be >= 1; failed for dimension 1}} - %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %1 : tensor<*xf32> + %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %1 : tensor } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>, %block_shape: tensor<2xi64>) -> tensor { %paddings = "tf.Const"() {value = dense<[[1, 0], [-1, 0]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> // expected-error @+1 {{requires all values of paddings to be >= 0; failed for dimension 1}} - %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %1 : tensor<*xf32> + %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor + return %1 : tensor } // ----- // Test invalid tf.SpaceToBatchND -func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>) -> tensor<*xf32> { +func @testSpaceToBatchND(%input: tensor<3x5x7x10xf32>) -> tensor<36x2x3x10xf32> { %block_shape = "tf.Const"() {value = dense<[4, 3]> : tensor<2xi64>} : () -> tensor<2xi64> %paddings = "tf.Const"() {value = dense<[[1, 2], [1, 2]]> : tensor<2x2xi64>} : () -> tensor<2x2xi64> // expected-error @+1 {{requires block_shape[i] divides input_shape[i + 1] + paddings[i, 0] + paddings[i, 1]; failed for i=1}} - %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32> - return %1 : tensor<*xf32> + %1 = "tf.SpaceToBatchND"(%input, %block_shape, %paddings) : (tensor<3x5x7x10xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<36x2x3x10xf32> + return %1 : tensor<36x2x3x10xf32> } // ----- @@ -2841,6 +2852,13 @@ func @testSplitV2(%input: tensor<4x4xf32>) { // ----- +func @testSplitVDynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { + %0:2 = "tf.SplitV"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// ----- + //===--------------------------------------------------------------------===// // tf.All //===--------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py new file mode 100644 index 00000000000..a6d78d4693b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/structured_output.py @@ -0,0 +1,125 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/structured_output | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common + + +class TestModule(tf.Module): + # The fNNNN name prefixes in this file are such that the sorted order of the + # functions in the resulting MLIR output match the order in the source file, + # allowing us to conveniently co-locate the CHECK's with the code they are + # checking. + # + # Note: CHECK-DAG doesn't work with CHECK-SAME/CHECK-NEXT. + + # Check index paths for results. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = []}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0000_single_return"] + @tf.function(input_signature=[]) + def f0000_single_return(self): + return tf.constant(1.0, shape=[1]) + + # Check index paths for results with multiple return values. + # Note that semantically in Python, multiple return values are equivalent + # to returning a tuple/list. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0001_multiple_results_no_punctuation"] + @tf.function(input_signature=[]) + def f0001_multiple_results_no_punctuation(self): + return tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2]) + + # Check index paths for results written explicitly with parentheses. + # This is semantically equivalent to the earlier test without parentheses, + # but this test serves as documentation of this behavior for the purposes + # of tf_saved_model users. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0002_multiple_results_parentheses"] + @tf.function(input_signature=[]) + def f0002_multiple_results_parentheses(self): + return (tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])) + + # Check index paths for results written explicitly with brackets. + # This is semantically equivalent to the earlier test without parentheses, + # but this test serves as documentation of this behavior for the purposes + # of tf_saved_model users. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0003_multiple_results_brackets"] + @tf.function(input_signature=[]) + def f0003_multiple_results_brackets(self): + return [tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])] + + # Check index paths for lists. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = [0, 0]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = [0, 1]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0004_list_2_elements"] + @tf.function(input_signature=[]) + def f0004_list_2_elements(self): + return [[tf.constant(1.0, shape=[1]), tf.constant(1.0, shape=[2])]] + + # Check index paths for dicts. + # Keys are linearized in sorted order, matching `tf.nest.flatten`. + # More thorough testing of this is in structured_input.py. The underlying code + # path for linearization is shared, so no need to replicate that testing here. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}() -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, + # CHECK-SAME: tensor<2xf32> {tf_saved_model.index_path = ["y"]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0005_dict_2_keys"] + @tf.function(input_signature=[]) + def f0005_dict_2_keys(self): + return { + 'x': tf.constant(1.0, shape=[1]), + 'y': tf.constant(1.0, shape=[2]), + } + + # Check index paths for outputs are correctly handled in the presence of + # multiple return statements. + # + # CHECK: func {{@[a-zA-Z_0-9]+}}( + # CHECK-SAME: %arg0: tensor {tf._user_specified_name = "x", tf_saved_model.index_path = [0]} + # CHECK-SAME: ) -> ( + # CHECK-SAME: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) + # CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["f0006_multiple_return_statements"] + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def f0006_multiple_return_statements(self, x): + if x > 3.: + return {'x': tf.constant(1.0, shape=[1])} + else: + return {'x': tf.constant(1.0, shape=[1])} + + +if __name__ == '__main__': + common.do_test(TestModule) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir index 84b4f97d4eb..ea2ebc64a29 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_lift_variables.mlir @@ -59,3 +59,25 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) } + +// ----- + +module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { + + // Test case: Fix bound_inputs' types. + + func @serving_default(%arg0: tensor>> {tf.resource_name = "dense/kernel"}, %arg1: tensor>> {tf.resource_name = "dense/bias"}) -> (tensor<*xf32> {tf_saved_model.index_path = ["dense_2"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor>>) -> tensor<*xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor>>) -> tensor<*xf32> + %2 = "tf.Add"(%0, %1) {device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> + } + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/kernel" + // CHECK: "tf_saved_model.global_tensor"() + // CHECK: sym_name = "dense/bias" + // CHECK: func @serving_default( + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}, + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @"dense/bias"}) +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 2be2e2aab50..e2cfd6c82b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -145,12 +145,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -200,15 +200,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"() // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"() - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.C"(%[[HOST_OUTPUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -235,11 +235,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: tf_device.return %[[HOST_OUTPUT]] %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -266,11 +266,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf.C"(%[[HOST_OUTPUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -333,12 +333,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf.D"(%[[HOST_OUTPUT]]#0) // CHECK: "tf.E"(%[[HOST_OUTPUT]]#1) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -368,20 +368,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]]) // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]]) // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster2_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster2_0_retvals" // CHECK: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]]) // CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[C_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster2_retvals" + // CHECK-SAME: recv_key = "host_compute_channel_cluster2_0_retvals" // CHECK: "tf.E"(%[[HOST_OUTPUT2]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -408,12 +408,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -437,20 +437,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]]) - // CHECK-SAME: key = "host_compute_channel_cluster2_args" + // CHECK-SAME: key = "host_compute_channel_cluster2_0_args" // CHECK: "tf.D"(%[[RECV_OUTPUT_2]]) // CHECK: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.B"(%[[RECV_OUTPUT_1]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" // CHECK: "tf._XlaHostComputeMlir"(%[[C_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster2_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster2_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -475,14 +475,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf_device.launch" // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.C"(%[[RECV_OUTPUT]]#0) // CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -531,25 +531,25 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 // CHECK-NEXT: "tf.Yield"() : () -> () %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -587,20 +587,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor, tensor) // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2) // CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]]) // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 %0 = "tf.A"(%arg0) : (tensor) -> tensor %7 = "tf.F"() : () -> tensor @@ -641,12 +641,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-SAME: (tensor<2x!tf.string>) -> tensor // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]]) // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor) // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1) // CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]]) @@ -654,20 +654,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: "tf.Yield"() : () -> () // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-SAME: (tensor) -> () // CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: %[[D_OUT:[0-9]*]] = "tf.D" // CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F" // CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 // CHECK: "tf.Yield"() : () -> () // CHECK: "tf.Yield"() : () -> () @@ -719,25 +719,25 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 // CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -776,7 +776,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK: "tf.D" // CHECK-NEXT: "tf.Yield"() : () -> () @@ -784,7 +784,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK-NEXT: "tf.Yield"() : () -> () %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -821,30 +821,30 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_0" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) // CHECK-NEXT: %[[PREDICATE2_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "if_predicate_channel_cluster1_1" + // CHECK-SAME: key = "if_predicate_channel_cluster1_0_1" // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE2_RECV_OUTPUT]]) // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" - // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0"} + // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_0"} // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]]) - // CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_1"} + // CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_0_1"} // CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]]) // CHECK-NEXT: "tf.Yield"() : () -> () // CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]]) @@ -895,7 +895,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: tf.WhileRegion" // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" @@ -954,7 +954,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) // CHECK_NEXT: "tf.Yield" // CHECK: "tf_device.cluster" @@ -1009,7 +1009,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: tf.WhileRegion" // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "while_condition_channel_cluster2_0" + // CHECK-SAME: key = "while_condition_channel_cluster2_0_0" // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" @@ -1023,7 +1023,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) // CHECK_NEXT: "tf.Yield" // CHECK: "tf_device.cluster" @@ -1079,7 +1079,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: tf.WhileRegion" // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK-SAME: key = "while_condition_channel_cluster1_0_0" // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) // CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) @@ -1144,7 +1144,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor, tensor) // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2) // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]]) @@ -1154,14 +1154,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[K_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) // CHECK-SAME: device_ordinal = 0 - // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: key = "host_compute_channel_cluster1_0_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]]) - // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" - // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_0_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_0_args" // CHECK-SAME: tpu_core = 0 %0 = "tf.A"(%arg0) : (tensor) -> tensor %7 = "tf.F"() : () -> tensor @@ -1202,7 +1202,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor } // Tests extraction of an outside compiled tf.WhileRegion where the entire - // tf.WhileREgion op is outside compiled with a nested tf.IfRegion. + // tf.WhileRegion op is outside compiled with a nested tf.IfRegion. // CHECK-LABEL: func @outside_compiled_ops_tf_while_nested_if func @outside_compiled_ops_tf_while_nested_if(%arg0: tensor) -> tensor { @@ -1253,4 +1253,70 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor return %1 : tensor } + + // Tests extraction of an outside compiled cluster that contains ops wrapped + // inside multiple regions of nested tf.IfRegion and tf.WhileRegion. + + // CHECK-LABEL: func @outside_compiled_ops_with_multiple_region_single_cluster + func @outside_compiled_ops_with_multiple_region_single_cluster(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[B_OUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: "tf.WhileRegion"() + // CHECK-NEXT: %[[WHILE_COND:.*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: "tf.Yield"(%[[WHILE_COND]]) + // CHECK: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]]) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[C_OUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[IF_COND:.*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: "tf.IfRegion"(%[[IF_COND]]) + // CHECK-NEXT: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[C_OUT]]) + // CHECK: "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: %[[B_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"() + // CHECK-NEXT: %[[G_OUT:.*]] = "tf.G" + // CHECK-NEXT: "tf.WhileRegion"(%[[B_OUT_DEVICE]], %[[A_OUT]]) + // CHECK: %[[H_OUT:.*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUT]]) + // CHECK: %[[C_OUT_DEVICE:.*]] = "tf._XlaHostComputeMlir"() + // CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUT]]) + // CHECK-NEXT: "tf.IfRegion"(%[[G_OUT]]) + // CHECK-NEXT: "tf._XlaHostComputeMlir"() + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() {_xla_outside_compilation="cluster0"} : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%4) {_xla_outside_compilation="cluster0"} : (tensor) -> tensor + %10 = "tf.IfRegion"(%6) ({ + %9 = "tf.D"(%8) {_xla_outside_compilation="cluster0"} : (tensor) -> tensor + "tf.Yield"(%9) : (tensor) -> () + }, { + "tf.Yield"(%arg2) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%8, %10) : (tensor, tensor) -> () + }) {is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir index 269af51504f..183c7c34d41 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_outside_compilation_cluster.mlir @@ -95,7 +95,6 @@ func @two_clusters_with_one_op_each() { // CHECK-NEXT: "tf.opC" // CHECK-NEXT: "tf.opD" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER6]]" - // CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}" // CHECK-NEXT: "tf.opE" "tf_device.cluster"() ( { %a = "tf.opA"() : () -> tensor @@ -118,9 +117,8 @@ func @two_clusters_with_two_ops_each() { // CHECK-NEXT: "tf.opD" // CHECK-NEXT: "tf.opE" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER8]]" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER9:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opF" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER9]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER8]]" // CHECK-NEXT: "tf.opG" "tf_device.cluster"() ( { %a = "tf.opA"() : () -> tensor @@ -163,12 +161,11 @@ func @two_clusters_with_same_parent() { // CHECK-NEXT: "tf.opB" // CHECK-NEXT: "tf.opC" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER10]]" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER11:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opD" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opE" // CHECK-NEXT: "tf.opF" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13:[a-zA-Z_0-9]+]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]" // CHECK-NEXT: "tf.opG" "tf_device.cluster"() ( { %a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor @@ -189,10 +186,10 @@ func @two_clusters_with_same_outside_compiled_parent() { // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER12:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opB" // CHECK-NEXT: "tf.opC" - // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.opD" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER14:[a-zA-Z_0-9]+]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER12]]" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER13]]" // CHECK-NEXT: "tf.Identity" // CHECK-NEXT: "tf.opF" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER13]]" @@ -234,8 +231,7 @@ func @outside_compile_with_block() { // CHECK-NEXT: "tf.opB" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER15]]" // CHECK: "tf.opC" - // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER14]]" - // CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER15]]" "tf_device.cluster"() ( { %a = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor %b = "tf.opB"(%a) {_xla_outside_compilation = "0"} : (tensor) -> tensor @@ -257,7 +253,6 @@ func @two_clusters_with_one_op_each_with_indirect_dependency() { // CHECK-NEXT: "tf.opD" // CHECK-NEXT: "tf.opE" // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER16]]" - // CHECK-SAME: _xla_outside_compilation = "{{[a-zA-Z_0-9]+}}" // CHECK-NEXT: "tf.opF" "tf_device.cluster"() ( { %a = "tf.opA"() : () -> tensor @@ -326,27 +321,27 @@ func @check_op_inside_nested_region_clustered(%arg0 : tensor<*x!tf.resource>) { return } -// CHECK-LABEL: func @check_ops_inside_different_block_in_different_cluster -func @check_ops_inside_different_block_in_different_cluster(%arg0 : tensor<*x!tf.resource>) { +// CHECK-LABEL: func @check_ops_inside_different_block_clustered +func @check_ops_inside_different_block_clustered(%arg0 : tensor<*x!tf.resource>) { // CHECK: tf_device.cluster // CHECK-NEXT: "tf.Const" // CHECK-NEXT: "tf.B" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.C" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18:[a-zA-Z_0-9]+]]" // CHECK: "tf.IfRegion" // CHECK-NEXT: "tf.Const" // CHECK-NEXT: "tf.Const" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18:[a-zA-Z_0-9]+]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" // CHECK-NEXT: "tf.Const" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" // CHECK-NEXT: "tf.WriteSummary" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" // CHECK: "tf.Const" // CHECK-NEXT: "tf.Const" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER19:[a-zA-Z_0-9]+]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" // CHECK-NEXT: "tf.D" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER19]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" "tf_device.cluster"() ( { %0 = "tf.Const"() {value = dense : tensor} : () -> tensor %2 = "tf.B"() {_xla_outside_compilation = "auto1"} : () -> (tensor) @@ -376,16 +371,16 @@ func @check_clustering_ops_inside_nested_control_flow(%arg0 : tensor<*x!tf.resou // CHECK-NEXT: "tf.B" // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" // CHECK-NEXT: "tf.C" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17:[a-zA-Z_0-9]+]]" + // CHECK: _xla_outside_compilation = "[[CLUSTER17]]" // CHECK: "tf.IfRegion" // CHECK: "tf.IfRegion" // CHECK-NEXT: "tf.Const" // CHECK-NEXT: "tf.Const" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18:[a-zA-Z_0-9]+]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" // CHECK-NEXT: "tf.Const" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" // CHECK-NEXT: "tf.WriteSummary" - // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER18]]" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER17]]" "tf_device.cluster"() ( { %0 = "tf.Const"() {value = dense : tensor} : () -> tensor %2 = "tf.B"() {_xla_outside_compilation = "auto1"} : () -> (tensor) @@ -411,3 +406,136 @@ func @check_clustering_ops_inside_nested_control_flow(%arg0 : tensor<*x!tf.resou }) {cluster_attr = "cluster_attr"} : () -> () return } + +// CHECK-LABEL: func @single_variant_input +func @single_variant_input() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + "tf_device.cluster"() ( { + %1= "tf.opA"() : () -> tensor>> + "tf.opB"(%1) {_xla_outside_compilation = "0"} : (tensor>>) -> () + "tf.opC"() : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @chained_variant_input +func @chained_variant_input() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %1 = "tf.opA"() : () -> tensor>> + %2 = "tf.opB"(%1) : (tensor>>) -> (tensor>>) + "tf.opC"(%2) {_xla_outside_compilation = "0"} : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @single_variant_output +func @single_variant_output() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + "tf_device.cluster"() ( { + %1= "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor>> + "tf.opB"(%1) : (tensor>>) -> () + "tf.opC"() : () -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @chained_variant_output +func @chained_variant_output() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %1 = "tf.opA"() {_xla_outside_compilation = "0"} : () -> tensor>> + %2 = "tf.opB"(%1) : (tensor>>) -> (tensor>>) + "tf.opC"(%2) : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @variant_input_output +func @variant_input_output() { + // CHECK: "tf.opA" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.opB" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + // CHECK: "tf.opC" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %1 = "tf.opA"() : () -> tensor>> + %2 = "tf.opB"(%1) {_xla_outside_compilation = "0"} : (tensor>>) -> (tensor>>) + "tf.opC"(%2) : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @variant_input_nested +func @variant_input_nested(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.C" + // CHECK-SAME: _xla_outside_compilation = "[[CLUSTER1:[a-zA-Z_0-9]+]]" + // CHECK: "tf.IfRegion" + // CHECK: "tf.opD" + // CHECK-NOT: _xla_outside_compilation = "[[CLUSTER1]]" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.C"() {_xla_outside_compilation = "auto0"} : () -> (tensor>>) + "tf.IfRegion"(%0) ( { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.opD"(%2) : (tensor>>) -> () + "tf.Yield"(%1) : (tensor) -> () + }, { + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%1) : (tensor) -> () + }) { is_stateless = true, _xla_outside_compilation = "auto1" } : (tensor) -> tensor + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @variant_output_nested +func @variant_output_nested(%arg0 : tensor<*x!tf.resource>) { + // CHECK: tf_device.cluster + // CHECK: "tf.IfRegion" + // CHECK: "tf.C" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.D" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.Yield" + // CHECK: _xla_outside_compilation + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %2 = "tf.C"() : () -> (tensor>>) + "tf.Yield"(%2) : (tensor>>) -> () + }, { + %2 = "tf.D"() : () -> (tensor>>) + "tf.Yield"(%2) : (tensor>>) -> () + }) { is_stateless = true, _xla_outside_compilation = "auto1" } : (tensor) -> tensor>> + "tf.E"(%1) {_xla_outside_compilation = "auto0"} : (tensor>>) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index bb159f68eae..eccbe5feaec 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -58,7 +58,7 @@ tensorflow::Status RunTPUBridge( ModuleOp module, bool enable_logging, llvm::function_ref pipeline_builder) { PassManager bridge(module.getContext()); - ::tensorflow::SetCrashReproducer(bridge); + ::tensorflow::applyTensorflowAndCLOptions(bridge); if (enable_logging) EnableLogging(&bridge); // Populate a passmanager with the list of passes that implement the bridge. @@ -106,6 +106,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { pm.addPass(TFDevice::CreateResourceOpLiftingPass()); pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass()); pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); + pm.addPass(CreateTPUOutsideCompilationClusterPass()); pm.addPass(CreateTPUExtractOutsideCompilationPass()); pm.addNestedPass(tf_executor::CreateTFExecutorConstantSinkingPass()); @@ -119,7 +120,6 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { pm.addPass(createSymbolDCEPass()); pm.addNestedPass(TFDevice::CreateReplicateInvariantOpHoistingPass()); pm.addNestedPass(CreateTPUDynamicLayoutPass()); - pm.addNestedPass(CreateTPUParallelExecuteSinkResourceWritePass()); pm.addNestedPass(CreateTPUMergeVariablesWithExecutePass()); pm.addNestedPass(CreateTPUColocateCompositeResourceOps()); pm.addPass(CreateTPUVariableReformattingPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 3005c78c54f..31cfc5ebf9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" #include "tensorflow/core/platform/mutex.h" @@ -72,7 +73,8 @@ LogicalResult ConstantFoldFallbackHook( SmallVectorImpl& results) { // NOLINT // Instructions with side effects should not be constant folded to preserve // the original semantics. - if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) + if (inst->hasTrait() || + inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) return failure(); // If any of the result types are variants, don't try to constant fold them. @@ -87,7 +89,7 @@ LogicalResult ConstantFoldFallbackHook( } // Do not execute function calls. - if (llvm::isa(inst)) { + if (llvm::isa(inst)) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc index 4737f44ae1e..28a5c583919 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc @@ -73,7 +73,7 @@ static Type GetResourceSubtype(Value resource) { void PopulateDecomposeResourceOpsPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 69dab58c3f5..c3d43c27ac5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -16,12 +16,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include +#include #include #include #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -43,130 +47,6 @@ namespace TF { namespace { -// All supported Einsum equations. -enum EinsumEquation { - BatchMatMul, - FourDMatrixDotProd, - ThreeDReshapeTail, - FourDBatchMatMul, - BroadcastMatMul, - ReduceSum, - TransposeMatMul, - BatchMatMulReducedDim, - TransposeReducedDim, - FourDReduceLast, - FourDTransposeAll, - UnsupportedEquation -}; - -// Tokens for parsing the given equation string. -enum EquationToken { - A, - B, - C, - D, - E, - COMMA, - ARROW, -}; -constexpr int kNumSupportedEquationVariables = 5; // A - E for now. - -bool tokenizeEquation(const llvm::StringRef& equation, - std::vector* tokens) { - std::map label_axis_mapping; - size_t index = 0; - int variable_count = 0; - llvm::Regex r("[[:alpha:]]"); - while (index < equation.size()) { - if (r.match(equation.substr(index, 1))) { - const char ltr = equation[index]; - auto itr = label_axis_mapping.find(ltr); - if (itr == label_axis_mapping.end() && - variable_count < kNumSupportedEquationVariables) { - label_axis_mapping[ltr] = EquationToken(variable_count); - tokens->push_back(EquationToken(variable_count)); - variable_count++; - } else if (itr != label_axis_mapping.end()) { - tokens->push_back(itr->second); - } else { - // Ran out of equation variables. - return false; - } - } else if (equation.substr(index, 1).contains(",")) { - tokens->push_back(COMMA); - } else if ((index < (equation.size() - 1)) && - (equation.substr(index, 2).contains("->"))) { - tokens->push_back(ARROW); - index++; - } else { - // Unallowed character encountered. - return false; - } - index++; - } - return true; -} - -EinsumEquation parseEquation(const std::vector& eqn) { - auto is_equal = [](const std::vector& eqn1, - const std::initializer_list& eqn2) { - return std::equal(eqn1.begin(), eqn1.end(), eqn2.begin(), eqn2.end()); - }; - // IJK,IKM->IJM - if (is_equal(eqn, {A, B, C, COMMA, A, C, D, ARROW, A, B, D})) { - return EinsumEquation::BatchMatMul; - } - // BFND,NDH->BFH - if (is_equal(eqn, {A, B, C, D, COMMA, C, D, E, ARROW, A, B, E})) { - return EinsumEquation::FourDMatrixDotProd; - } - // BFNH,BTNH->BNFT - if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, B, E})) { - return EinsumEquation::FourDBatchMatMul; - } - // BFD,DNH->BFNH - if (is_equal(eqn, {A, B, C, COMMA, C, D, E, ARROW, A, B, D, E})) { - return EinsumEquation::ThreeDReshapeTail; - } - // BFH,HO->BFO - if (is_equal(eqn, {A, B, C, COMMA, C, D, ARROW, A, B, D})) { - return EinsumEquation::BroadcastMatMul; - } - // LBH,BL->BH - if (is_equal(eqn, {A, B, C, COMMA, B, A, ARROW, B, C})) { - return EinsumEquation::ReduceSum; - } - // LBH,BKL->BKH - if (is_equal(eqn, {A, B, C, COMMA, B, D, A, ARROW, B, D, C})) { - return EinsumEquation::TransposeMatMul; - } - // BIN,BINJ->BIJ - if (is_equal(eqn, {A, B, C, COMMA, A, B, C, D, ARROW, A, B, D})) { - return EinsumEquation::BatchMatMulReducedDim; - } - // BIJ,BINJ->BIN - if (is_equal(eqn, {A, B, C, COMMA, A, B, D, C, ARROW, A, B, D})) { - return EinsumEquation::TransposeReducedDim; - } - // ABCD,ADBE->ACBE - if (is_equal(eqn, {A, B, C, D, COMMA, A, D, B, E, ARROW, A, C, B, E})) { - return EinsumEquation::FourDReduceLast; - } - // ABCD,AECD->ACEB - if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, E, B})) { - return EinsumEquation::FourDTransposeAll; - } - return EinsumEquation::UnsupportedEquation; -} - -EinsumEquation tokenizeAndParse(const llvm::StringRef& equation) { - std::vector tokens; - if (tokenizeEquation(equation, &tokens)) { - return parseEquation(tokens); - } - return EinsumEquation::UnsupportedEquation; -} - TF::TransposeOp createTransposeOp(Value value, Location loc, llvm::ArrayRef permutation, PatternRewriter* rewriter) { @@ -186,28 +66,6 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, perm_op); } -TF::SumOp createSumOp(Value value, Location loc, - llvm::ArrayRef redux_axes, - PatternRewriter* rewriter) { - auto value_type = value.getType().cast(); - auto shape = value_type.getShape(); - auto redux_type = RankedTensorType::get( - {static_cast(redux_axes.size())}, rewriter->getIntegerType(32)); - auto redux_attr = DenseElementsAttr::get(redux_type, redux_axes); - auto redux_op = rewriter->create(loc, redux_type, redux_attr); - std::vector sum_shape(shape.size() - redux_axes.size()); - int count = 0; - for (int i = 0, end = shape.size(); i < end; ++i) { - if (std::find(redux_axes.begin(), redux_axes.end(), i) == - redux_axes.end()) { - sum_shape[count] = shape[i]; - count++; - } - } - auto sum_type = RankedTensorType::get(sum_shape, value_type.getElementType()); - return rewriter->create(loc, sum_type, value, redux_op); -} - TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, Type element_type, Location loc, PatternRewriter* rewriter) { @@ -222,241 +80,277 @@ TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, /*shape=*/shape_tensor); } +struct EinsumDimensionNumbers { + // Each field contains the list of dimensions appearing only in the specifed + // arguments of the einsum op with natural ordering. For example `rhs_out` + // contains the dimensions appearing in the RHS and the OUTPUT of the einsum + // but not in the LHS. + std::vector lhs; + std::vector rhs; + std::vector> lhs_rhs; + std::vector> lhs_out; + std::vector> rhs_out; + std::vector> lhs_rhs_out; +}; + +llvm::Optional> EquationToMap( + llvm::StringRef equation) { + llvm::SmallDenseMap map; + for (int64_t i = 0; i < equation.size(); ++i) { + if (!std::isalpha(equation[i])) { + // Unsupported character in the equation. + return llvm::None; + } + if (map.count(equation[i])) { + // Duplicate character in the equation. + return llvm::None; + } + map.try_emplace(equation[i], i); + } + return map; +} + +llvm::Optional GetEinsumDimensionNumbers( + llvm::StringRef equation) { + llvm::StringRef lhs_rhs; + llvm::StringRef out; + std::tie(lhs_rhs, out) = equation.split("->"); + if (lhs_rhs.empty() || out.empty()) return llvm::None; + + llvm::StringRef lhs; + llvm::StringRef rhs; + std::tie(lhs, rhs) = lhs_rhs.split(','); + if (lhs.empty() || rhs.empty()) return llvm::None; + + auto lhs_map_or = EquationToMap(lhs); + if (!lhs_map_or.hasValue()) return llvm::None; + auto lhs_map = lhs_map_or.getValue(); + + auto rhs_map_or = EquationToMap(rhs); + if (!rhs_map_or.hasValue()) return llvm::None; + auto rhs_map = rhs_map_or.getValue(); + + auto out_map_or = EquationToMap(out); + if (!out_map_or.hasValue()) return llvm::None; + auto out_map = out_map_or.getValue(); + + EinsumDimensionNumbers dnums; + for (int64_t i = 0, e = lhs.size(); i < e; ++i) { + auto rhs_index = rhs_map.find(lhs[i]); + auto out_index = out_map.find(lhs[i]); + if (rhs_index == rhs_map.end() && out_index == out_map.end()) { + dnums.lhs.emplace_back(i); + } else if (rhs_index == rhs_map.end()) { + dnums.lhs_out.emplace_back(i, out_index->second); + } else if (out_index == out_map.end()) { + dnums.lhs_rhs.emplace_back(i, rhs_index->second); + } else { + dnums.lhs_rhs_out.emplace_back(i, rhs_index->second, out_index->second); + } + } + for (int64_t i = 0, e = rhs.size(); i < e; ++i) { + auto lhs_index = lhs_map.find(rhs[i]); + auto out_index = out_map.find(rhs[i]); + if (lhs_index == lhs_map.end()) { + if (out_index == out_map.end()) { + dnums.rhs.emplace_back(i); + } else { + dnums.rhs_out.emplace_back(i, out_index->second); + } + } + } + for (int64_t i = 0, e = out.size(); i < e; ++i) { + auto lhs_index = lhs_map.find(out[i]); + auto rhs_index = rhs_map.find(out[i]); + if (lhs_index == lhs_map.end() && rhs_index == rhs_map.end()) { + // out only isn't supported + return llvm::None; + } + } + return dnums; +} + +std::vector inverseTransposeVector( + llvm::ArrayRef input, llvm::ArrayRef permutation) { + std::vector output(input.size()); + for (int64_t i = 0; i < input.size(); ++i) { + output[permutation[i]] = input[i]; + } + return output; +} + +// Computes the transpositions required to convert dnums to one supported by +// tf.BatchMatmulV2 and returns the new set of dimension numbers with them. +LogicalResult transposeForBatchMatmul( + const Location& loc, EinsumDimensionNumbers& dnums, Value* lhs, Value* rhs, + std::vector* out_inverse_transpose, PatternRewriter* rewriter) { + std::vector lhs_transpose; + std::vector rhs_transpose; + std::vector out_transpose; + lhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + + dnums.lhs_rhs.size()); + rhs_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.rhs_out.size() + + dnums.lhs_rhs.size()); + out_transpose.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + + dnums.rhs_out.size()); + for (int64_t i = 0, e = dnums.lhs_rhs_out.size(); i < e; ++i) { + lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs_out[i])); + rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs_out[i])); + out_transpose.push_back(std::get<2>(dnums.lhs_rhs_out[i])); + dnums.lhs_rhs_out[i] = std::make_tuple(i, i, i); + } + + for (int64_t i = 0, e = dnums.lhs_out.size(); i < e; ++i) { + lhs_transpose.push_back(std::get<0>(dnums.lhs_out[i])); + out_transpose.push_back(std::get<1>(dnums.lhs_out[i])); + dnums.lhs_out[i] = + std::make_tuple(lhs_transpose.size() - 1, out_transpose.size() - 1); + } + for (int64_t i = 0, e = dnums.lhs_rhs.size(); i < e; ++i) { + lhs_transpose.push_back(std::get<0>(dnums.lhs_rhs[i])); + rhs_transpose.push_back(std::get<1>(dnums.lhs_rhs[i])); + dnums.lhs_rhs[i] = + std::make_tuple(lhs_transpose.size() - 1, rhs_transpose.size() - 1); + } + for (int64_t i = 0, e = dnums.rhs_out.size(); i < e; ++i) { + rhs_transpose.push_back(std::get<0>(dnums.rhs_out[i])); + out_transpose.push_back(std::get<1>(dnums.rhs_out[i])); + dnums.rhs_out[i] = + std::make_tuple(rhs_transpose.size() - 1, out_transpose.size() - 1); + } + + out_inverse_transpose->resize(out_transpose.size()); + for (int64_t i = 0, e = out_transpose.size(); i < e; ++i) { + out_inverse_transpose->at(out_transpose[i]) = i; + } + + *lhs = createTransposeOp(*lhs, loc, lhs_transpose, rewriter); + *rhs = createTransposeOp(*rhs, loc, rhs_transpose, rewriter); + return success(); +} + +// Reshapes LHS and RHS to have B0,...,Bn,L,C and B0,...,Bn,C,R shape +// respectively while assuming that the initial shape for them is +// B0,...,Bn,L0,...,Ln,C0,...,Cn and B0,...,Bn,C0,...,Cn,R0,...,Rn respectively. +LogicalResult reshapeForBatchMatmul(const Location& loc, + EinsumDimensionNumbers& dnums, Value* lhs, + Value* rhs, std::vector* out_shape, + PatternRewriter* rewriter) { + RankedTensorType lhs_type = lhs->getType().cast(); + RankedTensorType rhs_type = rhs->getType().cast(); + + std::vector lhs_shape; + std::vector rhs_shape; + lhs_shape.reserve(dnums.lhs_rhs_out.size() + dnums.lhs_out.size() + 1); + rhs_shape.reserve(dnums.lhs_rhs_out.size() + 2); + for (auto i : dnums.lhs_rhs_out) { + int64_t b = lhs_type.getShape()[std::get<0>(i)]; + lhs_shape.push_back(b); + rhs_shape.push_back(b); + out_shape->push_back(b); + } + + if (dnums.lhs_out.empty()) { + lhs_shape.push_back(1); + out_shape->push_back(1); + dnums.lhs_out.emplace_back(lhs_shape.size() - 1, out_shape->size() - 1); + } else if (dnums.lhs_rhs_out.empty()) { + for (auto i : dnums.lhs_out) { + int64_t b = lhs_type.getShape()[std::get<0>(i)]; + lhs_shape.push_back(b); + out_shape->push_back(b); + } + } else { + int64_t lhs_out_size = 1; + for (auto i : dnums.lhs_out) { + lhs_out_size *= lhs_type.getShape()[std::get<0>(i)]; + } + lhs_shape.push_back(lhs_out_size); + out_shape->push_back(lhs_out_size); + } + + int64_t lhs_rhs_size = 1; + for (auto i : dnums.lhs_rhs) { + lhs_rhs_size *= lhs_type.getShape()[std::get<0>(i)]; + } + lhs_shape.push_back(lhs_rhs_size); + rhs_shape.push_back(lhs_rhs_size); + + int64_t rhs_size = 1; + for (auto i : dnums.rhs_out) { + rhs_size *= rhs_type.getShape()[std::get<0>(i)]; + } + rhs_shape.push_back(rhs_size); + out_shape->push_back(rhs_size); + + *lhs = createReshapeOp(*lhs, lhs_shape, lhs_type.getElementType(), loc, + rewriter); + *rhs = createReshapeOp(*rhs, rhs_shape, rhs_type.getElementType(), loc, + rewriter); + + dnums.lhs_rhs.assign( + {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(), + dnums.lhs_rhs_out.size())}); + dnums.rhs_out.assign( + {std::make_tuple(dnums.lhs_rhs_out.size() + dnums.lhs_out.size(), + dnums.lhs_rhs_out.size() + dnums.lhs_out.size())}); + return success(); +} + +LogicalResult rewriteToBatchMatmul(TF::EinsumOp op, + EinsumDimensionNumbers dnums, + PatternRewriter& rewriter) { + if (!dnums.lhs.empty() || !dnums.rhs.empty()) return failure(); + + auto inputs = op.inputs(); + if (inputs.size() != 2) return failure(); + Value lhs = inputs.front(); + Value rhs = inputs.back(); + + RankedTensorType original_type = + op.getResult().getType().dyn_cast_or_null(); + if (!original_type) return failure(); + + std::vector out_transpose; + if (failed(transposeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs, + &out_transpose, &rewriter))) + return failure(); + + std::vector matmul_shape; + if (failed(reshapeForBatchMatmul(op.getLoc(), dnums, &lhs, &rhs, + &matmul_shape, &rewriter))) + return failure(); + + auto matmul_type = + RankedTensorType::get(matmul_shape, original_type.getElementType()); + Value out = rewriter.create( + op.getLoc(), matmul_type, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + + out = createReshapeOp( + out, inverseTransposeVector(original_type.getShape(), out_transpose), + original_type.getElementType(), op.getLoc(), &rewriter); + out = createTransposeOp(out, op.getLoc(), out_transpose, &rewriter); + + rewriter.replaceOp(op, out); + return success(); +} + } // namespace LogicalResult ConvertTFEinsumOp::matchAndRewrite( TF::EinsumOp op, PatternRewriter& rewriter) const { - Type output_type = op.getResult().getType(); - Value lhs = op.getOperand(0); - Value rhs = op.getOperand(1); - Location loc = op.getLoc(); - if (!lhs.getType().isa()) { - // LHS must be a ranked tensor type - return failure(); - } - if (!rhs.getType().isa()) { - // RHS must be a ranked tensor type - return failure(); - } + const auto dnums_or = GetEinsumDimensionNumbers(op.equation()); + if (!dnums_or.hasValue()) return failure(); + const auto& dnums = dnums_or.getValue(); - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); - auto lhs_shape = lhs_type.getShape(); - auto rhs_shape = rhs_type.getShape(); + RankedTensorType lhs = + op.getOperand(0).getType().dyn_cast_or_null(); + RankedTensorType rhs = + op.getOperand(1).getType().dyn_cast_or_null(); + if (!lhs || !rhs) return failure(); - // Currently only support static shapes. - if (!(lhs_type.hasStaticShape() && rhs_type.hasStaticShape())) { - return failure(); - } - - // Currently support use cases of LHS dims \in {3,4} RHS dims \in {2, 3, 4} - const int dims_lhs = lhs_shape.size(); - const int dims_rhs = rhs_shape.size(); - if (dims_lhs < 3 || dims_lhs > 4 || dims_rhs < 2 || dims_rhs > 4) { - return failure(); - } - - EinsumEquation einsum_eqn = tokenizeAndParse(op.equation()); - if (einsum_eqn == EinsumEquation::BatchMatMul) { - // Case "IJK,IKM->IJM" - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, bmm_op.getResult()); - return success(); - } - if (einsum_eqn == EinsumEquation::BroadcastMatMul) { - // Case "BFH,HO->BFO" - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, bmm_op.getResult()); - return success(); - } - if (einsum_eqn == EinsumEquation::ReduceSum) { - // Case "LBH,BL->BH" - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter); - // Reshape RHS - auto rhs_element_type = rhs_type.getElementType(); - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim1 = rhs_shape[1]; - auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, 1, rhs_dim1}, - rhs_element_type, loc, &rewriter); - auto mul_op = rewriter.create(loc, lhs, reshaped_rhs); - - auto sum_op = createSumOp(mul_op, loc, {2}, &rewriter); - rewriter.replaceOp(op, {sum_op.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::TransposeMatMul) { - // Case "LBH,BKL->BKH" - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {1, 2, 0}, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 1}, &rewriter); - std::vector bmm_shape = {lhs_shape[1], lhs_shape[2], rhs_shape[1]}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - - auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1}, &rewriter); - rewriter.replaceOp(op, {trans_bmm.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::ThreeDReshapeTail) { - // Case "BFD,DNH->BFNH" - auto lhs_type = lhs.getType().cast(); - auto lhs_shape = lhs_type.getShape(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - // Reshape RHS - auto rhs_type = rhs.getType().cast(); - auto rhs_shape = rhs_type.getShape(); - auto rhs_element_type = rhs_type.getElementType(); - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim1 = rhs_shape[1]; - const int rhs_dim2 = rhs_shape[2]; - auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0, rhs_dim1 * rhs_dim2}, - rhs_element_type, loc, &rewriter); - - std::vector bmm_shape = {lhs_dim0, lhs_dim1, rhs_dim1 * rhs_dim2}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, reshaped_rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - auto bmm_element_type = bmm_type.getElementType(); - auto final_reshape = - createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim1, rhs_dim2}, - bmm_element_type, loc, &rewriter); - rewriter.replaceOp(op, {final_reshape.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDMatrixDotProd) { - // Case "BFND,NDH->BFH" - // Reshape LHS - auto lhs_element_type = lhs_type.getElementType(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int lhs_dim3 = lhs_shape[3]; - auto reshaped_lhs = - createReshapeOp(lhs, {lhs_dim0, lhs_dim1, lhs_dim2 * lhs_dim3}, - lhs_element_type, loc, &rewriter); - // Reshape RHS - auto rhs_element_type = rhs_type.getElementType(); - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim1 = rhs_shape[1]; - const int rhs_dim2 = rhs_shape[2]; - auto reshaped_rhs = createReshapeOp(rhs, {rhs_dim0 * rhs_dim1, rhs_dim2}, - rhs_element_type, loc, &rewriter); - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, reshaped_lhs, reshaped_rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, {bmm_op.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDBatchMatMul) { - // Case "BFNH,BTNH->BNFT" - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter); - auto bmm_op = rewriter.create( - loc, ArrayRef{output_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - rewriter.replaceOp(op, {bmm_op.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::BatchMatMulReducedDim) { - // Case "BIN,BINJ->BIJ" - // Reshape LHS - auto lhs_element_type = lhs_type.getElementType(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim3 = rhs_shape[3]; - - auto reshaped_lhs = createReshapeOp(lhs, {lhs_dim0, lhs_dim1, 1, lhs_dim2}, - lhs_element_type, loc, &rewriter); - std::vector bmm_shape = {lhs_dim0, lhs_dim1, 1, rhs_dim3}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, reshaped_lhs, rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - - auto bmm_element_type = bmm_type.getElementType(); - auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim3}, - bmm_element_type, loc, &rewriter); - rewriter.replaceOp(op, {final_reshape.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::TransposeReducedDim) { - // Case "BIJ,BINJ->BIN" - // Reshape LHS - auto lhs_element_type = lhs_type.getElementType(); - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim2 = rhs_shape[2]; - - auto reshaped_lhs = createReshapeOp(lhs, {lhs_dim0, lhs_dim1, 1, lhs_dim2}, - lhs_element_type, loc, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 1, 3, 2}, &rewriter); - std::vector bmm_shape = {lhs_dim0, lhs_dim1, 1, rhs_dim2}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, reshaped_lhs, rhs, - rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - - auto bmm_element_type = bmm_type.getElementType(); - auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim2}, - bmm_element_type, loc, &rewriter); - rewriter.replaceOp(op, {final_reshape.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDReduceLast) { - // Case "acbe,aecd->abcd" - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim0 = rhs_shape[0]; - const int rhs_dim2 = rhs_shape[2]; - const int rhs_dim3 = rhs_shape[3]; - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 1, 3}, &rewriter); - std::vector bmm_shape = {rhs_dim0, rhs_dim2, lhs_dim2, rhs_dim3}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - - auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1, 3}, &rewriter); - rewriter.replaceOp(op, {trans_bmm.getResult()}); - return success(); - } - if (einsum_eqn == EinsumEquation::FourDTransposeAll) { - // Case "aecd,abcd->acbe" - const int lhs_dim0 = lhs_shape[0]; - const int lhs_dim1 = lhs_shape[1]; - const int lhs_dim2 = lhs_shape[2]; - const int rhs_dim1 = rhs_shape[1]; - // Transpose LHS - lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter); - // Transpose RHS - rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter); - std::vector bmm_shape = {lhs_dim0, lhs_dim2, lhs_dim1, rhs_dim1}; - auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); - auto bmm_op = rewriter.create( - loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - - auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 1, 3, 2}, &rewriter); - rewriter.replaceOp(op, {trans_bmm.getResult()}); - return success(); - } - - return failure(); + return rewriteToBatchMatmul(op, dnums, rewriter); } // Transform Einsum to other TF Ops for the supported variants. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index 2677f5bdfac..a18d893fac7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -40,7 +40,7 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto, VLOG(1) << "Run MLIR Graph Optimization Passes"; PassManager pm(module.getContext()); - ::tensorflow::SetCrashReproducer(pm); + ::tensorflow::applyTensorflowAndCLOptions(pm); // Run island coarsening before shape inference to allow more exact shape // inference using constant folding within islands. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 8e93a7e7470..8ab348c1e5b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -821,7 +821,7 @@ static PassRegistration pass( void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns, MLIRContext *context) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); patterns->insert(context); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc index 6686b340be9..6c1e6a827c7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -176,7 +176,48 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) { if (resource_names.empty()) return success(); - return LiftVariablesFromSession(module, session, resource_names); + if (failed(LiftVariablesFromSession(module, session, resource_names))) + return failure(); + + // Now that we have all global tensors created, we set the corresponding + // bound_inputs' types correctly. + SymbolTable symbol_table(module); + for (auto func : module.getOps()) { + for (auto arg : func.getArguments()) { + unsigned arg_number = arg.getArgNumber(); + auto global_tensor = LookupBoundInputOfType( + func, arg_number, symbol_table); + if (!global_tensor) continue; + + auto arg_type = arg.getType().cast(); + assert(arg_type.getRank() == 0); + llvm::ArrayRef underlying_type = + arg_type.getElementType().cast().getSubtypes(); + + // If the arg type already matches the global_tensor type, we don't need + // to do anything. + if (!underlying_type.empty() && + underlying_type[0] == global_tensor.type()) { + assert(underlying_type.size() == 1); + continue; + } + + // Otherwise, set this argument's type to the global_tensor's type. + auto new_arg_type = mlir::RankedTensorType::get( + /*shape=*/{}, + mlir::TF::ResourceType::get( + /*subtypes=*/{global_tensor.type().cast()}, + module.getContext())); + + arg.setType(new_arg_type); + } + + // Update the function type. + func.setType(mlir::FunctionType::get(func.getArgumentTypes(), + func.getType().getResults(), + module.getContext())); + } + return success(); } } // namespace tf_saved_model diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 1ad6bde1013..002117f15d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -780,7 +780,7 @@ void PopulateLoweringTFPatterns(MLIRContext *context, LowerDynamicStitchOp, LowerInvertPermutationOp, LowerPackOp, LowerSpaceToBatchNDOp, LowerSparseMatMulOp, Lower_UnaryOpsComposition>(context); - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 2107f84f26e..bddc863ee60 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -24,6 +24,10 @@ class GetScalarOfType : NativeCodeCall< class GetScalarOfFloatType : NativeCodeCall< "GetScalarOfFloatType(getElementTypeOrSelf($0)," # value # ")">; +def GetScalarInfOfType : NativeCodeCall< + "GetScalarOfFloatType(getElementTypeOrSelf($0), " + "std::numeric_limits::infinity())">; + def GetScalarNanOfType : NativeCodeCall< "GetScalarOfFloatType(getElementTypeOrSelf($0), " "std::numeric_limits::quiet_NaN())">; @@ -154,13 +158,22 @@ foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp], def LowerFillOp : Pat<(TF_FillOp $dims, $value), (TF_BroadcastToOp $value, $dims)>; +//===----------------------------------------------------------------------===// +// Inf op patterns. +//===----------------------------------------------------------------------===// + +def LowerIsInfOp : Pat<(TF_IsInfOp $x), + (TF_EqualOp (TF_AbsOp:$abs $x), + (TF_ConstOp:$inf (GetScalarInfOfType $x)), + /*incompatible_shape_error*/ConstBoolAttrTrue)>; + //===----------------------------------------------------------------------===// // NaN op patterns. //===----------------------------------------------------------------------===// def LowerIsNanOp : Pat<(TF_IsNanOp $x), - (TF_EqualOp $x, (TF_ConstOp:$nan (GetScalarNanOfType $x)), - /*incompatible_shape_error*/ConstBoolAttrTrue)>; + (TF_NotEqualOp $x, $x, + /*incompatible_shape_error*/ConstBoolAttrTrue)>; //===----------------------------------------------------------------------===// // L2Loss op patterns. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 47bd0027236..4abebc4475d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -199,17 +199,6 @@ LogicalResult MarkUncompilableOps( op->getContext())); outside_compiled_cluster_counter++; } - if (llvm::isa(op)) { - if (HasCapturedStringOperand(op)) { - op->setAttr( - kXlaOutsideCompilationAttr, - StringAttr::get( - llvm::formatv("auto{0}", outside_compiled_cluster_counter) - .str(), - op->getContext())); - outside_compiled_cluster_counter++; - } - } }); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 24e77d31e7c..29ecc38de0b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -37,7 +37,7 @@ struct TFOptimizePass : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); - populateWithGenerated(&getContext(), &patterns); + populateWithGenerated(&getContext(), patterns); applyPatternsAndFoldGreedily(func, patterns); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 1562e9bb0a5..3cd316cf92d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -1340,9 +1340,15 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { llvm::SmallDenseMap lifted_partitioned_call_callees; - return HoistForControlFlow(&function.front(), - cast(function.getParentOp()), - &lifted_partitioned_call_callees); + if (failed(HoistForControlFlow(&function.front(), + cast(function.getParentOp()), + &lifted_partitioned_call_callees))) + return failure(); + + // Clean up and canonicalize to remove dead local variables as some local + // variables might be dead after hoisting resource loads/stores from control + // flow ops. + return TF::CleanupAndCanonicalizeForResourceOpLifting(function); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc index fb232c7d355..b635096cc9b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -52,11 +52,11 @@ void RemoveDeadLocalVariables(Block &block) { } } for (auto local_var : local_vars) { - if (llvm::all_of(local_var.resource().getUsers(), - [](const Operation *user) { - return isa(user); - })) { - for (auto user : local_var.resource().getUsers()) user->erase(); + auto users = local_var.resource().getUsers(); + if (llvm::all_of(users, [](const Operation *user) { + return isa(user); + })) { + for (auto user : llvm::make_early_inc_range(users)) user->erase(); local_var.erase(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 01de6a89c83..680d5334ceb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -132,6 +132,15 @@ llvm::Optional> GetTensorArrayElementShape( return llvm::None; } return t; + } else if (auto scatter = + llvm::dyn_cast(user)) { + // TensorArrayScatter writes vector of tensors to TensorArray. We can + // deduce the shape of TensorArray by dropping the 0th dim of + // TensorArrayScatter `value`. + auto t = scatter.value().getType().dyn_cast(); + if (!t || t.getShape().empty()) return llvm::None; + return RankedTensorType::get(t.getShape().drop_front(), + t.getElementType()); } return llvm::None; }); @@ -139,6 +148,26 @@ llvm::Optional> GetTensorArrayElementShape( return llvm::to_vector<8>(elem_type->getShape()); } +void ReplaceAllUsesWithCast(Value old_val, Value new_val) { + if (old_val.use_empty()) return; + auto cast_op = + OpBuilder(old_val.getDefiningOp()) + .create(old_val.getLoc(), old_val.getType(), new_val); + old_val.replaceAllUsesWith(cast_op); +} + +void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) { + if (old_val.getType() == new_val.getType()) { + old_val.replaceAllUsesWith(new_val); + return; + } + Operation* old_op = old_val.getDefiningOp(); + Operation* terminator_op = + old_op->getParentOfType().front().getTerminator(); + llvm::SmallPtrSet exceptions = {terminator_op}; + old_val.replaceAllUsesExcept(new_val, exceptions); +} + struct TensorArrayStats { // Whether a write op should accumulate with the old value. Set to true if // this is a gradient. @@ -195,7 +224,8 @@ LogicalResult HandleTensorArrayReadV3Op( auto index_reshape = cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc()); auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc()); - read.value().replaceAllUsesWith(elem); + ReplaceAllUsesExceptTerminator(read.value(), elem); + ReplaceAllUsesWithCast(read.value(), elem); read.erase(); // The clear_after_read attribute does not mean setting the tensor to 0 after // read; instead it does not allow a second read before the next write. We @@ -260,7 +290,8 @@ LogicalResult HandleTensorArrayConcatV3Op( RankedTensorType::get(shape, buffer_type.getElementType())}, ArrayRef{buffer, cutil::GetR1Const(shape, builder, concat.getLoc())}); - concat.value().replaceAllUsesWith(buffer); + ReplaceAllUsesExceptTerminator(concat.value(), buffer); + ReplaceAllUsesWithCast(concat.value(), buffer); // Create the lengths as a list of the same value (element size). tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64, @@ -389,7 +420,8 @@ LogicalResult HandleTensorArrayGatherV3Op( auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc()); auto result = cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc()); - gather.value().replaceAllUsesWith(result); + ReplaceAllUsesExceptTerminator(gather.value(), result); + ReplaceAllUsesWithCast(gather.value(), result); gather.erase(); return success(); } @@ -570,6 +602,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, } stat.grads[source] = grad_var; operands.push_back(grad_var); + (*stats)[grad_var].accumulate_on_write = true; } } } @@ -636,6 +669,7 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, } stat.grads[source] = grad_var; operands.push_back(grad_var); + (*stats)[grad_var].accumulate_on_write = true; } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index f74ee5e9c1f..f7c0357a212 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -124,26 +124,39 @@ LogicalResult DecomposeTensorListOpsInternal( Block*, ModuleOp, llvm::SmallDenseMap*, llvm::StringMap*); +// Adds the corresponding sizes of tensor list buffers in block's terminator +// to the list of return values. Returns the mapping from the buffer +// indices to the added size indices, which is a list of tuples +// (buffer_return_index, size_return_index, fixed_size). +template +llvm::SmallVector, 8> +AddTensorListSizesToTerminator( + Block& block, const llvm::SmallDenseMap& buffer_to_size) { + auto old_terminator = block.getTerminator(); + auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands()); + llvm::SmallVector, 8> + output_buffer_to_size; + for (auto retval : llvm::enumerate(old_terminator->getOperands())) { + auto it = buffer_to_size.find(retval.value()); + if (it == buffer_to_size.end()) continue; + output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(), + it->getSecond().fixed); + new_outputs.push_back(it->getSecond().size); + } + OpBuilder(old_terminator) + .create(old_terminator->getLoc(), new_outputs); + old_terminator->erase(); + return output_buffer_to_size; +} + // Adds the corresponding sizes of tensor list buffers in func's return values // to the list of return values. Returns the mapping from the buffer indices to // the added size indices, which is a list of tuples (buffer_return_index, // size_return_index, fixed_size). -llvm::SmallVector, 8> -AddTensorListSizesToReturn( +llvm::SmallVector, 8> ModifyFunctionReturn( FuncOp func, const llvm::SmallDenseMap& buffer_to_size) { - auto old_return = func.front().getTerminator(); - auto new_returns = llvm::to_vector<8>(old_return->getOperands()); - llvm::SmallVector, 8> - output_buffer_to_size; - for (auto retval : llvm::enumerate(old_return->getOperands())) { - auto it = buffer_to_size.find(retval.value()); - if (it == buffer_to_size.end()) continue; - output_buffer_to_size.emplace_back(retval.index(), new_returns.size(), - it->getSecond().fixed); - new_returns.push_back(it->getSecond().size); - } - OpBuilder(old_return).create(old_return->getLoc(), new_returns); - old_return->erase(); + auto output_buffer_to_size = + AddTensorListSizesToTerminator(func.front(), buffer_to_size); UpdateFuncType(func); return output_buffer_to_size; } @@ -172,7 +185,7 @@ LogicalResult HandleWhileOp( decomposed_partitioned_call_callees))) { return failure(); } - auto output_buffer_to_size = AddTensorListSizesToReturn(body, body_map); + auto output_buffer_to_size = ModifyFunctionReturn(body, body_map); // Rewrite cond. auto cond = while_op.cond_function(); @@ -240,9 +253,9 @@ LogicalResult HandleCaseOrIfOp( const bool arg_no_changed = branch_maps.front().empty(); auto output_buffer_to_size = - AddTensorListSizesToReturn(branches.front(), branch_maps.front()); + ModifyFunctionReturn(branches.front(), branch_maps.front()); for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1)) - AddTensorListSizesToReturn(std::get<0>(pair), std::get<1>(pair)); + ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair)); if (output_buffer_to_size.empty() && arg_no_changed) return success(); @@ -266,6 +279,158 @@ LogicalResult HandleCaseOrIfOp( return success(); } +LogicalResult HandleWhileRegionOp( + TF::WhileRegionOp while_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::StringMap* + decomposed_partitioned_call_callees) { + OpBuilder builder(while_op); + auto modify_region_arguments = [&](Region& region) { + int64_t original_arg_count = region.getNumArguments(); + for (int64_t i = 0; i < original_arg_count; ++i) { + auto operand = while_op.getOperand(i); + auto it = buffer_to_size->find(operand); + if (it == buffer_to_size->end()) continue; + auto buffer_type = it->getFirst().getType(); + region.getArgument(i).setType(buffer_type); + auto size_arg = region.addArgument(cutil::GetSizeType(builder)); + (*buffer_to_size)[region.getArgument(i)] = {size_arg, + it->getSecond().fixed}; + } + }; + + // Rewrite body. + Region& body_region = while_op.body(); + modify_region_arguments(body_region); + if (failed(DecomposeTensorListOpsInternal( + &body_region.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + auto output_buffer_to_size = AddTensorListSizesToTerminator( + body_region.front(), *buffer_to_size); + + // Rewrite cond. + Region& cond_region = while_op.cond(); + modify_region_arguments(cond_region); + if (failed(DecomposeTensorListOpsInternal( + &cond_region.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + + if (output_buffer_to_size.empty()) return success(); + + // Create the new while op. + auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); + for (int64_t i = 0; i < while_op.getNumResults(); ++i) { + auto it = buffer_to_size->find(while_op.getOperand(i)); + if (it == buffer_to_size->end()) continue; + new_while_operands.push_back(it->getSecond().size); + } + auto new_while = builder.create( + while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(), + new_while_operands, while_op.getAttrs()); + new_while.body().takeBody(body_region); + new_while.cond().takeBody(cond_region); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = { + new_while.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + while_op.replaceAllUsesWith( + new_while.getResults().take_front(while_op.getNumResults())); + while_op.erase(); + return success(); +} + +LogicalResult HandleIfRegionOp( + TF::IfRegionOp if_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::StringMap* + decomposed_partitioned_call_callees) { + // Rewrite the branches. + Region& then_branch = if_op.then_branch(); + Region& else_branch = if_op.else_branch(); + if (failed(DecomposeTensorListOpsInternal( + &then_branch.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) + return failure(); + if (failed(DecomposeTensorListOpsInternal( + &else_branch.front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) + return failure(); + + auto output_buffer_to_size = AddTensorListSizesToTerminator( + then_branch.front(), *buffer_to_size); + AddTensorListSizesToTerminator(else_branch.front(), + *buffer_to_size); + + if (output_buffer_to_size.empty()) return success(); + + // Recreate the op. + auto new_op = OpBuilder(if_op).create( + if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(), + if_op.getOperand(), if_op.getAttrs()); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = { + new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + + new_op.then_branch().takeBody(if_op.then_branch()); + new_op.else_branch().takeBody(if_op.else_branch()); + + if_op.replaceAllUsesWith( + new_op.getResults().take_front(if_op.getNumResults())); + if_op.erase(); + return success(); +} + +LogicalResult HandleCaseRegionOp( + TF::CaseRegionOp case_op, ModuleOp module, + llvm::SmallDenseMap* buffer_to_size, + llvm::StringMap* + decomposed_partitioned_call_callees) { + // Rewrite the branches. + RegionRange branches = case_op.getRegions(); + + for (Region* branch : branches) { + if (failed(DecomposeTensorListOpsInternal( + &branch->front(), module, buffer_to_size, + decomposed_partitioned_call_callees))) + return failure(); + } + + // Get the output buffer index to size index mapping one of the branches. It + // should be same for all the branches so we only get it for the first branch. + Region* first_branch = branches.front(); + auto output_buffer_to_size = AddTensorListSizesToTerminator( + first_branch->front(), *buffer_to_size); + for (Region* branch : branches.drop_front()) { + AddTensorListSizesToTerminator(branch->front(), + *buffer_to_size); + } + + if (output_buffer_to_size.empty()) return success(); + + // Recreate the op. + auto new_op = OpBuilder(case_op).create( + case_op.getLoc(), + first_branch->front().getTerminator()->getOperandTypes(), + case_op.getOperand(), case_op.getAttrs(), case_op.getNumRegions()); + for (const auto& entry : output_buffer_to_size) { + (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = { + new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; + } + + for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) { + std::get<0>(pair)->takeBody(*std::get<1>(pair)); + } + case_op.replaceAllUsesWith( + new_op.getResults().take_front(case_op.getNumResults())); + case_op.erase(); + return success(); +} + template LogicalResult HandlePartitionedCallOp( CallOp call, FuncOp callee, ModuleOp module, @@ -336,7 +501,7 @@ LogicalResult HandlePartitionedCallOp( return failure(); } info.buffer_ret_to_size_ret = - AddTensorListSizesToReturn(lowered_callee, callee_map); + ModifyFunctionReturn(lowered_callee, callee_map); info.decomposed_callee = lowered_callee; if (args_no_changed && info.buffer_ret_to_size_ret.empty()) { // Signature is not modified. We do not need to keep two copies. @@ -730,6 +895,21 @@ LogicalResult DecomposeTensorListOpsInternal( decomposed_partitioned_call_callees))) { return failure(); } + } else if (auto while_op = llvm::dyn_cast(&op)) { + if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto if_op = llvm::dyn_cast(&op)) { + if (failed(HandleIfRegionOp(if_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } + } else if (auto case_op = llvm::dyn_cast(&op)) { + if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size, + decomposed_partitioned_call_callees))) { + return failure(); + } } } return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc index 786c4b74b34..f2321df9823 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc @@ -58,7 +58,7 @@ struct FuseParallelMapAndBatch : public OpRewritePattern { void PopulateTFDataOptimizationPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index 7cc31d82324..65490716cf0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -124,6 +124,10 @@ class ControlFlowStackInfo { Operation* GetCallSiteOp() const { return callsite_op_; } + bool operator==(const ControlFlowStackInfo& other) const { + return type_ == other.type_ && callsite_op_ == other.callsite_op_; + } + private: ControlFlowBranchType type_; @@ -280,17 +284,56 @@ Value CreateCompilationKeyPlaceholder(Location loc, OpBuilder* builder) { loc, /*program=*/result_type, llvm::ArrayRef{}); } +// Retrieves terminator of branch specified by `control_flow_info` of replicated +// control flow op. +Operation* GetControlFlowBranchRegionTerminator( + const ControlFlowStackInfo& controlflow_info, Operation* controlflow_op) { + if (auto inner_most_if = llvm::dyn_cast(controlflow_op)) { + if (controlflow_info.GetBranchType() == ControlFlowStackInfo::kIfThen) { + return inner_most_if.then_branch().front().getTerminator(); + } else { + return inner_most_if.else_branch().front().getTerminator(); + } + } else if (auto inner_most_while = + llvm::dyn_cast(controlflow_op)) { + if (controlflow_info.GetBranchType() == ControlFlowStackInfo::kWhileCond) { + return &inner_most_while.cond().front().front(); + } else { + return inner_most_while.body().front().getTerminator(); + } + } + assert(false); + return nullptr; +} + // Replicates the control flow operations that wraps outside compiled ops to // `destination_block`. -Operation* ReplicateControlFlowStack( +Operation* GetOrReplicateControlFlowStack( llvm::StringRef outside_cluster_name, const llvm::SmallVectorImpl& stack_info, tf_device::ClusterOp tpu_cluster, ModuleOp module, Value compilation_key, - Block* destination_block, int* send_recv_counter) { - assert(stack_info.size()); + Block* destination_block, int* send_recv_counter, + llvm::SmallDenseMap* replicated_controlflow_map) { + assert(!stack_info.empty()); + const auto& controlflow_info = stack_info.back(); + auto it = replicated_controlflow_map->find(controlflow_info.GetCallSiteOp()); + if (it != replicated_controlflow_map->end()) + return GetControlFlowBranchRegionTerminator(controlflow_info, it->second); + OpBuilder builder = OpBuilder::atBlockTerminator(destination_block); Operation* previous_replicated_controlflow_op = nullptr; for (const auto& controlflow_stack_info : stack_info) { + // If controlflow operation has already been created, reuse the cached + // controlflow operation. + auto it = replicated_controlflow_map->find( + controlflow_stack_info.GetCallSiteOp()); + if (it != replicated_controlflow_map->end()) { + previous_replicated_controlflow_op = it->second; + builder.setInsertionPoint(GetControlFlowBranchRegionTerminator( + controlflow_stack_info, previous_replicated_controlflow_op)); + continue; + } + // Create control flow op given provided insertion point and // ControlFlowStackInfo. if (auto control_flow_op = llvm::dyn_cast( @@ -325,28 +368,13 @@ Operation* ReplicateControlFlowStack( } } + replicated_controlflow_map->try_emplace(stack_info.back().GetCallSiteOp(), + previous_replicated_controlflow_op); + // Return operation which should be used to as the insertion point to create // send/recv ops. - if (auto inner_most_if = - llvm::dyn_cast(previous_replicated_controlflow_op)) { - auto inner_most_controlflow_stack = stack_info.back(); - if (inner_most_controlflow_stack.GetBranchType() == - ControlFlowStackInfo::kIfThen) { - return inner_most_if.then_branch().front().getTerminator(); - } else { - return inner_most_if.else_branch().front().getTerminator(); - } - } else if (auto inner_most_while = llvm::dyn_cast( - previous_replicated_controlflow_op)) { - auto inner_most_controlflow_stack = stack_info.back(); - if (inner_most_controlflow_stack.GetBranchType() == - ControlFlowStackInfo::kWhileCond) { - return &inner_most_while.cond().front().front(); - } else { - return inner_most_while.body().front().getTerminator(); - } - } - return destination_block->getTerminator(); + return GetControlFlowBranchRegionTerminator( + stack_info.back(), previous_replicated_controlflow_op); } // Collects and clusters ops in `block` with the same `_xla_outside_compilation` @@ -426,6 +454,7 @@ llvm::SmallSetVector GetExternalOperands( bool is_external = false; if (defining_op) { if (!tpu_cluster.getOperation()->isAncestor(defining_op)) continue; + is_external = llvm::none_of(host_cluster_ops, [&](Operation* cluster_op) { return defining_op == cluster_op; @@ -474,7 +503,7 @@ llvm::SmallSetVector GetExternalOperands( } // Extracts all externally used outputs of `cluster_ops`. -llvm::SmallVector GetExternalOutputs( +llvm::SmallSetVector GetExternalOutputs( llvm::ArrayRef cluster_ops) { llvm::SmallSetVector external_outputs; llvm::SmallPtrSet host_cluster_ops_set; @@ -496,7 +525,7 @@ llvm::SmallVector GetExternalOutputs( } } - return external_outputs.takeVector(); + return external_outputs; } // Sets the insertion point on `builder` for HostCompute op. Sets insertion @@ -543,52 +572,62 @@ TF::_XlaHostComputeMlirOp CreateHostCompute( return host_compute; } -void MoveOutsideCompiledOps( +// Represents a set of ops inside host computation that is wrapped inside the +// same control flow. +struct HostCluster { + // List of control flow that wraps host computation operations. May be empty. + llvm::SmallVector controlflow_stack; + + // Set of operations that will run on host wrapped around same stack of + // control flow. + llvm::SmallVector section_ops; +}; + +HostCluster* FindHostCluster( + llvm::SmallVectorImpl& host_cluster_sections, + const llvm::SmallVector& control_flows) { + for (auto& section : host_cluster_sections) + if (control_flows == section.controlflow_stack) return §ion; + return nullptr; +} + +void MoveOutsideCompiledOpsInsideControlFlow( ModuleOp module, tf_device::ClusterOp tpu_cluster, - llvm::StringRef outside_cluster_name, tf_device::LaunchOp host_launch_op, - llvm::ArrayRef cluster_ops, - const llvm::SmallSetVector& external_inputs, - llvm::ArrayRef external_outputs) { - // Since ops in `cluster_ops` do not cross function/control flow boundary, it - // is sufficient to identify the control flow that wraps `cluster_ops` by - // looking at any arbitary op inside `cluster_ops`. - auto controlflow_stack = - GetControlFlowStackForOp(tpu_cluster, cluster_ops.front()); - - Value compilation_key; - if (!controlflow_stack.empty() || !external_inputs.empty() || - !external_outputs.empty()) { - OpBuilder builder(&host_launch_op.GetBody().front()); - compilation_key = - CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder); - } - + llvm::StringRef host_cluster_section_name, + tf_device::LaunchOp host_launch_op, Value compilation_key, + llvm::ArrayRef cluster_section_ops, + const llvm::SmallVectorImpl& controlflow_stack, + const llvm::SmallSetVector& section_external_inputs, + llvm::ArrayRef section_external_outputs, + llvm::SmallDenseMap* replicated_controlflow_map) { Operation* insertion_op = nullptr; if (controlflow_stack.empty()) { insertion_op = host_launch_op.GetBody().getTerminator(); } else { int send_recv_counter = 0; - insertion_op = ReplicateControlFlowStack( - outside_cluster_name, controlflow_stack, tpu_cluster, module, - compilation_key, &host_launch_op.GetBody(), &send_recv_counter); + insertion_op = GetOrReplicateControlFlowStack( + host_cluster_section_name, controlflow_stack, tpu_cluster, module, + compilation_key, &host_launch_op.GetBody(), &send_recv_counter, + replicated_controlflow_map); } MLIRContext* context = host_launch_op.getContext(); - if (external_inputs.empty() && external_outputs.empty()) { - MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context); + if (section_external_inputs.empty() && section_external_outputs.empty()) { + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_section_ops, context); return; } OpBuilder builder(insertion_op); llvm::SmallVector host_output_types; - for (const auto& external_input : external_inputs) + for (const auto& external_input : section_external_inputs) host_output_types.push_back(external_input.getType()); std::string args_communication_key = - llvm::formatv("host_compute_channel_{0}_args", outside_cluster_name) + llvm::formatv("host_compute_channel_{0}_args", host_cluster_section_name) .str(); std::string retvals_communication_key = - llvm::formatv("host_compute_channel_{0}_retvals", outside_cluster_name) + llvm::formatv("host_compute_channel_{0}_retvals", + host_cluster_section_name) .str(); auto recv_at_host = builder.create( @@ -597,25 +636,105 @@ void MoveOutsideCompiledOps( builder.getStringAttr(args_communication_key), /*device_ordinal=*/builder.getI64IntegerAttr(0)); - auto host_compute = CreateHostCompute( - &builder, tpu_cluster, cluster_ops, external_inputs, external_outputs, - args_communication_key, retvals_communication_key); - MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context); + auto host_compute = + CreateHostCompute(&builder, tpu_cluster, cluster_section_ops, + section_external_inputs, section_external_outputs, + args_communication_key, retvals_communication_key); + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_section_ops, context); builder.setInsertionPoint(insertion_op); builder.create( - tpu_cluster.getLoc(), external_outputs, + tpu_cluster.getLoc(), section_external_outputs, /*dynamic_key=*/compilation_key, builder.getStringAttr(retvals_communication_key), /*device_ordinal=*/builder.getI64IntegerAttr(0)); - for (auto result : llvm::zip(external_inputs, recv_at_host.getResults())) + for (auto result : + llvm::zip(section_external_inputs, recv_at_host.getResults())) { mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), - host_launch_op.body()); + *insertion_op->getParentRegion()); + } - for (auto result : llvm::zip(external_outputs, host_compute.getResults())) - mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), - tpu_cluster.body()); + for (auto result : + llvm::zip(section_external_outputs, host_compute.getResults())) { + for (auto& result_use : std::get<0>(result).getUses()) { + Operation* result_using_op = result_use.getOwner(); + const bool inside_device_cluster = + tpu_cluster.body().isAncestor(result_using_op->getParentRegion()); + if (inside_device_cluster) result_use.set(std::get<1>(result)); + } + } +} + +void MoveOutsideCompiledOps( + ModuleOp module, tf_device::ClusterOp tpu_cluster, + llvm::StringRef outside_cluster_name, tf_device::LaunchOp host_launch_op, + llvm::ArrayRef cluster_ops, + const llvm::SmallSetVector& external_inputs, + const llvm::SmallSetVector& external_outputs) { + // Identify and groups ops in `cluster_ops` by ops wrapped inside the same + // control flows. + llvm::SmallVector host_cluster_sections; + for (Operation* host_cluster_op : cluster_ops) { + auto controlflow_stack = + GetControlFlowStackForOp(tpu_cluster, host_cluster_op); + auto host_cluster_section = + FindHostCluster(host_cluster_sections, controlflow_stack); + if (!host_cluster_section) { + host_cluster_sections.emplace_back( + HostCluster{controlflow_stack, {host_cluster_op}}); + } else { + host_cluster_section->section_ops.emplace_back(host_cluster_op); + } + } + + const bool has_control_flow = + llvm::any_of(host_cluster_sections, [](const auto host_cluster_section) { + return !host_cluster_section.controlflow_stack.empty(); + }); + + Value compilation_key; + if (has_control_flow || !external_inputs.empty() || + !external_outputs.empty()) { + OpBuilder builder(&host_launch_op.GetBody().front()); + compilation_key = + CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder); + } + + // Maintains a map of control flow callsite operation in TPU device side + // and an replicated control flow operation on host cluster. + llvm::SmallDenseMap replicated_controlflows; + + // Move `cluster_op` to host cluster, replicating control flow if ops are + // wrapped inside a control flow. + for (const auto& host_cluster_section_and_index : + llvm::enumerate(host_cluster_sections)) { + const auto& host_cluster_section = host_cluster_section_and_index.value(); + const int index = host_cluster_section_and_index.index(); + + const auto& controlflow_stack = host_cluster_section.controlflow_stack; + const auto& cluster_section_ops = host_cluster_section.section_ops; + auto section_external_inputs = + GetExternalOperands(tpu_cluster, cluster_section_ops); + for (auto input : section_external_inputs) { + if (!external_inputs.contains(input)) + section_external_inputs.remove(input); + } + auto section_external_outputs = GetExternalOutputs(cluster_section_ops); + for (auto output : section_external_outputs) { + if (!external_outputs.contains(output)) + section_external_outputs.remove(output); + } + + const std::string host_cluster_section_name = + llvm::formatv("{0}_{1}", outside_cluster_name, index).str(); + + MoveOutsideCompiledOpsInsideControlFlow( + module, tpu_cluster, host_cluster_section_name, host_launch_op, + compilation_key, cluster_section_ops, controlflow_stack, + section_external_inputs, section_external_outputs.takeVector(), + &replicated_controlflows); + } } // Creates a `parallel_execute` op in place of launch with 'clusters` and diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc index 900bdf6f519..63bb53f52b5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_outside_compilation_cluster.cc @@ -41,6 +41,20 @@ struct TPUOutsideCompilationCluster const TF::SideEffectAnalysis::Info& side_effect_analysis); }; +bool IsVariant(Value value) { + return getElementTypeOrSelf(value.getType()).isa(); +} + +bool HasOutsideCompiledAncestor(Operation* op) { + Operation* parent = op->getParentOp(); + while (parent) { + if (parent->getAttrOfType(kXlaOutsideCompilationAttr)) + return true; + parent = parent->getParentOp(); + } + return false; +} + // Represents an outside compiled cluster. All ops that are added to the same // cluster will be extracted together in a later pass. class OutsideCompiledCluster { @@ -62,6 +76,53 @@ class OutsideCompiledCluster { return false; } + // If any tf.variants are inputs/outputs to the cluster, add them to the + // cluster unless they are already marks with outside compilation attribute. + bool AddVariantInputsOutputs() { + bool added_op = false; + llvm::SmallPtrSet expanded_cluster_ops(host_cluster_ops_); + for (Operation* cluster_op : host_cluster_ops_) { + // Walk the clustered operations to handle nested ops. + cluster_op->walk([&](Operation* op) { + // Add any operations that provide variant inputs to the cluster. + for (auto value : op->getOperands()) { + auto input_defining_op = value.getDefiningOp(); + if (IsVariant(value) && input_defining_op && + !HasOutsideCompiledAncestor(input_defining_op) && + !input_defining_op->getAttrOfType( + kXlaOutsideCompilationAttr)) { + expanded_cluster_ops.insert(input_defining_op); + input_defining_op->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get(cluster_name_, + input_defining_op->getContext())); + added_op = true; + } + } + // Add any operations that consume variant outputs to the cluster. + for (auto value : op->getResults()) { + if (IsVariant(value)) { + for (auto user : value.getUsers()) { + if (!host_cluster_ops_.contains(user) && + !HasOutsideCompiledAncestor(user) && + !user->getAttrOfType( + kXlaOutsideCompilationAttr)) { + expanded_cluster_ops.insert(user); + user->setAttr( + kXlaOutsideCompilationAttr, + StringAttr::get(cluster_name_, user->getContext())); + added_op = true; + } + } + } + } + }); + } + host_cluster_ops_.swap(expanded_cluster_ops); + + return added_op; + } + private: // Checks if it is safe for an op to be merged into this cluster. bool IsSafeToAdd(Operation* op, @@ -88,12 +149,7 @@ class OutsideCompiledCluster { op->getUsers(), [&](Operation* user) { return host_cluster_ops_.contains(user); }); - const bool inside_same_block = - llvm::all_of(host_cluster_ops_, [&](Operation* op_in_cluster) { - return op_in_cluster->getBlock() == op->getBlock(); - }); - - return inside_same_block && contains_data_dependency; + return contains_data_dependency; } // `host_cluster_op_` stores a set of ops that will be grouped and computed @@ -132,6 +188,10 @@ void TPUOutsideCompilationCluster::runOnFunction( } } }); + for (auto& cluster : clusters) { + bool variants_to_add = true; + while (variants_to_add) variants_to_add = cluster.AddVariantInputsOutputs(); + } } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc deleted file mode 100644 index e4c965b6cb1..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/Signals.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/TableGen/Error.h" -#include "llvm/TableGen/Main.h" -#include "llvm/TableGen/Record.h" -#include "llvm/TableGen/TableGenBackend.h" -#include "mlir/TableGen/Operator.h" // from @llvm-project - -using llvm::LessRecord; -using llvm::raw_ostream; -using llvm::Record; -using llvm::RecordKeeper; -using mlir::tblgen::Operator; - -// Helper macro that returns indented os. -#define OUT(X) os.indent((X)) - -// Emits TensorFlow derived attribute populator functions for each of the ops. -static void EmitOpAttrPopulators(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - for (const auto &op : ops) { - // TODO(hinsu): Introduce a derived attribute property for ops with no - // type attributes. That way an error can be generated if no derived type - // attribute or the property is set. This will make sure derived type - // attributes are not omitted by mistake. - - // Emit function signature. - auto op_name = op.getCppClassName(); - OUT(0) << "static Status Populate" << op_name - << "DerivedAttrs(mlir::TF::" << op_name - << "& op, AttrValueMap *values) {\n"; - - for (const auto &named_attr : op.getAttributes()) { - auto attr_name = named_attr.name; - const auto &attr = named_attr.attr; - if (!attr.isDerivedAttr()) continue; - auto retType = attr.getReturnType(); - if (retType == "ShapedType" || retType == "mlir::TF::OperandShapeRange" || - retType == "mlir::TF::ResultShapeRange") { - OUT(2) << "TF_RETURN_IF_ERROR(SetShapeAttribute(\"" << attr_name - << "\", op." << attr_name << "(), values));\n"; - } else if (retType == "Type" || - retType == "mlir::OperandElementTypeRange" || - retType == "mlir::ResultElementTypeRange") { - OUT(2) << "TF_RETURN_IF_ERROR(SetTypeAttribute(\"" << attr_name - << "\", op." << attr_name << "(), values));\n"; - } else if (attr.isSubClassOf("TF_DerivedOperandSizeAttr") || - attr.isSubClassOf("TF_DerivedResultSizeAttr")) { - OUT(2) << "TF_RETURN_IF_ERROR(SetSizeAttribute(\"" << attr_name - << "\", op." << attr_name << "(), values));\n"; - } else { - PrintFatalError(op.getLoc(), - "NYI: attribute populator for derived attributes"); - } - } - - OUT(2) << "return Status::OK();\n"; - OUT(0) << "}\n\n"; - } -} - -// Emits TensorFlow derived attribute populator function taking an Operation -// as argument. -static void EmitInstAttrPopulator(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - // Emit function signature. - OUT(0) << "static Status PopulateDerivedAttrs(mlir::Operation* op, " - "AttrValueMap* values) {\n"; - - for (const auto &op : ops) { - auto op_name = op.getCppClassName(); - - // Emit conditional for the op and then call populator for the op on match. - OUT(2) << "if (auto tfOp = llvm::dyn_cast(op)) {\n"; - OUT(4) << "TF_RETURN_IF_ERROR(Populate" << op_name - << "DerivedAttrs(tfOp, values));\n"; - OUT(2) << "}\n"; - } - OUT(2) << "return Status::OK();\n"; - OUT(0) << "}\n\n"; -} - -// Emits TensorFlow derived attribute name collector functions for each of the -// ops. -static void EmitOpAttrNameCollector(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - for (const auto &op : ops) { - // Emit function signature. - auto op_name = op.getCppClassName(); - OUT(0) << "static void Collect" << op_name - << "DerivedAttrsName(mlir::TF::" << op_name - << "& op, llvm::SmallDenseSet* values) {\n"; - - // Insert the name for each derived attribute in the set. - for (const auto &named_attr : op.getAttributes()) { - auto attr_name = named_attr.name; - const auto &attr = named_attr.attr; - if (!attr.isDerivedAttr()) continue; - OUT(2) << "values->insert(\"" << attr_name << "\");\n"; - } - - OUT(2) << "return;\n"; - OUT(0) << "}\n\n"; - } -} - -// Emits TensorFlow derived attribute name collector function taking an -// Operation as argument. -static void EmitInstAttrNameCollector(const std::vector &ops, - raw_ostream *ostream) { - raw_ostream &os = *ostream; - - // Emit function signature. - OUT(0) << "static void CollectDerivedAttrsName(mlir::Operation* op, " - "llvm::SmallDenseSet* values) {\n"; - - for (const auto &op : ops) { - auto op_name = op.getCppClassName(); - - // Emit conditional for the op and then call collect for the op on match. - OUT(2) << "if (auto tf_op = llvm::dyn_cast(op)) {\n"; - OUT(4) << "Collect" << op_name << "DerivedAttrsName(tf_op, values);\n"; - OUT(2) << "}\n"; - } - OUT(2) << "return;\n"; - OUT(0) << "}\n\n"; -} - -// The function below has a non-constant reference as that is required by LLVM's -// TableGenMain. -// NOLINTNEXTLINE -static bool DerivedAttrWritersMain(raw_ostream &os, RecordKeeper &records) { - emitSourceFileHeader("MLIR Derived TensorFlow Attributes Populators", os); - - // Retrieve all the definitions derived from TF_Op and sort by record name. - std::vector defs = records.getAllDerivedDefinitions("TF_Op"); - llvm::sort(defs, LessRecord()); - - std::vector ops; - ops.reserve(defs.size()); - - // Wrap TensorFlow op definitions into tblgen Operator wrapper and verify - // them. - for (const auto *def : defs) { - ops.emplace_back(Operator(def)); - - const Operator &op = ops.back(); - if (op.getDialectName() != "tf") - PrintFatalError(op.getLoc(), - "unexpected op name format: 'TF_' prefix missing"); - if (!op.getCppClassName().endswith("Op")) - PrintFatalError(op.getLoc(), - "unexpected op name format: 'Op' suffix missing"); - } - - EmitOpAttrPopulators(ops, &os); - EmitInstAttrPopulator(ops, &os); - - EmitOpAttrNameCollector(ops, &os); - EmitInstAttrNameCollector(ops, &os); - - return false; -} - -int main(int argc, char **argv) { - llvm::InitLLVM y(argc, argv); - llvm::cl::ParseCommandLineOptions(argc, argv); - return TableGenMain(argv[0], &DerivedAttrWritersMain); -} diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 0445dbb698a..c69e802994d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -244,6 +244,7 @@ StatusOr> Exporter::GetArgumentNode( func.getArgAttrs(index); absl::flat_hash_set attrs_to_ignore = {kDeviceAttr}; TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, + /*remove_ref_type=*/false, node_def->mutable_attr())); return node_def; @@ -661,8 +662,9 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, grad_string.data(), stateful_string.data()}; llvm::SmallVector funcAttrs( function.getDialectAttrs()); - TF_RETURN_IF_ERROR( - ConvertAttributes(funcAttrs, attrs_to_ignore, func_def.mutable_attr())); + TF_RETURN_IF_ERROR(ConvertAttributes(funcAttrs, attrs_to_ignore, + /*remove_ref_type=*/false, + func_def.mutable_attr())); for (int i = 0, e = function.getNumArguments(); i < e; ++i) { if (auto resource_arg_unique_id_attr = @@ -679,6 +681,7 @@ Status Exporter::ConvertLibFunction(const GraphExportConfig& configs, kDeviceAttr, kResourceArgUniqueIdAttr}; FunctionDef::ArgAttrs func_def_arg_i_attrs; TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, + /*remove_ref_type=*/false, func_def_arg_i_attrs.mutable_attr())); if (func_def_arg_i_attrs.attr().empty()) continue; (*func_def.mutable_arg_attr())[i] = std::move(func_def_arg_i_attrs); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index 727831a6055..0057e498cea 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -21,10 +21,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSet.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -84,22 +87,12 @@ Status SetShapeAttribute(absl::string_view name, ContainerT shapes, return Status::OK(); } -// Include the auto generated derived attribute populator function taking -// TensorFlow dialect operation as an argument. This file contains the function -// definitions and isn't a header file. -#include "tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator.inc" - // Collects all the unregistered attributes for an TF dialect operation. // Attributes "name" and "device" are not included because they are not part // of an TF op attributes. Status GetUnregisteredAttrs( - mlir::Operation* inst, + mlir::Operation* inst, const tensorflow::OpRegistrationData* op_reg_data, absl::flat_hash_set* attrs_to_ignore) { - TF_ASSIGN_OR_RETURN(auto op_name, - GetTensorFlowOpName(inst->getName().getStringRef())); - - const tensorflow::OpRegistrationData* op_reg_data = - tensorflow::OpRegistry::Global()->LookUp(std::string(op_name)); if (!op_reg_data) { // This is likely a function call node, so we should continue. return Status::OK(); @@ -125,19 +118,24 @@ Status GetUnregisteredAttrs( // Collects all attribute names to ignore in an MLIR operation when exporting to // a TensorFlow NodeDef. StatusOr> GetAttributesToIgnore( - mlir::Operation* inst, bool ignore_unregistered_attrs) { + mlir::Operation* inst, mlir::DictionaryAttr derived_attrs, + const tensorflow::OpRegistrationData* op_reg_data, + bool ignore_unregistered_attrs) { // The elements are owned by the MLIRContext. absl::flat_hash_set attrs_to_ignore; - if (inst->isRegistered()) { - // We ignore attributes attached to the operation when there is already a - // derived attribute defined in ODS. - llvm::SmallDenseSet derived_attrs; - CollectDerivedAttrsName(inst, &derived_attrs); - for (auto name : derived_attrs) attrs_to_ignore.insert(name.data()); + + // We ignore attributes attached to the operation when there is already a + // derived attribute defined in ODS. + if (derived_attrs) { + for (auto derived_attr : derived_attrs) { + attrs_to_ignore.insert( + mlir::StringRefToView(derived_attr.first.strref())); + } } if (ignore_unregistered_attrs) { - TF_RETURN_IF_ERROR(GetUnregisteredAttrs(inst, &attrs_to_ignore)); + TF_RETURN_IF_ERROR( + GetUnregisteredAttrs(inst, op_reg_data, &attrs_to_ignore)); } if (inst->hasTrait()) { @@ -154,25 +152,24 @@ StatusOr> GetAttributesToIgnore( attrs_to_ignore.insert(attr_name.data()); } + if (llvm::isa(inst)) + attrs_to_ignore.insert("is_stateless"); + return attrs_to_ignore; } // Populates all derived attributes of a MLIR operation in a proto // map. -Status PopulateDerivedAttributes(mlir::Operation* inst, +Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name, + mlir::DictionaryAttr derived_attrs, bool ignore_unregistered_attrs, AttrValueMap* attributes) { - // Use auto generated function to populate derived attribute. - // - // Note: This only populates derived attributes for TensorFlow ops that are - // generated using the TableGen. Manually defined ops should have all the - // attributes present as native MLIR op attributes. - - // If the operation is not registered, we won't be able to infer any attribute - if (inst->isRegistered()) { - TF_RETURN_WITH_CONTEXT_IF_ERROR(PopulateDerivedAttrs(inst, attributes), - "When populating derived attrs for ", - inst->getName().getStringRef().str()); + if (derived_attrs) { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertAttributes(derived_attrs.getValue(), /*attrs_to_ignore=*/{}, + /*remove_ref_type=*/true, attributes), + "while converting derived attributes for node: ", + mlir::StringRefToView(name)); } // Here we only add the shapes for the leading values with ShapedType, @@ -197,28 +194,36 @@ Status PopulateDerivedAttributes(mlir::Operation* inst, } // namespace -Status GetAttrValuesFromOperation(mlir::Operation* inst, llvm::StringRef name, - bool ignore_unregistered_attrs, - AttrValueMap* attributes) { +Status GetAttrValuesFromOperation( + mlir::Operation* inst, llvm::StringRef name, + const tensorflow::OpRegistrationData* op_reg_data, + bool ignore_unregistered_attrs, AttrValueMap* attributes) { + mlir::DictionaryAttr derived_attrs = nullptr; + if (auto interface = llvm::dyn_cast(inst)) + derived_attrs = interface.materializeDerivedAttributes(); TF_ASSIGN_OR_RETURN(auto attrs_to_ignore, - GetAttributesToIgnore(inst, ignore_unregistered_attrs)); + GetAttributesToIgnore(inst, derived_attrs, op_reg_data, + ignore_unregistered_attrs)); TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertAttributes(inst->getAttrs(), attrs_to_ignore, attributes), - "while converting attributes for node: ", name.str()); - TF_RETURN_IF_ERROR( - PopulateDerivedAttributes(inst, ignore_unregistered_attrs, attributes)); + ConvertAttributes(inst->getAttrs(), attrs_to_ignore, + /*remove_ref_type=*/false, attributes), + "while converting attributes for node: ", mlir::StringRefToView(name)); + TF_RETURN_IF_ERROR(PopulateDerivedAttributes( + inst, name, derived_attrs, ignore_unregistered_attrs, attributes)); return Status::OK(); } StatusOr> ConvertTFDialectOpToNodeDef( mlir::Operation* inst, llvm::StringRef name, bool ignore_unregistered_attrs) { - TF_ASSIGN_OR_RETURN(auto attrs_to_ignore, - GetAttributesToIgnore(inst, ignore_unregistered_attrs)); - TF_ASSIGN_OR_RETURN(auto node_def, - GetOperationNodeDef(attrs_to_ignore, inst, name)); - TF_RETURN_IF_ERROR(PopulateDerivedAttributes(inst, ignore_unregistered_attrs, - node_def->mutable_attr())); + TF_ASSIGN_OR_RETURN(auto node_def, GetOperationNodeDef(inst, name)); + TF_ASSIGN_OR_RETURN(auto op_name, + GetTensorFlowOpName(inst->getName().getStringRef())); + const tensorflow::OpRegistrationData* op_reg_data = + tensorflow::OpRegistry::Global()->LookUp(op_name.str()); + TF_RETURN_IF_ERROR(GetAttrValuesFromOperation(inst, name, op_reg_data, + ignore_unregistered_attrs, + node_def->mutable_attr())); return node_def; } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index bd260171a86..6341b14fe7b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -29,9 +30,10 @@ namespace tensorflow { // Extracts the attributes of a MLIR operation and populates the converted // attributes in a proto map. -Status GetAttrValuesFromOperation(mlir::Operation* inst, llvm::StringRef name, - bool ignore_unregistered_attrs, - AttrValueMap* attributes); +Status GetAttrValuesFromOperation( + mlir::Operation* inst, llvm::StringRef name, + const tensorflow::OpRegistrationData* op_reg_data, + bool ignore_unregistered_attrs, AttrValueMap* attributes); // Converts a MLIR operation to TensorFlow NodeDef with given node name. This // name should be unique to the graph it is being inserted to. If the diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 0b6500bece3..42ce5c533a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -74,6 +74,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" @@ -310,21 +311,6 @@ class ImporterBase { return ::tensorflow::ConvertTensorProto(value, &builder_); } - // Converts the tensor shape proto into an MLIR shape attribute. - StatusOr ConvertTensorShapeProto( - const TensorShapeProto& shape) { - if (shape.unknown_rank()) - return mlir::TF::ShapeAttr::get(builder_.getContext(), llvm::None); - - llvm::SmallVector dims; - dims.reserve(shape.dim().size()); - for (const auto& dim : shape.dim()) { - dims.push_back(dim.size()); - } - return mlir::TF::ShapeAttr::get(builder_.getContext(), - llvm::makeArrayRef(dims)); - } - // Converts func name in graphdef to mlir::SymbolRefAttribute. StatusOr ConvertFunctionCallName( const std::string& func_name); @@ -1142,74 +1128,36 @@ StatusOr ImporterBase::ConvertFunctionCallName( StatusOr ImporterBase::ConvertAttributeValue( const AttrValue& value) { switch (value.value_case()) { - case AttrValue::kI: - return builder_.getI64IntegerAttr(value.i()); - case AttrValue::kS: - return builder_.getStringAttr(value.s()); - case AttrValue::kF: - return builder_.getFloatAttr(builder_.getF32Type(), value.f()); - case AttrValue::kB: - return builder_.getBoolAttr(value.b()); - case AttrValue::kType: { - mlir::Type type; - TF_RETURN_IF_ERROR(ConvertDataType(value.type(), builder_, &type)); - return mlir::TypeAttr::get(type); - } - case AttrValue::kShape: - return ConvertTensorShapeProto(value.shape()); - case AttrValue::kTensor: - return ConvertTensorProto(value.tensor()); - case AttrValue::kList: { - absl::InlinedVector attrs; - for (const auto& item : value.list().i()) - attrs.push_back(builder_.getI64IntegerAttr(item)); - for (const auto& item : value.list().s()) - attrs.push_back(builder_.getStringAttr(item)); - for (const auto& item : value.list().f()) - attrs.push_back(builder_.getFloatAttr(builder_.getF32Type(), item)); - for (const auto& item : value.list().b()) - attrs.push_back(builder_.getBoolAttr(item)); - for (const auto& item : value.list().type()) { - mlir::Type type; - TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), builder_, &type)); - attrs.push_back(mlir::TypeAttr::get(type)); - } - for (const auto& item : value.list().shape()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorShapeProto(item)); - attrs.push_back(attr); - } - for (const auto& item : value.list().tensor()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item)); - attrs.push_back(attr); - } - for (const auto& item : value.list().func()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); - if (item.attr_size() != 0) - return errors::Unimplemented( - "func attributes with non-zero attr.size()"); - attrs.push_back(attr); - } - return builder_.getArrayAttr( - llvm::makeArrayRef(attrs.begin(), attrs.end())); - } case AttrValue::kFunc: { // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue // will not use this representation. NamedAttrList attrs; for (const auto& func_attr : value.func().attr()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(func_attr.second)); + TF_ASSIGN_OR_RETURN( + auto attr, ImporterBase::ConvertAttributeValue(func_attr.second)); attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); } auto func_attrs = builder_.getDictionaryAttr(attrs); return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); } - case AttrValue::VALUE_NOT_SET: - return builder_.getUnitAttr(); - // kPlaceholder is not implemented. + case AttrValue::kList: { + if (!value.list().func().empty()) { + absl::InlinedVector attrs; + for (const auto& item : value.list().func()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); + if (item.attr_size() != 0) + return errors::Unimplemented( + "func attributes with non-zero attr.size()"); + attrs.push_back(attr); + } + return builder_.getArrayAttr( + llvm::makeArrayRef(attrs.begin(), attrs.end())); + } + return ConvertNonFuncAttributeValue(value, &builder_); + } default: - return errors::Unimplemented( - absl::StrCat("Attribute ", value.DebugString())); + return ConvertNonFuncAttributeValue(value, &builder_); } } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 824828052c0..b55a5aa5243 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -274,9 +274,13 @@ void CreateConvertMlirToXlaHloPipeline( mlir::OpPassManager& pm, llvm::StringRef device_type, llvm::MutableArrayRef> custom_legalization_passes) { - pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); + + // TODO(b/159127949): Stack and TensorArray decomposition passes do not handle + // region based control flow yet. So convert back to functional control flow. + pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); pm.addPass(mlir::TF::CreateStackOpsDecompositionPass()); pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass()); pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); @@ -324,7 +328,7 @@ Status ConvertMLIRToXlaComputation( llvm::MutableArrayRef> custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); - SetCrashReproducer(tf2xla); + applyTensorflowAndCLOptions(tf2xla); CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, custom_legalization_passes); @@ -513,7 +517,7 @@ Status CompileGraphToXlaHlo( } mlir::PassManager pm(module_op.getContext()); - SetCrashReproducer(pm); + applyTensorflowAndCLOptions(pm); mlir::TF::StandardPipelineOptions tf_options; mlir::TF::CreateTFStandardPipeline(pm, tf_options); { @@ -530,8 +534,9 @@ Status CompileGraphToXlaHlo( Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef args, - llvm::StringRef device_type, bool use_tuple_args, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + llvm::ArrayRef control_rets, llvm::StringRef device_type, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> @@ -540,6 +545,7 @@ Status CompileGraphToXlaHlo( RegisterDialects(context.getDialectRegistry()); GraphImportConfig config; config.graph_as_function = true; + config.control_outputs = control_rets; // Disable shape inference during import as some TensorFlow op fails during // shape inference with dynamic shaped operands. This in turn causes the // import to fail. Shape inference during import is going to be removed and diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index dac1c994d03..40230de406b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -116,11 +116,11 @@ Status CompileGraphToXlaHlo( // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata // and stores them in CompilationResult. -// TODO(lyandy): Allow populating of targets/control outputs. Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef args, - llvm::StringRef device_type, bool use_tuple_args, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, + llvm::ArrayRef control_rets, llvm::StringRef device_type, + bool use_tuple_args, const FunctionLibraryDefinition& flib_def, + const GraphDebugInfo& debug_info, const XlaHelpers::ShapeRepresentationFn shape_representation_fn, XlaCompilationResult* compilation_result, llvm::MutableArrayRef> diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc new file mode 100644 index 00000000000..98bfbbe608a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +// Converts non func AttrValue proto into an MLIR attribute. Func attribute is +// exclused in this function because the function might be renamed when the +// function definition is imported. +StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, + mlir::Builder* builder) { + switch (value.value_case()) { + case AttrValue::kI: + return builder->getI64IntegerAttr(value.i()); + case AttrValue::kS: + return builder->getStringAttr(value.s()); + case AttrValue::kF: + return builder->getFloatAttr(builder->getF32Type(), value.f()); + case AttrValue::kB: + return builder->getBoolAttr(value.b()); + case AttrValue::kType: { + mlir::Type type; + TF_RETURN_IF_ERROR(ConvertDataType(value.type(), *builder, &type)); + return mlir::TypeAttr::get(type); + } + case AttrValue::kShape: + return ConvertTensorShapeProto(value.shape(), builder->getContext()); + case AttrValue::kTensor: + return ConvertTensorProto(value.tensor(), builder); + case AttrValue::kList: { + absl::InlinedVector attrs; + for (const auto& item : value.list().i()) + attrs.push_back(builder->getI64IntegerAttr(item)); + for (const auto& item : value.list().s()) + attrs.push_back(builder->getStringAttr(item)); + for (const auto& item : value.list().f()) + attrs.push_back(builder->getFloatAttr(builder->getF32Type(), item)); + for (const auto& item : value.list().b()) + attrs.push_back(builder->getBoolAttr(item)); + for (const auto& item : value.list().type()) { + mlir::Type type; + TF_RETURN_IF_ERROR(ConvertDataType(DataType(item), *builder, &type)); + attrs.push_back(mlir::TypeAttr::get(type)); + } + for (const auto& item : value.list().shape()) { + TF_ASSIGN_OR_RETURN( + auto attr, ConvertTensorShapeProto(item, builder->getContext())); + attrs.push_back(attr); + } + for (const auto& item : value.list().tensor()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item, builder)); + attrs.push_back(attr); + } + if (!value.list().func().empty()) { + return tensorflow::errors::Unimplemented( + absl::StrCat("Attribute ", value.DebugString())); + } + return builder->getArrayAttr( + llvm::makeArrayRef(attrs.begin(), attrs.end())); + } + case AttrValue::VALUE_NOT_SET: + return builder->getUnitAttr(); + // kPlaceholder is not implemented. + default: + return tensorflow::errors::Unimplemented( + absl::StrCat("Attribute ", value.DebugString())); + } +} + +StatusOr ConvertAttributeValue(const AttrValue& value, + mlir::Builder* builder) { + switch (value.value_case()) { + case AttrValue::kFunc: { + // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. + // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue + // will not use this representation. + mlir::NamedAttrList attrs; + for (const auto& func_attr : value.func().attr()) { + TF_ASSIGN_OR_RETURN(auto attr, + ConvertAttributeValue(func_attr.second, builder)); + attrs.push_back(builder->getNamedAttr(func_attr.first, attr)); + } + auto func_attrs = builder->getDictionaryAttr(attrs); + return mlir::TF::FuncAttr::get(builder->getContext(), value.func().name(), + func_attrs); + } + default: + return ConvertNonFuncAttributeValue(value, builder); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h new file mode 100644 index 00000000000..c95ed60273d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +using stream_executor::port::StatusOr; + +// Converts non func AttrValue proto into an MLIR attribute. Func attribute is +// exclused in this function because the function might be renamed when the +// function definition is imported. +StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, + mlir::Builder* builder); + +// Converts all kinds of AttrValue proto into an MLIR attribute. +StatusOr ConvertAttributeValue(const AttrValue& value, + mlir::Builder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 05e1f059029..98328212c88 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -214,6 +214,20 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { return mlir::TF::ShapeAttr::get(type.getContext(), ArrayRef()); } +// Converts the tensor shape proto into an MLIR shape attribute. +StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, + mlir::MLIRContext* context) { + if (shape.unknown_rank()) + return mlir::TF::ShapeAttr::get(context, llvm::None); + + llvm::SmallVector dims; + dims.reserve(shape.dim().size()); + for (const auto& dim : shape.dim()) { + dims.push_back(dim.size()); + } + return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims)); +} + // Converts an MLIR dense string elements attribute to a TensorFlow tensor // proto. void ConvertStringElementsAttr( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index e7cde4db936..294453ebcfd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -48,6 +48,10 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type); // Converts an MLIR shaped type to a TensorFlow shape attribute. mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type); +// Converts a TensorFlow shape attribute to an MLIR shape attribute. +StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, + mlir::MLIRContext* context); + // Converts an MLIR elements attribute to a TensorFlow tensor proto. Status ConvertToTensorProto(mlir::ElementsAttr attr, TensorProto* output_tensor); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index b9d26c9849a..6c1cab435d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -227,4 +227,10 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { pm.enableCrashReproducerGeneration(path, /*genLocalReproducer=*/false); } +void applyTensorflowAndCLOptions(mlir::PassManager& pm, + llvm::StringRef dir_path) { + mlir::applyPassManagerCLOptions(pm); + SetCrashReproducer(pm, dir_path); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h index 55eaeadc43a..133285864f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -72,6 +72,15 @@ std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, // Files" by looking up the environment to infer the directory path. void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path = ""); +// This applies both the PassManagerCLOptions provided by MLIR along with any +// tensorflow specific options. +// +// Note that this function should be in a more appropriate file, but it is +// unclear what a proper file would be as no other functions would currently be +// in the file also. +void applyTensorflowAndCLOptions(mlir::PassManager& pm, + llvm::StringRef dir_path = ""); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 67c2aebf121..cad5f2bae98 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -127,11 +127,12 @@ Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(const mlir::TF::FuncAttr& attr, AttrValue* value) { +Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type, + AttrValue* value) { TF_RETURN_IF_ERROR( ConvertAttribute(attr.GetName().cast(), value)); TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(), - /*attrs_to_ignore=*/{}, + /*attrs_to_ignore=*/{}, remove_ref_type, value->mutable_func()->mutable_attr())); return Status::OK(); } @@ -159,15 +160,18 @@ Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(mlir::Type type, AttrValue* value) { +Status ConvertAttribute(mlir::Type type, bool remove_ref_type, + AttrValue* value) { DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType(type, &dtype)); + if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype); value->set_type(dtype); return Status::OK(); } -Status ConvertAttribute(const mlir::TypeAttr& type, AttrValue* value) { - return ConvertAttribute(type.getValue(), value); +Status ConvertAttribute(const mlir::TypeAttr& type, bool remove_ref_type, + AttrValue* value) { + return ConvertAttribute(type.getValue(), remove_ref_type, value); } Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { @@ -175,7 +179,8 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { +Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, + AttrValue* value) { auto* list = value->mutable_list(); for (mlir::Attribute a : attr.getValue()) { if (auto attr = a.dyn_cast()) { @@ -215,7 +220,8 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { if (auto shaped_type = elt_type.dyn_cast()) { elt_type = shaped_type.getElementType(); } - TF_RETURN_IF_ERROR(ConvertAttribute(elt_type, &attr_val)); + TF_RETURN_IF_ERROR( + ConvertAttribute(elt_type, remove_ref_type, &attr_val)); list->add_type(attr_val.type()); } else if (auto attr = a.dyn_cast()) { AttrValue attr_val; @@ -228,18 +234,6 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { return Status::OK(); } -// Updates NodeDef constructed out of an MLIR Case/IfW/While op to map it to -// either TensorFlow StatelessX or X op depending on the additional attribute. -void UpdateCompositeOp(NodeDef* node_def) { - auto it = node_def->mutable_attr()->find("is_stateless"); - if (it != node_def->attr().end()) { - if (it->second.b()) { - *node_def->mutable_op() = "Stateless" + node_def->op(); - } - node_def->mutable_attr()->erase(it); - } -} - // Returns true if the executor/control dialect op should map to Ref node in // TensorFlow Graph. For control dialect NextIteration it uses the 1st operand // type. For executor dialect NextIteration it uses the 2nd operand type. For @@ -291,7 +285,6 @@ StatusOr GetTensorFlowOpName(llvm::StringRef op_name) { } StatusOr> GetOperationNodeDef( - const absl::flat_hash_set& attrs_to_ignore, mlir::Operation* inst, llvm::StringRef name) { auto node_def = absl::make_unique(); // Note: we do not use NodeBuilder or NodeDefBuilder as that would require @@ -321,6 +314,14 @@ StatusOr> GetOperationNodeDef( node_def->set_name(name.str()); node_def->set_op(std::string(op_name.str())); + // Update NodeDef constructed out of an MLIR Case/If/While op to map it to + // either TensorFlow StatelessX or X op depending on the additional attribute. + if (llvm::isa(inst)) { + auto stateless = inst->getAttrOfType("is_stateless"); + if (stateless && stateless.getValue()) + *node_def->mutable_op() = "Stateless" + node_def->op(); + } + // Add inputs to the NodeDef based on the number of operands. This is required // as later when edges are added to the Node using Graph::AddEdge the // associated NodeDef is not updated. @@ -331,27 +332,17 @@ StatusOr> GetOperationNodeDef( node_def->set_device(std::string(attr.getValue())); } - // Add the node attributes. - TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertAttributes(inst->getAttrs(), attrs_to_ignore, - node_def->mutable_attr()), - "while converting attributes for node: ", name.str()); - // Add the node debug info. TF_RETURN_IF_ERROR(ConvertLocation( inst->getLoc(), node_def->mutable_experimental_debug_info())); - if (node_def->op() == "Case") UpdateCompositeOp(node_def.get()); - if (node_def->op() == "If") UpdateCompositeOp(node_def.get()); - if (node_def->op() == "While") UpdateCompositeOp(node_def.get()); - return node_def; } Status ConvertAttributes( const llvm::ArrayRef attrs, const absl::flat_hash_set& attrs_to_ignore, - AttrValueMap* values) { + bool remove_ref_type, AttrValueMap* values) { AttrValueMap func_call_attrs; for (const mlir::NamedAttribute& named_attr : attrs) { auto name_strref = named_attr.first.str(); @@ -376,7 +367,7 @@ Status ConvertAttributes( continue; } if (auto func_attr = attr.dyn_cast()) { - TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, &value)); + TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); func_call_attrs[string(name)] = value; continue; } @@ -388,11 +379,13 @@ Status ConvertAttributes( TF_RETURN_IF_ERROR( llvm::TypeSwitch(attr) .Case( - [&](auto derived_attr) { - return ConvertAttribute(derived_attr, &value); - }) + mlir::StringAttr, mlir::ElementsAttr, mlir::UnitAttr, + mlir::TF::ShapeAttr>([&](auto derived_attr) { + return ConvertAttribute(derived_attr, &value); + }) + .Case([&](auto derived_attr) { + return ConvertAttribute(derived_attr, remove_ref_type, &value); + }) .Default([&](mlir::Attribute) { return errors::Unimplemented( "Unhandled attribute kind for attribute '", name_strref, @@ -419,28 +412,6 @@ Status ConvertAttributes( return Status::OK(); } -// Sets type attribute with the given name. If the attribute already exists with -// a different value, returns an error. -Status SetTypeAttribute(absl::string_view name, mlir::Type type, - AttrValueMap* values) { - DataType dtype; - TF_RETURN_IF_ERROR(ConvertScalarTypeToDataType(type, &dtype)); - if (tensorflow::IsRefType(dtype)) dtype = tensorflow::RemoveRefType(dtype); - AttrValue value; - value.set_type(dtype); - - auto result = values->insert({string(name), value}); - if (!result.second) { - DataType actual_dtype = result.first->second.type(); - if (actual_dtype != dtype) { - return errors::InvalidArgument("Expected ", DataType_Name(dtype), " '", - name, "' attribute but found ", - DataType_Name(actual_dtype)); - } - } - return Status::OK(); -} - Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, AttrValueMap* values) { tensorflow::TensorShapeProto tshape; @@ -469,26 +440,6 @@ Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shaped_type, return Status::OK(); } -Status SetSizeAttribute(absl::string_view name, size_t size, - AttrValueMap* values) { - AttrValue value; - value.set_i(size); - - auto result = values->insert({string(name), value}); - if (!result.second) { - // This should be extremely rare as it means we are adding the same - // attribute multiple times/have some redundancy in representing this - // attribute. - size_t actual_size = result.first->second.i(); - // Just check via string output as we shouldn't get here and if we do they - // should be trivially the same, else fail. - if (actual_size != size) - return errors::InvalidArgument("Expected '", name, "' attribute to be ", - size, " but found ", actual_size); - } - return Status::OK(); -} - bool IsLegacyCallInstruction(mlir::Operation* inst) { return llvm::dyn_cast(inst); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 58fe39fa4e8..d1e0fd12f26 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -50,12 +50,8 @@ Status AddTensorFlowOpPrefix(std::string); StatusOr GetTensorFlowOpName(llvm::StringRef); // Converts an MLIR operation to TensorFlow NodeDef with given node name. This -// name should be unique to the graph it is being inserted into. `op_name_func` -// is to map the op name of `inst` to its op name in TensorFlow. "name" and -// "device" attributes are ignored by default. Use attrs_to_ignore to specify -// any other attributes that should be ignored. +// name should be unique to the graph it is being inserted into. StatusOr> GetOperationNodeDef( - const absl::flat_hash_set& attrs_to_ignore, mlir::Operation* inst, llvm::StringRef name); // Converts MLIR attributes with values to their tensorflow equivalent. @@ -64,23 +60,13 @@ StatusOr> GetOperationNodeDef( Status ConvertAttributes( const llvm::ArrayRef attrs, const absl::flat_hash_set& attrs_to_ignore, - AttrValueMap* values); - -// Sets type attribute with the given name. If the attribute already exists with -// a different value, returns an error. -Status SetTypeAttribute(absl::string_view name, mlir::Type type, - AttrValueMap* values); + bool remove_ref_type, AttrValueMap* values); // Sets shape attribute with the given name. If the attribute already exists // with a different value, returns an error. Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape, AttrValueMap* values); -// Sets the given size_t value as an integer attribute with the given name. -// If the attribute already exists with a different value, returns an error. -Status SetSizeAttribute(absl::string_view name, size_t size, - AttrValueMap* values); - // Returns true if the given instruction is an mlir::TF::LegacyCallOp or the // result of such an operation transformed by the // ExecutorToControlDialectConversion pass. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc index d1815a4a88b..d82d61ecf9e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" @@ -65,7 +66,8 @@ namespace TF { namespace { -// Extracts attributes from a MLIR operation, including derived attributes. +// Extracts attributes from a MLIR operation, including derived attributes, into +// one NamedAttrList. NamedAttrList GetAllAttributesFromOperation(Operation* op) { NamedAttrList attr_list; attr_list.append(op->getAttrDictionary().getValue()); @@ -173,30 +175,30 @@ ShapedTypeComponents CreateShapedTypeComponents(InferenceContext& context, return ShapedTypeComponents(element_type); } -// Runs TensorFlow shape inference associated to the op type registered in the -// TensorFlow op registry based on Graph version, operands, and attributes. -// Invoking this shape function will invoke conversions of parameters to the -// TensorFlow Graph equivalent data structures and back to MLIR equivalent data -// structures. This does not use a natively implemented shape inference in MLIR, -// and instead is temporary until shape functions are reimplemented/migrated to -// being in MLIR instead of the TensorFlow op registry. -LogicalResult InferReturnTypeComponentsFallback( - MLIRContext* context, StringRef op_name, int64_t graph_version, - Optional location, ValueRange operands, - const NamedAttrList& attributes, OperandAsConstantFn operand_as_constant_fn, +} // namespace + +LogicalResult InferReturnTypeComponentsForTFOp( + Optional location, Operation* op, int64_t graph_version, + OperandAsConstantFn operand_as_constant_fn, OpResultAsShapeFn op_result_as_shape_fn, ResultElementTypeFn result_element_type_fn, SmallVectorImpl& inferred_return_shapes) { - assert(op_name.startswith(TensorFlowDialect::getDialectNamespace())); - // Drop the `tf.` prefix to query TF registry. - std::string op_type = - op_name.drop_front(TensorFlowDialect::getDialectNamespace().size() + 1) - .str(); + assert(op->getName().getDialect() == + TensorFlowDialect::getDialectNamespace()); + + auto op_name_or = + tensorflow::GetTensorFlowOpName(op->getName().getStringRef()); + if (!op_name_or.ok()) { + LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" + << op->getName().getStringRef() << "'.\n"); + return emitOptionalError(location, "op is unregistered"); + } + llvm::StringRef op_name = op_name_or.ConsumeValueOrDie(); // Get information from the registry and check if we have a shape function for // this op. const tensorflow::OpRegistrationData* op_reg_data = - tensorflow::OpRegistry::Global()->LookUp(op_type); + tensorflow::OpRegistry::Global()->LookUp(op_name.str()); if (!op_reg_data) { LLVM_DEBUG(llvm::dbgs() << "Skipping inference for unregistered op '" << op_name << "'.\n"); @@ -211,32 +213,26 @@ LogicalResult InferReturnTypeComponentsFallback( // Convert the operation attributes to be able to use the InferenceContext // and the TensorFlow shape function. - tensorflow::AttrValueMap converted_attributes; - NamedAttrList attributes_to_convert; - // Filter out unregistered attributes. - for (const auto& attr_def : op_reg_data->op_def.attr()) - if (auto registered_attr = attributes.get(attr_def.name())) - attributes_to_convert.set(attr_def.name(), registered_attr); - - auto attrs_status = tensorflow::ConvertAttributes( - attributes_to_convert, /*attrs_to_ignore=*/{}, &converted_attributes); - if (!attrs_status.ok()) { + tensorflow::AttrValueMap attributes; + auto attr_status = tensorflow::GetAttrValuesFromOperation( + op, op_name, op_reg_data, /*ignore_unregistered_attrs=*/true, + &attributes); + if (!attr_status.ok()) { LLVM_DEBUG(llvm::dbgs() << "Error creating attribute map for '" << op_name - << "': " << attrs_status.error_message() << "\n"); - return emitOptionalError( - location, - "failed to convert attributes to proto map"); + << "': " << attr_status.error_message() << "\n"); + return emitOptionalError(location, attr_status.error_message()); } // Collect an array with input values for constant operands and input shapes // for all the operands. - std::vector input_tensors(operands.size()); - std::vector input_shapes(operands.size()); - std::vector tensors(operands.size()); + const int num_operands = op->getNumOperands(); + std::vector input_tensors(num_operands); + std::vector input_shapes(num_operands); + std::vector tensors(num_operands); std::vector>>> - handle_shapes_and_types(operands.size()); - for (auto it : llvm::enumerate(operands)) { + handle_shapes_and_types(num_operands); + for (auto it : llvm::enumerate(op->getOperands())) { Value operand = it.value(); size_t index = it.index(); @@ -265,8 +261,7 @@ LogicalResult InferReturnTypeComponentsFallback( // Perform the shape inference using an InferenceContext with the input // shapes. This object is abstracting the information that the ShapeInference // function operates on. - InferenceContext c(graph_version, - tensorflow::AttrSlice(&converted_attributes), + InferenceContext c(graph_version, tensorflow::AttrSlice(&attributes), op_reg_data->op_def, input_shapes, input_tensors, /*input_tensors_as_shapes=*/{}, handle_shapes_and_types); auto status = c.Run(op_reg_data->shape_inference_fn); @@ -288,7 +283,7 @@ LogicalResult InferReturnTypeComponentsFallback( if (c.requested_input_tensor_as_partial_shape(input) && !input_tensors[input]) { LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n"); - auto op_result = operands[input].dyn_cast(); + auto op_result = op->getOperand(input).dyn_cast(); if (!op_result) continue; // Resize on first valid shape computed. input_tensors_as_shapes.resize(c.num_inputs()); @@ -325,7 +320,7 @@ LogicalResult InferReturnTypeComponentsFallback( auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { SmallVector subtypes; - Builder b(context); + Builder b(op->getContext()); for (const auto& shape_n_type : *handle_shapes_types) { Type element_type; auto status = @@ -335,9 +330,9 @@ LogicalResult InferReturnTypeComponentsFallback( CreateTensorType(c, shape_n_type.shape, element_type)); } if (new_element_type.isa()) { - new_element_type = TF::ResourceType::get(subtypes, context); + new_element_type = TF::ResourceType::get(subtypes, op->getContext()); } else { - new_element_type = TF::VariantType::get(subtypes, context); + new_element_type = TF::VariantType::get(subtypes, op->getContext()); } } } @@ -348,21 +343,6 @@ LogicalResult InferReturnTypeComponentsFallback( return success(); } -} // namespace - -LogicalResult InferReturnTypeComponentsForTFOp( - Optional location, Operation* op, int64_t graph_version, - OperandAsConstantFn operand_as_constant_fn, - OpResultAsShapeFn op_result_as_shape_fn, - ResultElementTypeFn result_element_type_fn, - SmallVectorImpl& inferred_return_shapes) { - auto attributes = GetAllAttributesFromOperation(op); - return InferReturnTypeComponentsFallback( - op->getContext(), op->getName().getStringRef(), graph_version, location, - op->getOperands(), attributes, operand_as_constant_fn, - op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes); -} - LogicalResult InferReturnTypeComponentsForTFOp( Optional location, Operation* op, int64_t graph_version, SmallVectorImpl& inferred_return_shapes) { diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc index a3678f7d154..5d3ee121577 100644 --- a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc @@ -50,7 +50,7 @@ void Optimize::runOnFunction() { auto *ctx = &getContext(); auto func = getFunction(); - populateWithGenerated(ctx, &patterns); + populateWithGenerated(ctx, patterns); applyPatternsAndFoldGreedily(func, patterns); } } // namespace diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 1fc3137b953..fbdc2d7d93b 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load( "//third_party/mlir:tblgen.bzl", "gentbl", @@ -17,7 +18,8 @@ package_group( includes = ["//third_party/mlir:subpackages"], packages = [ "//learning/brain/experimental/mlir/tfr/...", - "//tensorflow/compiler/mlir/...", + "//tensorflow/c/...", + "//tensorflow/compiler/...", ], ) @@ -163,3 +165,191 @@ filegroup( "@llvm-project//llvm:not", ], ) + +cc_library( + name = "tfr_decompose_ctx", + srcs = ["integration/tfr_decompose_ctx.cc"], + hdrs = ["integration/tfr_decompose_ctx.h"], + deps = [ + ":passes", + ":tfr", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_attr", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", + ], +) + +tf_cc_test( + name = "tfr_decompose_ctx_test", + srcs = ["integration/tfr_decompose_ctx_test.cc"], + deps = [ + ":tfr_decompose_ctx", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "graph_decompose_pass", + srcs = ["integration/graph_decompose_pass.cc"], + hdrs = ["integration/graph_decompose_pass.h"], + deps = [ + ":tfr_decompose_ctx", + "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", + "//tensorflow/stream_executor/lib", + "@llvm-project//mlir:IR", + ], + alwayslink = 1, +) + +tf_py_test( + name = "graph_decompose_test", + size = "small", + srcs = ["integration/graph_decompose_test.py"], + data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], + python_version = "PY3", + tags = [ + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + "notsan", # TODO(b/170752141) + ], + deps = [ + "//tensorflow/compiler/mlir/tfr/resources:composite_ops", + "//tensorflow/python/eager:def_function", + ], +) + +cc_library( + name = "node_expansion_pass", + srcs = ["integration/node_expansion_pass.cc"], + hdrs = ["integration/node_expansion_pass.h"], + deps = [ + ":tfr_decompose_ctx", + "//tensorflow/core/common_runtime/eager:core", + "//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +tf_py_test( + name = "node_expansion_test", + size = "small", + srcs = ["integration/node_expansion_test.py"], + data = ["//tensorflow/compiler/mlir/tfr/resources:decomposition_lib"], + python_version = "PY3", + tags = [ + "no_windows", # TODO(b/170752141) + "nomac", # TODO(b/170752141) + ], + deps = [ + "//tensorflow/compiler/mlir/tfr/resources:composite_ops", + ], +) + +tf_python_pybind_extension( + name = "tfr_wrapper", + srcs = ["python/tfr_wrapper.cc"], + module_name = "tfr_wrapper", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tfr", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:pybind11_status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", + "@pybind11", + ], +) + +py_library( + name = "composite", + srcs = ["python/composite.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "tfr_gen", + srcs = ["python/tfr_gen.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/compiler/mlir/tfr:tfr_wrapper", + "//tensorflow/python:op_def_registry", + "//tensorflow/python/autograph/pyct", + "//tensorflow/python/autograph/pyct/static_analysis", + "@gast_archive//:gast", + ], +) + +tf_py_test( + name = "tfr_gen_test", + size = "small", + srcs = ["python/tfr_gen_test.py"], + python_version = "PY3", + tags = ["no_pip"], + deps = [ + ":composite", + ":tfr_gen", + "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + "//tensorflow/compiler/mlir/tfr/resources:test_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + ], +) + +py_library( + name = "op_reg_gen", + srcs = ["python/op_reg_gen.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:op_def_registry", + "//tensorflow/python/autograph/pyct", + "@gast_archive//:gast", + ], +) + +py_test( + name = "op_reg_gen_test", + size = "small", + srcs = ["python/op_reg_gen_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":composite", + ":op_reg_gen", + "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc new file mode 100644 index 00000000000..c283b3b6986 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h" + +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +bool GraphDecomposePass::IsEnabled(const ConfigProto& config_proto) const { + const char* tfr_lib_env_val = getenv(std::string(kTFRLibEnv).c_str()); + return tfr_lib_env_val != nullptr; +} + +Status GraphDecomposePass::Run(const ConfigProto& config_proto, + mlir::ModuleOp module) { + if (!IsEnabled(config_proto)) { + VLOG(1) << "Skipping Graph Decomposition Pass, decompositin library was " + "not found"; + return Status::OK(); + } + VLOG(1) << "Run Graph Decomposition Passes"; + TF_ASSIGN_OR_RETURN(ctx_, TFRDecomposeContext::Get(module.getContext())); + TF_RETURN_IF_ERROR(ctx_->Decompose(module)); + return ctx_->Destroy(); +} + +namespace { +constexpr int kMlirGraphDecomposePassPriority = -1; + +static mlir_pass_registration::MlirOptimizationPassRegistration + register_mlir_graph_decompose_pass(kMlirGraphDecomposePassPriority, + std::make_unique()); +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h new file mode 100644 index 00000000000..589865c0f7d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +// An optimization pass that decompose the composite ops in a module according +// to the decomposition library. Currently the decomposition library is loaded +// each time the pass runs. A special environment variable is set to locate the +// decomposition library. +class GraphDecomposePass : public MlirOptimizationPass { + public: + llvm::StringRef name() const override { return "tfr"; } + + // Whether to run this pass. If this is enabled, the GraphDef will be imported + // to MLIR even no tf composition file is found. + bool IsEnabled(const ConfigProto& config_proto) const override; + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + Status Run(const ConfigProto& config_proto, mlir::ModuleOp module) override; + + private: + std::unique_ptr ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_GRAPH_DECOMPOSE_PASS_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py new file mode 100644 index 00000000000..d11ad9eef81 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_test.py @@ -0,0 +1,90 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.integrattion.graph_decompose.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_composite_ops.__file__) +_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace( + '.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class GraphDecomposeTest(test.TestCase): + + def setUp(self): + os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' + super(GraphDecomposeTest, self).setUp() + + def tearDown(self): + del os.environ['TF_MLIR_TFR_LIB_DIR'] + super(GraphDecomposeTest, self).tearDown() + + def testAddN(self): + add = def_function.function(gen_composite_ops.my_add_n) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq1 = add([t1]) + sq2 = add([t1, t2]) + sq3 = add([t1, t2, t3]) + self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4]) + self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8]) + self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12]) + + def testBiasedDense(self): + biased_dense = def_function.function(gen_composite_ops.my_biased_dense) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biased_dense(t1, t2, t3) + self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12]) + + def testBiasedDenseRelu(self): + biased_dense = def_function.function(gen_composite_ops.my_biased_dense) + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biased_dense(t1, t2, t3, act='relu') + self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12]) + + def testWithKnownKernel(self): + + @def_function.function + def biasd_dense_elu(x, y, z): + dot = gen_composite_ops.my_biased_dense(x, y, z) + return nn_ops.elu(dot) # with known kernel, should not expand. + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biasd_dense_elu(t1, t2, t3) + self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12]) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc new file mode 100644 index 00000000000..81106690590 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +Status CompositeOpExpansion::Run(EagerOperation* orig_op, + std::unique_ptr* out_op) { + if (!IsEnabled()) return Status::OK(); + if (orig_op->Device() != kVariantDeviceNull) return Status::OK(); + + VLOG(1) << "Run Node Expansion Passes"; + + // Get the FunctionDef and insert that into the context + const NodeDef& ndef = orig_op->MutableAttrs()->BuildNodeDef(); + auto& ctx = orig_op->EagerContext(); + Fprint128 cache_key = + orig_op->MutableAttrs()->CacheKey(orig_op->DeviceName()); + // Include soft placement policy in cache key since the placement strategy + // can change and thus affect which kernel is picked. + auto x = FingerprintCat64(cache_key.high64, cache_key.low64); + std::string fname = + absl::StrCat("_expanded_", ndef.name(), "_", std::to_string(x)); + if (!ctx.FindFunctionByName(fname)) { + TF_ASSIGN_OR_RETURN(auto func, TFRDecomposeContext::Expand(ndef, fname)); + TF_RETURN_IF_ERROR(ctx.AddFunctionDef(func)); + } + + // Rewrite the out_op to be the call op. This essentially a deep copy of the + // orig_op, except the op name. + auto* new_op = new EagerOperation(&ctx); + TF_RETURN_IF_ERROR( + new_op->Reset(fname.c_str(), orig_op->DeviceName().c_str())); + for (auto input : orig_op->GetInputs()) { + TF_RETURN_IF_ERROR(new_op->AddInput(input)); + } + new_op->MutableAttrs()->CopyAttributes(orig_op->Attrs()); + out_op->reset(new_op); + + VLOG(1) << "Rewrite the op to call function: " << fname; + + return Status::OK(); +} + +REGISTER_REWRITE(EagerOpRewriteRegistry::POST_PLACEMENT, CompositeOpExpansion); + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h new file mode 100644 index 00000000000..7e3aa82026b --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ + +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +// An optimization pass that decompose the composite ops in a module according +// to the decomposition library. Currently the decomposition library is loaded +// each time the pass runs. A special environment variable is set to locate the +// decomposition library. +class CompositeOpExpansion : public EagerOpRewrite { + public: + CompositeOpExpansion(string name, string file, string line) + : EagerOpRewrite(name, file, line) {} + + Status Run(EagerOperation* orig_op, + std::unique_ptr* out_op) override; + + private: + // Whether to run this pass. If this is enabled, the NodeDef will be imported + // to MLIR even no tf composition file is found. + bool IsEnabled() { + const char* tfr_lib_env_val = getenv(string(kTFRLibEnv).c_str()); + return tfr_lib_env_val != nullptr; + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py b/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py new file mode 100644 index 00000000000..6bdaed303ef --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_test.py @@ -0,0 +1,85 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensorflow.compiler.mlir.tfr.integrattion.node_expansion.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.compiler.mlir.tfr.resources import gen_composite_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + +_lib_dir = os.path.dirname(gen_composite_ops.__file__) +_lib_name = os.path.basename(gen_composite_ops.__file__)[4:].replace( + '.py', '.so') +load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) + + +class NodeExpansionTest(test.TestCase): + + def setUp(self): + os.environ['TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/resources' + super(NodeExpansionTest, self).setUp() + + def tearDown(self): + del os.environ['TF_MLIR_TFR_LIB_DIR'] + super(NodeExpansionTest, self).tearDown() + + def testAddN(self): + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq1 = gen_composite_ops.my_add_n([t1]) + sq2 = gen_composite_ops.my_add_n([t1, t2]) + sq3 = gen_composite_ops.my_add_n([t1, t2, t3]) + self.assertAllEqual(sq1.numpy().reshape(-1), [1, 2, 3, 4]) + self.assertAllEqual(sq2.numpy().reshape(-1), [2, 4, 6, 8]) + self.assertAllEqual(sq3.numpy().reshape(-1), [3, 6, 9, 12]) + + def testBiasedDense(self): + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = gen_composite_ops.my_biased_dense(t1, t2, t3) + self.assertAllEqual(sq.numpy().reshape(-1), [-3, 0, 5, 12]) + + def testBiasedDenseRelu(self): + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = gen_composite_ops.my_biased_dense(t1, t2, t3, act='relu') + self.assertAllEqual(sq.numpy().reshape(-1), [0, 0, 5, 12]) + + def testWithKnownKernel(self): + + def biasd_dense_elu(x, y, z): + dot = gen_composite_ops.my_biased_dense(x, y, z) + return nn_ops.elu(dot) # with known kernel, should not expand. + + t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t3 = constant_op.constant([[-10.0, -10.0], [-10.0, -10.0]]) + sq = biasd_dense_elu(t1, t2, t3) + self.assertAllClose(sq.numpy().reshape(-1), [-0.950213, 0, 5, 12]) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc new file mode 100644 index 00000000000..91fc40941e0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -0,0 +1,219 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SMLoc.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/compiler/mlir/tfr/passes/passes.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +const char* const kTFRLibEnv = "TF_MLIR_TFR_LIB_DIR"; + +StatusOr> TFRDecomposeContext::Get( + mlir::MLIRContext* mlir_ctx) { + Env* env = Env::Default(); + std::string tfr_lib_dir; + TF_RETURN_IF_ERROR(ReadStringFromEnvVar( + kTFRLibEnv, "tensorflow/compiler/mlir/tfr/resources", &tfr_lib_dir)); + string composite_mlir_dir = io::JoinPath(env->GetRunfilesDir(), tfr_lib_dir); + std::vector files; + TF_RETURN_IF_ERROR(env->GetChildren(composite_mlir_dir, &files)); + std::string tfr_raw_text; + for (const auto& file : files) { + string fullpath = io::JoinPath(composite_mlir_dir, file); + if (env->MatchPath(fullpath, io::JoinPath(composite_mlir_dir, "*.mlir"))) { + std::string text; + TF_RETURN_IF_ERROR(ReadFileToString(env, fullpath, &text)); + tfr_raw_text.append(text); + } + } + + auto ctx = TFRDecomposeContext::Get(tfr_raw_text, mlir_ctx); + if (!ctx) { + return errors::Internal(absl::StrCat( + "Failed to load the imported decomposition lib: ", tfr_raw_text)); + } + return ctx; +} + +std::unique_ptr TFRDecomposeContext::Get( + StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx) { + mlir_ctx->allowUnregisteredDialects(/*allow=*/true); + // Load dialects involved in the conversion + mlir::DialectRegistry& registry = mlir_ctx->getDialectRegistry(); + // clang-format off + registry.insert(); + // clang-format on + + // Load the TFR functions in a mlir::ModuleOp + auto memory_buffer = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(tfr_raw_text.data(), tfr_raw_text.size())); + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(memory_buffer), llvm::SMLoc()); + mlir::OwningModuleRef module = mlir::parseSourceFile(source_mgr, mlir_ctx); + + // Create the context + return absl::make_unique(std::move(module)); +} + +StatusOr TFRDecomposeContext::Decompose(const NodeDef& node_def, + StringPiece func_name) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def)); + DataTypeVector input_dtys, output_dtys; + TF_RETURN_IF_ERROR(InputTypesForNode(node_def, *op_def, &input_dtys)); + TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, &output_dtys)); + + mlir::MLIRContext* context = tfr_module_->getContext(); + llvm::SmallVector input_tys, output_tys; + mlir::Builder builder(context); + for (auto ty : input_dtys) { + mlir::Type elt_ty; + TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty)); + mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty); + input_tys.push_back(mlir_ty); + } + for (auto ty : output_dtys) { + mlir::Type elt_ty; + TF_RETURN_IF_ERROR(ConvertDataType(ty, builder, &elt_ty)); + mlir::TensorType mlir_ty = mlir::UnrankedTensorType::get(elt_ty); + output_tys.push_back(mlir_ty); + } + llvm::SmallVector attrs; + for (const auto& attr : node_def.attr()) { + TF_ASSIGN_OR_RETURN(auto mlir_attr, + ConvertAttributeValue(attr.second, &builder)); + attrs.push_back({mlir::Identifier::get(attr.first, context), mlir_attr}); + } + + mlir::Location loc = mlir::UnknownLoc::get(context); + mlir::ModuleOp module = mlir::ModuleOp::create(loc); + mlir::FunctionType func_type = + mlir::FunctionType::get(input_tys, output_tys, context); + llvm::StringRef func_name_str(func_name.data(), func_name.size()); + auto func = mlir::FuncOp::create(loc, func_name_str, func_type, {}); + module.push_back(func); + func.addEntryBlock(); + mlir::OpBuilder op_builder(func.getBody()); + + // Create the TF op + const std::string tf_op_full_name = absl::StrCat("tf.", node_def.op()); + mlir::OperationState op_state(loc, tf_op_full_name); + op_state.addOperands(func.getArguments()); + op_state.addTypes(output_tys); + op_state.addAttributes(attrs); + mlir::Operation* tf_op = op_builder.createOperation(op_state); + op_builder.create(loc, tf_op->getResults()); + + if (failed(mlir::verify(module))) { + return errors::Internal(absl::StrCat( + "Failed to verify the imported NodeDef: ", node_def.DebugString())); + } + + // Call the decompose passes by using the external symbol table. + if (failed(pm_.run(module))) { + return errors::Internal("Failed to run the decompose passes."); + } + + // Export the result as a FunctionDef. + FunctionDef func_def; + TF_RETURN_IF_ERROR( + ConvertMlirFunctionToFunctionLibraryDef(func, export_confs_, &func_def)); + module.erase(); + return func_def; +} + +Status TFRDecomposeContext::Decompose(mlir::ModuleOp user_module) { + // Call the decompose passes by using the external symbol table. + if (failed(pm_.run(user_module))) { + return errors::Internal("Failed to run the decompose passes."); + } + return Status::OK(); +} + +StatusOr TFRDecomposeContext::Expand(const NodeDef& node_def, + StringPiece func_name) { + mlir::MLIRContext mlir_ctx; + mlir_ctx.allowUnregisteredDialects(/*allow=*/true); + TF_ASSIGN_OR_RETURN(auto ctx, Get(&mlir_ctx)); + return ctx->Decompose(node_def, func_name); +} + +Status TFRDecomposeContext::Destroy() { + tfr_module_.release().erase(); + return Status::OK(); +} + +// Constructor of the decompose context. +TFRDecomposeContext::TFRDecomposeContext(mlir::OwningModuleRef tfr_module) + : tfr_module_(std::move(tfr_module)), pm_(tfr_module_->getContext()) { + mlir::OpPassManager& func_pm = pm_.nest(); + + // Prepare the imported graph. + func_pm.addPass(mlir::CreateExecutorDialectToFunctionalConversionPass()); + + // Run TFR lowering, inlining and raising to tf. + func_pm.addPass(mlir::TFR::CreateDecomposeTFOpsPass(tfr_module_.get())); + func_pm.addPass(mlir::TFR::CreateRaiseToTFOpsPass( + tfr_module_.get(), /*materialize_derived_attrs=*/true)); + + // Prepare to be exported. + func_pm.addPass(mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm_.addPass(mlir::CreateBreakUpIslandsPass()); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h new file mode 100644 index 00000000000..548c3ab8c0c --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace tensorflow { + +extern const char* const kTFRLibEnv; + +using stream_executor::port::StatusOr; +using NodeAndType = std::pair; + +// An wrapper for all the objects used to decompose a module (graph mode) and +// node_def (eager mode). Note that this class owns the decomposition library. +class TFRDecomposeContext { + public: + // The entry function to get a decompose context. All the required passes have + // been initialized. + static StatusOr> Get( + mlir::MLIRContext* mlir_ctx); + + static std::unique_ptr Get(StringPiece tfr_raw_text, + mlir::MLIRContext* mlir_ctx); + + // Decompose the NodeDef to a set of primitive ops according to the decompose + // library loaded. Wrap the decomposed result in a FunctionDef. + static StatusOr Expand(const NodeDef& node_def, + StringPiece func_name); + + // Constructor of the decompose context. To share the decompose library, the + // whole decompose TFR function library is loaded. + explicit TFRDecomposeContext(mlir::OwningModuleRef tfr_module); + + // Decompose the op in the NodeDef to a set of primitive ops according to the + // decompose library in the context. Wrap the decomposed result in a + // FunctionDef. + StatusOr Decompose(const NodeDef& node_def, + StringPiece func_name); + + // Decompose the ops in the ModuleOp to a set of primitive ops according to + // decompose library in the context. + Status Decompose(mlir::ModuleOp user_module); + + // Release all the owned references. + Status Destroy(); + + private: + mlir::OwningModuleRef tfr_module_; + mlir::PassManager pm_; + + GraphExportConfig export_confs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc new file mode 100644 index 00000000000..614e46df15c --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" + +#include +#include + +#include "absl/types/span.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +using testing::ElementsAreArray; +using testing::Test; + +namespace tensorflow { + +REGISTER_OP("MyAddN") + .Input("inputs: N * T") + .Output("sum: T") + .Attr("N: int >= 1") + .Attr("T: {numbertype, variant}") + .SetIsCommutative() + .SetIsAggregate() + .SetShapeFn(shape_inference::UnchangedShape); + +REGISTER_OP("RiscAdd") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr( + "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " + "complex64, complex128, string}") + .SetShapeFn(shape_inference::UnchangedShape); + +namespace { + +constexpr char tfr_raw_text[] = R"( + +tfr.func @tf__my_add_n(%values: !tfr.tensor_list, + %n: i64 {tfr.name="N"}) -> !tfr.tensor { + %index = constant 0 : index + %cst = constant 1 : i64 + %eq = cmpi "eq", %n, %cst : i64 + %v1 = tfr.get_element %values[%index] : (!tfr.tensor_list, index) -> !tfr.tensor + %res = scf.if %eq -> !tfr.tensor { + scf.yield %v1 : !tfr.tensor + } else { + %step = index_cast %cst : i64 to index + %end = index_cast %n : i64 to index + %reduce = scf.for %i = %step to %end step %step iter_args(%reduce_iter=%v1) -> !tfr.tensor { + %v = tfr.get_element %values[%i] : (!tfr.tensor_list, index) -> !tfr.tensor + %reduce_next = tfr.call @tf__risc_add(%reduce_iter, %v) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor + scf.yield %reduce_next : !tfr.tensor + } + scf.yield %reduce : !tfr.tensor + } + tfr.return %res : !tfr.tensor +} + +tfr.func @tf__risc_add_(!tfr.tensor, !tfr.tensor) -> !tfr.tensor attributes{T} +)"; + +class TFRDecomposeContextTest : public Test { + protected: + void SetUp() override { + test_ctx_ = TFRDecomposeContext::Get(tfr_raw_text, &ctx_); + } + + mlir::MLIRContext ctx_; + std::unique_ptr test_ctx_; +}; + +std::vector NodesSequenceOf(const FunctionDef& graph) { + std::vector nodes; + for (auto& node : graph.node_def()) { + nodes.push_back({node.op(), node.attr().at("T").type()}); + } + return nodes; +} + +TEST_F(TFRDecomposeContextTest, FLOAT_1_ins) { + std::vector src_list; + src_list.emplace_back("input", 0, DT_FLOAT); + NodeDef test_node; + auto status = NodeDefBuilder("float_add", "MyAddN") + .Input(src_list) + .Finalize(&test_node); + EXPECT_TRUE(status.ok()); + auto decomposed = test_ctx_->Decompose(test_node, "test"); + EXPECT_TRUE(decomposed.ok()); + std::vector expected_results{{"Identity", DT_FLOAT}}; + EXPECT_THAT(NodesSequenceOf(decomposed.ValueOrDie()), + ElementsAreArray(expected_results)); +} + +TEST_F(TFRDecomposeContextTest, FLOAT_3_ins) { + std::vector src_list; + src_list.emplace_back("in0", 0, DT_FLOAT); + src_list.emplace_back("in1", 0, DT_FLOAT); + src_list.emplace_back("in2", 0, DT_FLOAT); + NodeDef test_node; + auto status = NodeDefBuilder("float_add_3", "MyAddN") + .Input(src_list) + .Finalize(&test_node); + EXPECT_TRUE(status.ok()); + auto decomposed = test_ctx_->Decompose(test_node, "test"); + EXPECT_TRUE(decomposed.ok()); + + std::vector expected_results{{"RiscAdd", DT_FLOAT}, + {"RiscAdd", DT_FLOAT}}; + EXPECT_THAT(NodesSequenceOf(decomposed.ValueOrDie()), + ElementsAreArray(expected_results)); +} + +TEST_F(TFRDecomposeContextTest, INT32_3_ins) { + std::vector src_list; + src_list.emplace_back("in0", 0, DT_INT32); + src_list.emplace_back("in1", 0, DT_INT32); + src_list.emplace_back("in2", 0, DT_INT32); + NodeDef test_node; + auto status = + NodeDefBuilder("int_add", "MyAddN").Input(src_list).Finalize(&test_node); + EXPECT_TRUE(status.ok()); + auto decomposed = test_ctx_->Decompose(test_node, "test"); + EXPECT_TRUE(decomposed.ok()); + + std::vector expected_results{{"RiscAdd", DT_INT32}, + {"RiscAdd", DT_INT32}}; + EXPECT_THAT(NodesSequenceOf(decomposed.ValueOrDie()), + ElementsAreArray(expected_results)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index 2918336603b..562b3f79955 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -265,12 +265,11 @@ def TFR_ConstOp : TFR_Op<"constant", [ConstantLike, NoSideEffect]> { let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, Attribute value", + let builders = [OpBuilder<"Attribute value", [{ auto* ctx = value.getContext(); - state.addAttribute("value", value); - state.addTypes(TFRAttrType::get(ctx)); + $_state.addAttribute("value", value); + $_state.addTypes(TFRAttrType::get(ctx)); }]> ]; @@ -400,8 +399,8 @@ def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">, let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result, StringRef name, " - "FunctionType type, ArrayRef attrs = {}"> + OpBuilder<"StringRef name, FunctionType type, " + "ArrayRef attrs = {}"> ]; let extraClassDeclaration = [{ diff --git a/tensorflow/compiler/mlir/tfr/python/composite.py b/tensorflow/compiler/mlir/tfr/python/composite.py new file mode 100644 index 00000000000..7f558ce2fe7 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/composite.py @@ -0,0 +1,56 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Op composition registration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# TODO(fengliuai): add the tf_export decrator +class Composite(object): + """A decorator to register a function as a composition for an TF operator. + + The argument to the decorator must be the name of a TF raw operator the + function composites for. Decorated function must take positional arguments + which corresponds to the input and attributes in OpDef of the TF operation. + # TODO(fengliuai): more documents here. + + Example: + @composite.Composite('AddN') + def _compose_add_n(inputs, N): + if N == 1: + .... + """ + + # TODO(fengliuai): support input_binding and output_binding so the arguments + # are not positional. + def __init__(self, + op_name, + inputs=None, + attrs=None, + derived_attrs=None, + outputs=None): + self._op_name = op_name + self._inputs = inputs + self._attrs = attrs + self._derived_attrs = derived_attrs + self._outputs = outputs + + def __call__(self, compose_fn): + # TODO(fengliuai): more sanity check of the input function and make sure + # the bounded arguments of the function matches the 'inputs' and 'attrs'. + setattr(compose_fn, '_tfr_op_name', self._op_name) + return compose_fn diff --git a/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py b/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py new file mode 100644 index 00000000000..4876e17790b --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/op_reg_gen.py @@ -0,0 +1,148 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""op_reg_gen: Generate op registration code from composite op code.""" + +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast as ast + +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct import transpiler +from tensorflow.python.framework import op_def_registry +from tensorflow.python.util import tf_inspect + +_COMPOSITE_ARG_LIST = ['op_name', 'inputs', 'attrs', 'derived_attrs', 'outputs'] + + +class OpRegGenImpl(transformer.CodeGenerator): + """Visit the AST and generate C++ op registration functions.""" + + def __init__(self, ctx): + super(OpRegGenImpl, self).__init__(ctx) + self.ctx = ctx + + def visit_Name(self, node): + return node.id + + def visit_Constant(self, node): + return node.value + + def visit_keyword(self, node): + return node.arg, self.visit(node.value) + + def visit_List(self, node): + return [self.visit(cst) for cst in node.elts] + + def visit_arguments(self, node): + return [self.visit(arg) for arg in node.args] + + def visit_FunctionDef(self, node): + # TODO(fengliuai): create one utility method to match different apis and + # shared it with the tfr_gen.py module. + compose_dec = [] + for dec in node.decorator_list: + if isinstance(dec, ast.Call): + if isinstance(dec.func, ast.Attribute) and dec.func.attr == 'Composite': + compose_dec.append(dec) + if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': + compose_dec.append(dec) + + if not compose_dec: + # skip a non-composition function + return + elif len(compose_dec) > 1: + raise KeyError('More than one TF ops decomposes for.') + + all_dec_args = {} + for arg_name, arg_value in zip(_COMPOSITE_ARG_LIST, compose_dec[0].args): + all_dec_args[arg_name] = self.visit(arg_value) + + kw_dec_args = dict([self.visit(kw) for kw in compose_dec[0].keywords]) + + if all_dec_args.keys() & kw_dec_args.keys(): + raise KeyError('More arguments than expected.') + + all_dec_args.update(kw_dec_args) + + op_name = all_dec_args['op_name'] + op_def = op_def_registry.get(op_name) + if op_def: + if len(all_dec_args) > 1: + # Op has been registered, so it is a user error to specify op def. + raise ValueError('Op has been registered: ' + op_name) + else: + # Op has been registered, then we don't need to generate register code. + return + + # Validates the function inputs match what are in the decorator. + inputs = all_dec_args.get('inputs', []) + attrs = all_dec_args.get('attrs', []) + expected_args = [arg.split(':')[0] for arg in inputs + attrs] + all_func_args = self.visit(node.args) + + if len(expected_args) != len(all_func_args): + raise KeyError('Composition arguments do not match the registration.') + + cxx_reg_code = '\nREGISTER_OP("{0}")'.format(op_name) + for input_ in inputs: + cxx_reg_code += '\n .Input("{0}")'.format(input_) + for attr in attrs: + py_str = attr.replace('"', '\'') + cxx_reg_code += '\n .Attr("{0}")'.format(py_str) + for attr in all_dec_args.get('derived_attrs', []): + py_str = attr.replace('"', '\'') + cxx_reg_code += '\n .Attr("{0}")'.format(py_str) + for output_ in all_dec_args.get('outputs', []): + cxx_reg_code += '\n .Output("{0}")'.format(output_) + cxx_reg_code += ';\n' + self.emit(cxx_reg_code) + + +class OpRegGen(transpiler.GenericTranspiler): + """Transforms Python objects into TFR MLIR source code.""" + + def transform_ast(self, node, ctx): + gen = OpRegGenImpl(ctx) + gen.visit(node) + return gen.code_buffer + + +def op_reg_gen(func): + """Parse a function and emit the TFR functions.""" + op_reg_code, _ = OpRegGen().transform(func, None) + return op_reg_code + + +def gen_register_op(source, method_prefix=None): + """Parse a python code and emit the TFR functions from a target class.""" + mlir_funcs = [ + op_reg_gen(func) + for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) + if not method_prefix or name.startswith(method_prefix) + ] + headers = r""" +#include "third_party/tensorflow/core/framework/node_def_builder.h" +#include "third_party/tensorflow/core/framework/op.h" + +namespace tensorflow { + """ + code = '\n'.join(mlir_funcs) + return headers + code + '} // namespace tensorflow\n' diff --git a/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py b/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py new file mode 100644 index 00000000000..807e48a3dbf --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/op_reg_gen_test.py @@ -0,0 +1,82 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `op_reg_gen` module.""" + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op +from tensorflow.python.platform import test + + +Composite = composite.Composite + + +@composite.Composite( + 'TestNoOp', derived_attrs=['T: numbertype'], outputs=['o1: T']) +def _composite_no_op(): + pass + + +@Composite( + 'TestCompositeOp', + inputs=['x: T', 'y: T'], + attrs=['act: {"", "relu"}', 'trans: bool = true'], + derived_attrs=['T: numbertype'], + outputs=['o1: T', 'o2: T']) +def _composite_op(x, y, act, trans): + return x + act, y + trans + + +class TFRGenTensorTest(test.TestCase): + """MLIR Generation Tests for MLIR TFR Program.""" + + def test_op_reg_gen(self): + cxx_code = gen_register_op(sys.modules[__name__]) + cxx_code_exp = r""" + CHECK: #include "third_party/tensorflow/core/framework/node_def_builder.h" + CHECK-NEXT: #include "third_party/tensorflow/core/framework/op.h" + CHECK-EMPTY + CHECK-LABEL: namespace tensorflow { + CHECK-EMPTY + CHECK-LABEL: REGISTER_OP("TestNoOp") + CHECK-NEXT: .Attr("T: numbertype") + CHECK-NEXT: .Output("o1: T"); + CHECK-EMPTY + CHECK-LABEL: REGISTER_OP("TestCompositeOp") + CHECK-NEXT: .Input("x: T") + CHECK-NEXT: .Input("y: T") + CHECK-NEXT: .Attr("act: {'', 'relu'}") + CHECK-NEXT: .Attr("trans: bool = true") + CHECK-NEXT: .Attr("T: numbertype") + CHECK-NEXT: .Output("o1: T") + CHECK-NEXT: .Output("o2: T"); + CHECK-EMPTY + CHECK-LABEL: } // namespace tensorflow + """ + self.assertTrue(fw.check(str(cxx_code), cxx_code_exp), str(cxx_code)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py new file mode 100644 index 00000000000..f8622d11511 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen.py @@ -0,0 +1,1363 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tfr_gen: Generate mlir tfr decomposition function from python code.""" + +# pylint: disable=invalid-name +# pylint: disable=missing-function-docstring +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import enum +import os +import re +import types +from typing import List, Tuple +import gast as ast + +from tensorflow.compiler.mlir.tfr import tfr_wrapper as tfr +from tensorflow.core.framework import types_pb2 +from tensorflow.python.autograph.converters import control_flow +from tensorflow.python.autograph.converters import return_statements +from tensorflow.python.autograph.impl import api +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct import transpiler +from tensorflow.python.autograph.pyct.static_analysis import activity +from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions +from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs +from tensorflow.python.autograph.pyct.static_analysis import type_inference +from tensorflow.python.framework import load_library +from tensorflow.python.framework import op_def_registry +from tensorflow.python.util import tf_inspect + + +class TFRTypes(enum.Enum): + """All the supported types. + + 1-3: tfr types + 4-99: mlir built-in types + 100-199: TF related translator internal types + 200- : Python related translator internal types + """ + TENSOR = 1 + TENSOR_LIST = 2 + ATTR = 3 + NONE = 4 + SHAPE = 5 # shape -> !shape.shape + I1 = 21 + I32 = 22 + I64 = 23 + F32 = 24 + INDEX = 25 + AG_UNDEFINED_VAL = 100 + AG_BUILTIN_FUNC = 101 + TF_RAW_OP = 102 + TF_REGION = 103 + TF_TENSOR_SHAPE_FUNC = 104 # shape.as_list + TF_TENSOR_SHAPE_LIST = 105 # shape.as_list() + PY_BUILTIN_FUNC = 200 + + # As these are not real types, __getattribute__ helps them appear more like + # actual types (i.e. class definitions). + def __getattribute__(self, name): + if name == 'shape' and object.__getattribute__(self, 'value') == 1: + return TFRTypes.SHAPE + if name == 'as_list' and object.__getattribute__(self, 'value') == 5: + return TFRTypes.TF_TENSOR_SHAPE_FUNC + return object.__getattribute__(self, name) + + def __str__(self): + if self.value < 4: # pylint: disable=comparison-with-callable + return '!tfr.' + self.name.lower() + elif self.value < 10: # pylint: disable=comparison-with-callable + return '!shape.' + self.name.lower() + else: + return self.name.lower() + + +_attribute_types = [ + TFRTypes.I1, TFRTypes.I32, TFRTypes.I64, TFRTypes.F32, TFRTypes.INDEX, + TFRTypes.ATTR +] + + +def _get_type_from_proto(arg_def=None, attr_def=None): + if not arg_def: + if attr_def.type == 'bool': + return TFRTypes.I1 + elif attr_def.type == 'int32': + return TFRTypes.I32 + elif attr_def.type == 'int' or attr_def.type == 'int64': + return TFRTypes.I64 + elif attr_def.type == 'float': + return TFRTypes.F32 + else: + return TFRTypes.ATTR + + if arg_def.number_attr or arg_def.type_list_attr: + return TFRTypes.TENSOR_LIST + else: + return TFRTypes.TENSOR + + +def _get_type_info_from_proto(arg_def=None, attr_def=None): + attr_type = _get_type_from_proto(arg_def, attr_def) + if not arg_def: + return '{}{{tfr.name="{}"}}'.format(attr_type, attr_def.name) + else: + attr_names = [] + if arg_def.number_attr: + attr_names.append(arg_def.number_attr) + if arg_def.type_attr: + attr_names.append(arg_def.type_attr) + if arg_def.type_list_attr: + attr_names.append(arg_def.type_list_attr) + + # TODO(fengliuai): currently we don't support backward type inference, so we + # have to store these non-derivable type in the signatures, and then they + # can be used to cast the values when raising to tf ops. + if arg_def.type == types_pb2.DT_FLOAT: + attr_names.append('f32_') + elif arg_def.type == types_pb2.DT_INT32: + attr_names.append('i32_') + elif arg_def.type == types_pb2.DT_INT64: + attr_names.append('i64_') + elif arg_def.type == types_pb2.DT_BOOL: + attr_names.append('i1_') + + if not attr_names: + return str(attr_type) + else: + return '{}<{}>'.format(attr_type, ','.join(attr_names)) + + +def _get_val_from_proto(attr_type, attr_val): + if attr_type == TFRTypes.I1: + return 'true' if attr_val.b else 'false' + elif attr_type == TFRTypes.I32 or attr_type == TFRTypes.I64: + return attr_val.i + elif attr_type == TFRTypes.F32: + return attr_val.f + elif attr_type == TFRTypes.ATTR: + # string + if attr_val.HasField('s'): + return '"{}"'.format(attr_val.s.decode()) + # type + if attr_val.HasField('type'): + if attr_val.type == types_pb2.DT_FLOAT: + return 'f32' + elif attr_val.type == types_pb2.DT_INT32: + return 'i32' + elif attr_val.type == types_pb2.DT_INT64: + return 'i64' + elif attr_val.type == types_pb2.DT_BOOL: + return 'i1' + # list + if attr_val.HasField('list'): + if attr_val.list.f: + elt_ty = TFRTypes.F32 + values = attr_val.list.f + elif attr_val.list.i: + elt_ty = TFRTypes.I64 + values = attr_val.list.i + else: + elt_ty = TFRTypes.NONE + values = [] + array_attr_elts = ['{}:{}'.format(val, elt_ty) for val in values] + return '[{}]'.format(','.join(array_attr_elts)) + raise NotImplementedError( + 'Proto AttrValue not recoganized. type: {}, value: {}'.format( + attr_type, attr_val)) + + +def _collect_derived_attrs_from_proto(op_def): + derived_attrs = set() + for arg in op_def.input_arg: + if arg.type_attr: + derived_attrs.add(arg.type_attr) + if arg.number_attr: + derived_attrs.add(arg.number_attr) + if arg.type_list_attr: + derived_attrs.add(arg.type_list_attr) + + # TODO(fengliuai): currently we don't support backward type inference, so we + # have to store these non-derivable type in the signatures, and then they + # can be used to cast the values when raising to tf ops. + if arg.type == types_pb2.DT_FLOAT: + derived_attrs.add('f32_') + elif arg.type == types_pb2.DT_INT32: + derived_attrs.add('i32_') + elif arg.type == types_pb2.DT_INT64: + derived_attrs.add('i64_') + elif arg.type == types_pb2.DT_BOOL: + derived_attrs.add('i1_') + return derived_attrs + + +def _require_tensor_list(arg_def): + return arg_def.type_list_attr or arg_def.number_attr + + +def _camel_to_snake(name): + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + +class OpDefCache(object): + """A Dict to cache the OpDef for the Python function name.""" + + def __init__(self): + self._op_defs = {} + + def lookup(self, f_name, func_def=None, optional=False): + if f_name in self._op_defs.keys(): + return self._op_defs[f_name] + + if isinstance(func_def, types.FunctionType): + if not hasattr(func_def, '_tfr_op_name'): + # skip a non-composition function + if optional: + return (None, None) + else: + raise KeyError('OpDef does not exist: ' + f_name) + op_name = getattr(func_def, '_tfr_op_name') + elif not func_def: + op_name = f_name + else: + # TODO(fengliuai): create one utility method to match different apis. + compose_dec = [] + for dec in func_def.decorator_list: + if isinstance(dec, ast.Call): + if isinstance(dec.func, + ast.Attribute) and dec.func.attr == 'Composite': + compose_dec.append(dec) + if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite': + compose_dec.append(dec) + + if not compose_dec: + # skip a non-composition function + if optional: + return (None, None) + else: + raise KeyError('OpDef does not exist: ' + f_name) + elif len(compose_dec) > 1: + raise KeyError('More than one TF ops decomposes for.') + else: + op_name = compose_dec[0].args[0].value + + op_def = op_def_registry.get(op_name) + if not op_def: + raise ValueError('Not a registered op: ' + op_name) + derived_attrs = _collect_derived_attrs_from_proto(op_def) + self._op_defs[f_name] = (op_def, derived_attrs) + return (op_def, derived_attrs) + + def mlir_external_funcs(self): + tfr_funcs = [] + for _, (op_def, derived_attrs) in sorted(self._op_defs.items()): + tfr_func = '\ntfr.func @tf__{}_('.format(_camel_to_snake(op_def.name)) + + # tensor inputs + inputs = [ + _get_type_info_from_proto(arg_def) for arg_def in op_def.input_arg + ] + + # attribute inputs. The attribute with default values are moved backwards. + non_derived_attrs = [ + attr for attr in op_def.attr if attr.name not in derived_attrs + ] + attrs_no_default = [ + attr for attr in non_derived_attrs + if not attr.HasField('default_value') + ] + attrs_with_default = [ + attr for attr in non_derived_attrs if attr.HasField('default_value') + ] + attr_names = set() + for attr_def in attrs_no_default + attrs_with_default: + inputs.append(_get_type_info_from_proto(None, attr_def)) + attr_names.add(attr_def.name) + + # tensor outputs + outputs = [ + _get_type_info_from_proto(arg_def) for arg_def in op_def.output_arg + ] + + inputs = ','.join(inputs) + outputs = ','.join(outputs) + attrs = ','.join(sorted(derived_attrs.union(attr_names))) + tfr_funcs.append('{}{}) -> ({}) attributes {{{}}}'.format( + tfr_func, inputs, outputs, attrs)) + return tfr_funcs + + +_PY_TYPE_TO_TFR = { + bool: TFRTypes.I1, + int: TFRTypes.I64, + float: TFRTypes.F32, +} + +_AG_FIXED_RETURN_TYPE = { + 'for_stmt': type(None), + 'if_stmt': type(None), + 'Undefined': TFRTypes.AG_UNDEFINED_VAL, +} + +QN = qual_names.QN + +# TODO(mdan): Fix this with an importable module. +AG_MODULE = api._TRANSPILER._extra_locals['ag__'] # pylint:disable=protected-access + + +class TFRTypeResolver(type_inference.Resolver): + """Resolve types for the external names, calls and arguments.""" + + def __init__(self, op_defs): + super(TFRTypeResolver, self).__init__() + self._op_defs = op_defs + + # This pattern matching mechanism works with the functional form generated + # by autograph: + # + # for i in data: + # print(i) + # + # generates: + # + # def loop_body(itr): + # i = itr + # print(i) + # ag__.for_stmt(target) + # + # The mechanism lets us infer the type of the itr argument based on that of + # target. + self._for_loop_target_types = {} # Maps body function name to iterated. + self._for_loop_body_fns = {} # Used only to avoid collisions. + + def res_name(self, ns, types_ns, name): + name_str = str(name) + if name_str in ns: + ns_val = ns[name_str] + return {type(ns_val)}, ns_val + if name_str in __builtins__: + return {TFRTypes.PY_BUILTIN_FUNC}, __builtins__[name_str] + # This name is not in the namespace because the autograph transformation + # is not backloaded into Python. + if name_str == 'ag__': + return {type(AG_MODULE)}, AG_MODULE + + return None, None + + def res_value(self, ns, value): + if value is None: + return {TFRTypes.NONE} + if value in (TFRTypes.SHAPE, TFRTypes.TF_TENSOR_SHAPE_FUNC): + # See TFRTypes.__getattrbute__. + # TODO(mdan): Replacing the enum with classes would avoid this overlap. + return {value} + # TODO(mdan): Index more efficiently. Could do a name check instead. + if any(v is value for v in AG_MODULE.__dict__.values()): + return {TFRTypes.AG_BUILTIN_FUNC} + if getattr(value, '__name__', None) == 'tensorflow.raw_ops': + return {types.ModuleType} + if hasattr(value, '__module__'): + # All the imported operations, which are not autograph built-ins, are + # considered to be TF raw ops. + # TODO(fengliuai): refine the condition so we only matche tensorflow + # ops here. + return {TFRTypes.TF_RAW_OP} + # TODO(mdan): Is ATTR equivalent to string? + return {_PY_TYPE_TO_TFR.get(type(value), TFRTypes.ATTR)} + + def res_call(self, ns, types_ns, node, f_type, args, keywords): + name = anno.Basic.QN.of(node.func) + if f_type == (TFRTypes.AG_BUILTIN_FUNC,): + + if name == QN(QN('ag__'), attr='if_stmt'): + nouts = node.args[6].value + # TODO(mdan): Look at the actual types out of if_body. + side_effects = { + qual_names.QN(n.value): {TFRTypes.TENSOR} + for n in node.args[5].elts[:nouts] + } + return {type(None)}, side_effects + + if name == QN(QN('ag__'), attr='for_stmt'): + assert isinstance(node.args[2], ast.Name) + body_fn_name = str(anno.Basic.QN.of(node.args[2])) + assert body_fn_name not in self._for_loop_body_fns, ( + 'Previously used here: {}. Are you reusing the Resolver across ' + 'transformations?').format(self._for_loop_body_fns[body_fn_name]) + self._for_loop_body_fns[body_fn_name] = anno.Basic.ORIGIN.of(node) + + iterated_type = args[0] + assert iterated_type & { + TFRTypes.TENSOR_LIST, TFRTypes.TENSOR, List[int] + }, ( + iterated_type) + self._for_loop_target_types[body_fn_name] = iterated_type + + return {type(None)}, None + + # TODO(mdan): Actually resolve the type here instead. + ret_type = _AG_FIXED_RETURN_TYPE.get(name.qn[1], None) + if ret_type is not None: + return {ret_type}, None + raise NotImplementedError('return type of {}'.format(name)) + + elif f_type == (TFRTypes.TF_RAW_OP,): + op_name = name.qn[1] + op_def, _ = self._op_defs.lookup(op_name) + if len(op_def.output_arg) == 1: + return {_get_type_from_proto(op_def.output_arg[0])}, None + return ({tuple(_get_type_from_proto(arg) for arg in op_def.output_arg)}, + None) + + elif f_type == (TFRTypes.PY_BUILTIN_FUNC,): + assert name.is_simple() + if name == QN('range'): + return {List[int]}, None + + if name == QN('len'): + return {TFRTypes.INDEX}, None + + elif f_type == (TFRTypes.TF_TENSOR_SHAPE_FUNC,): + return {TFRTypes.TF_TENSOR_SHAPE_LIST}, None + + raise NotImplementedError('Function:', name, f_type) + + def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): + if f_is_local: + f_name_str = str(f_name) + if f_name_str in self._for_loop_target_types: + # See autograph/converters/control_flow.py - the function has a single + # argument, the iterate before any expansion. + assert self._for_loop_target_types[f_name_str] & {List[int]} + # Assume all loops are TF loops. Then the iterates are autoboxed into + # Tensors. + return {TFRTypes.INDEX} + else: + return None + + func = ns[f_name] + + op_def, derived_attrs = self._op_defs.lookup(f_name, func) + if op_def is None: + return None + pos = tf_inspect.getfullargspec(func).args.index(str(name)) + + if pos < len(op_def.input_arg): + arg_def = op_def.input_arg[pos] + return {_get_type_from_proto(arg_def)} + elif pos < len(op_def.input_arg) + len(op_def.attr) - len(derived_attrs): + non_derived_attr_pos = pos - len(op_def.input_arg) + for attr_def in op_def.attr: + # derived attribute, skip this one and continue to the next one. + if attr_def.name in derived_attrs: + continue + if non_derived_attr_pos == 0: + return {_get_type_from_proto(None, attr_def)} + non_derived_attr_pos -= 1 + + raise ValueError('Argument is not defined in OpDef: ' + str(name)) + + def res_subscript(self, ns, types_ns, node_or_slice, value, slice_): + assert len(value) == 1 + value, = tuple(value) + if value == TFRTypes.TF_TENSOR_SHAPE_LIST: + # TODO(mdan): This is not entirely correct for multi-element slices. + return {int} + elif value in (TFRTypes.TENSOR_LIST, TFRTypes.TENSOR): + # TODO(mdan): This is not entirely correct for multi-element slices. + return {TFRTypes.TENSOR} + raise NotImplementedError('slice of {}'.format(value)) + + def res_compare(self, ns, types_ns, node, left, right): + # TODO(fengliuai): make sure left and right are compatible + return {TFRTypes.I1} + + def res_binop(self, ns, types_ns, node, left, right): + # TODO(fengliuai): make sure left and right are compatible + return left + + +class SymbolTable(object): + """Symbol Table for python code.""" + + def __init__(self): + self.symbols = [] + self.enter_scope() + self.scf_scope = 0 + # reserved key words + self.insert_symbol('len', 'len', TFRTypes.PY_BUILTIN_FUNC) + + def enter_scope(self, scf_scope=False): + """Enter a new scope - at function level.""" + self.symbols.append({'types': {}, 'symbols': {}}) + self.curr_table = self.symbols[len(self.symbols) - 1] + if scf_scope: + self.scf_scope += 1 + + def insert_symbol(self, name, value, type_): + self.curr_table['symbols'][name] = (value, type_) + # TODO(mdan): Use the inferred type rather than tracking it here. + # The following field is decrepcated. + self.curr_table['types'][name] = type_ + return value + + def exit_scope(self): + self.symbols.pop() + self.curr_table = self.symbols[len(self.symbols) - 1] + if self.scf_scope > 0: + self.scf_scope -= 1 + + def in_scf_scope(self): + return self.scf_scope > 0 + + def lookup(self, name): + curr_idx = len(self.symbols) - 1 + while curr_idx >= 0 and (name not in self.symbols[curr_idx]['symbols']): + curr_idx -= 1 + if curr_idx < 0: + return None + return self.symbols[curr_idx]['symbols'][name] + + +class TFRGen(transformer.CodeGenerator): + """Visit the AST and generate MLIR TFR functions.""" + + def __init__(self, ctx, op_defs): + super(TFRGen, self).__init__(ctx) + self.ctx = ctx + self.symbol_table = SymbolTable() + self._op_defs = op_defs + + def _create_mlir_loc(self, loc): + """Creates mlir location from autograph ORIGIN value. + + Args: + loc: OriginInfo + + Returns: + A serialized mlir location string. + """ + if loc is not None and loc.loc.filename: + file_name = os.path.basename(loc.loc.filename) + return 'loc("{}":{}:{})'.format(file_name, loc.loc.lineno, + loc.loc.col_offset) + else: + return 'loc(unknown)' + + def _emit_with_loc(self, op_str, node=None): + """Emit the mlir operation with the location associated with the node. + + Args: + op_str: The mlir operation string to be emitted. + node: The node of the AST tree, the mlir operation translated from. + """ + loc = '' + if node: + loc = self._create_mlir_loc( + anno.getanno(node, anno.Basic.ORIGIN, default=None)) + self.emit(op_str + ' ' + loc) + + def _get_inferred_type(self, node, default=None): + types_ = anno.getanno(node, anno.Static.TYPES, None) + if not types_: + print('WARN: no Static.TYPES annotation. Fix the type inference pass: ') + self.debug_print(node) + return default + if types_ and len(types_) > 1: + raise ValueError('ambiguous inferred type for "{}": {}'.format( + node, types_)) + + type_, = types_ + # TODO(fengliuai): Tuple is added here to make return tuple work. + if type_ is list or type_ is Tuple: + # TODO(fengliuai): Seems like we need to move the followed list handling + # to the type inference and we shouldn't just put 'list' there. Otherwise + # we couldn't find out the right type for the Name node. + if not isinstance(node, ast.List): + return default + all_types = [ + anno.getanno(elt, anno.Static.TYPES, None) for elt in node.elts + ] + if (TFRTypes.TENSOR,) in all_types: + # For the elt which is not tfr.tensor, tfr.constant_tensor needs to be + # use to cast it to a tfr.tensor. + return TFRTypes.TENSOR_LIST + else: + return TFRTypes.ATTR + + if default is not None and type_ != default: + print('WARN: type annotation {}({}) does not match {}({})'.format( + type_, type(type_), default, type(default))) + self.debug_print(node) + + return type_ + + def _pack_tensor_list(self, value): + # This is packing a list of tensors, then the axis is 0. + axis = self._ssa_name('zero') + self._emit_with_loc('\n{} = constant 0 : i64'.format(axis)) + casted = self._ssa_name('pack') + self.emit('\n{} = tfr.call @tf__pack({}, {})'.format(casted, value, axis)) + self._emit_with_loc(' : (!tfr.tensor_list, i64) -> !tfr.tensor') + # load the op def of tf.Pack + self._op_defs.lookup('Pack') + return casted, TFRTypes.TENSOR + + def _index_to_I64(self, value, ty): + if ty == TFRTypes.INDEX: + casted = self._ssa_name('casted') + self._emit_with_loc('\n{} = index_cast {} : index to i64'.format( + casted, value)) + return casted, TFRTypes.I64 + else: + return value, ty + + def _value_to_tensor(self, value, ty, node): + value, ty = self._index_to_I64(value, ty) + cst_tensor = self._ssa_name('cst') + self.emit('\n{} = "tfr.constant_tensor"({})'.format(cst_tensor, value)) + self._emit_with_loc(' : ({}) -> !tfr.tensor'.format(ty), node) + return cst_tensor, TFRTypes.TENSOR + + def _ssa_name(self, prefix): + if isinstance(prefix, qual_names.QN): + assert prefix.is_simple(), 'ANF transform should have cleaned this up' + prefix = prefix.ssf() + return '%' + self.ctx.namer.new_symbol(prefix, set()) + + def _op_def(self, op_name): + return op_def_registry.get(op_name) + + def visit_block(self, block): + return [self.visit(item) for item in block] + + def visit_Pass(self, node): + if self.symbol_table.in_scf_scope(): + self._emit_with_loc('\nscf.yield', node) + else: + self._emit_with_loc('\ntfr.return', node) + + def visit_Attribute(self, node): + node_type = self._get_inferred_type(node, None) + if isinstance(node.value, ast.Name): + if node.value.id == 'ag__': + # some variables are assigned with 'ag__.xxx' method, we should handle + # them following the autograph convensions. + return (node.attr, TFRTypes.AG_BUILTIN_FUNC) + + if node_type == TFRTypes.TF_RAW_OP: + # This branch is used when it is inside tensorflow + return (node.attr, TFRTypes.TF_RAW_OP) + + value, _ = self.visit(node.value) + tensor_type = self._get_inferred_type(node.value, None) + # TODO(fengliuai): use node_type once it + if node_type == TFRTypes.SHAPE: + print('TODO: use "node_type"') + if node.attr == 'shape' and tensor_type == TFRTypes.TENSOR: + ssa_value = self._ssa_name('shape') + self._emit_with_loc( + '\n{} = tfr.get_shape {} -> !shape.shape'.format(ssa_value, value), + node) + return (ssa_value, TFRTypes.SHAPE) + + if isinstance(node.value, ast.Attribute): + if isinstance(node.value.value, ast.Name): + if node.value.value.id == 'tf' and node.value.attr == 'raw_ops': + # This branch is used when it is outside tensorflow + return (node.attr, TFRTypes.TF_RAW_OP) + + value, ty = self.visit(node.value) + # TODO(fengliuai): use node_type once it + if node_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: + print('TODO: use "node_type"') + if ty == TFRTypes.SHAPE and node.attr == 'as_list': + return (value, TFRTypes.TF_TENSOR_SHAPE_FUNC) + + raise NotImplementedError('Attribute kind not recoganized.') + + def visit_Assign(self, node): + values = self.visit(node.value) + if isinstance(node.targets[0], ast.Tuple): + targets = [elt.id for elt in node.targets[0].elts] + elif isinstance(node.targets[0], ast.Name): + targets = [node.targets[0].id] + else: + raise NotImplementedError('Assignment target type not recoganized.') + + if isinstance(values, list): + if len(targets) == len(values): + for key, value in zip(targets, values): + ssa_value, ty_ = value + ty = self._get_inferred_type(node.value, ty_) + self.symbol_table.insert_symbol(key, ssa_value, ty) + elif len(values) == 1: + n, ty = values[0] + assert ty == TFRTypes.TENSOR_LIST + # assign a tensor_list to multiple variables + for idx, key in enumerate(targets): + idx_name = self._ssa_name('idx') + self._emit_with_loc( + '\n{} = constant {} : index'.format(idx_name, idx), node) + elt_name = self._ssa_name('elt') + self.emit('\n{} = tfr.get_element {}[{}]'.format( + elt_name, n, idx_name)) + self._emit_with_loc(' : (!tfr.tensor_list, index) -> !tfr.tensor', + node) + self.symbol_table.insert_symbol(key, elt_name, TFRTypes.TENSOR) + elif len(targets) == 1: + ssa_names = [n for n, _ in values] + tys = [t for _, t in values] + self.symbol_table.insert_symbol(targets[0], ssa_names, tys) + else: + self.symbol_table.insert_symbol(targets[0], values[0], values[1]) + + def _emit_binary_op(self, op, lhs, lhs_ty, rhs, rhs_ty): + assert lhs_ty, rhs_ty + if isinstance(op, ast.Sub): + code = 'sub' + elif isinstance(op, ast.Add): + code = 'add' + else: + raise NotImplementedError('BinOp operator not recognized' + op) + + if lhs_ty == TFRTypes.I64: + suffix = 'i' + elif lhs_ty == TFRTypes.F32: + suffix = 'f' + else: + raise NotImplementedError('BinOp operand type not recognized' + op) + + ret = self._ssa_name(code) + self._emit_with_loc( + '\n{} = {}{} {}, {} : {}'.format(ret, code, suffix, lhs, rhs, lhs_ty), + op) + return ret, lhs_ty + + def visit_AugAssign(self, node): + lhs, lhs_ty = self.visit(node.target) + rhs, rhs_ty = self.visit(node.value) + ret, ret_ty = self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) + self.symbol_table.insert_symbol(node.target.id, ret, ret_ty) + + def visit_BinOp(self, node): + lhs, lhs_ty = self.visit(node.left) + rhs, rhs_ty = self.visit(node.right) + return self._emit_binary_op(node.op, lhs, lhs_ty, rhs, rhs_ty) + + def visit_BoolOp(self, node): + values = [self.visit(value) for value in node.values] + # TODO(fengliuai): Handle more ast node types. + if isinstance(node.op, ast.Or): + raise NotImplementedError('Or operator not recognized') + elif isinstance(node.op, ast.And): + raise NotImplementedError('And operator not recognized') + + def visit_Call(self, node): + func_name, func_type = self.visit(node.func) + _ = self._get_inferred_type(node.func, func_type) + if func_type == TFRTypes.AG_BUILTIN_FUNC: + if func_name == 'if_stmt': + cond, _ = self.visit(node.args[0]) + body, _ = self.visit(node.args[1]) + orelse, _ = self.visit(node.args[2]) + get_state, _ = self.visit(node.args[3]) + nouts = int(node.args[6].value) + out_symbols = [] + # The out symbols are just a Tuple of names + for out in node.args[5].elts[:nouts]: + val, ty = self.symbol_table.lookup(out.value) + if ty != TFRTypes.AG_UNDEFINED_VAL: + raise ValueError('if stmt out symbol is not defined.') + out_symbols.append(out.value) + return self._visit_if_stmt(cond, body, orelse, get_state, out_symbols, + node) + elif func_name == 'for_stmt': + range_ = self._visit_iter(node.args[0]) + body, _ = self.visit(node.args[2]) + get_state, _ = self.visit(node.args[3]) + loop_carried = [out.value for out in node.args[5].elts] + # TODO(fengliuai): opt is not used here. + return self._visit_for_stmt(range_, body, get_state, loop_carried, node) + elif func_name == 'Undefined': + val = self._ssa_name(node.args[0].value) + return (val, TFRTypes.AG_UNDEFINED_VAL) + elif func_name == 'UndefinedReturnValue': + val = self._ssa_name('return_val') + return (val, TFRTypes.AG_UNDEFINED_VAL) + + if func_type == TFRTypes.TF_RAW_OP: + return self._visit_tf_op(func_name, node.args, node.keywords, node) + + if func_type == TFRTypes.TF_TENSOR_SHAPE_FUNC: + return (func_name, TFRTypes.TF_TENSOR_SHAPE_LIST) + + if func_type == TFRTypes.PY_BUILTIN_FUNC: + if func_name == 'len': + arg, ty = self.visit(node.args[0]) + ty = self._get_inferred_type(node.args[0], ty) + assert ty == TFRTypes.TF_TENSOR_SHAPE_LIST, ty + len_value = self._ssa_name('len') + self._emit_with_loc( + '\n{} = shape.rank {} : !shape.shape -> !shape.size'.format( + len_value, arg), node) + size_value = self._ssa_name('len_size') + self._emit_with_loc( + '\n{} = shape.size_to_index {} : !shape.size'.format( + size_value, len_value), node) + return (size_value, TFRTypes.INDEX) + + raise NotImplementedError('call operator not recognized: {} {}'.format( + func_name, func_type)) + + def visit_Compare(self, node): + lhs, lhs_ty = self.visit(node.left) + for op, right in zip(node.ops, node.comparators): + rhs, _ = self.visit(right) + if isinstance(op, ast.Eq): + pred = 'eq' + elif isinstance(op, ast.Lt): + pred = 'ult' + elif isinstance(op, ast.LtE): + pred = 'ule' + elif isinstance(op, ast.Gt): + pred = 'ugt' + elif isinstance(op, ast.GtE): + pred = 'uge' + elif isinstance(op, ast.NotEq): + pred = 'ne' + else: + raise NotImplementedError('Compare operator not recognized') + + ret = self._ssa_name(pred) + if lhs_ty == TFRTypes.ATTR: + self._emit_with_loc( + '\n{} = tfr.equal {}, {} -> i1'.format(ret, lhs, rhs), node) + else: + if lhs_ty == TFRTypes.I64: + code = 'cmpi' + elif lhs_ty == TFRTypes.F32: + code = 'cmpf' + else: + raise NotImplementedError('Compare operand type not recognized') + self._emit_with_loc( + '\n{} = {} "{}", {}, {} : {}'.format(ret, code, pred, lhs, rhs, + lhs_ty), node) + + return ret, TFRTypes.I1 + + def visit_Constant(self, node): + cst_name = self._ssa_name('cst') + if node.value is None: + cst_ty = TFRTypes.NONE + elif isinstance(node.value, bool): + cst_ty = self._get_inferred_type(node) + cst_val = str(node.value).lower() + self._emit_with_loc('\n{} = constant {}'.format(cst_name, cst_val), node) + else: + cst_ty = self._get_inferred_type(node) + cst_val = node.value + if cst_ty == TFRTypes.ATTR: + self._emit_with_loc( + '\n{} = tfr.constant "{}" -> {}'.format(cst_name, cst_val, cst_ty), + node) + else: + self._emit_with_loc( + '\n{} = constant {} : {}'.format(cst_name, cst_val, cst_ty), node) + return cst_name, cst_ty + + def visit_FunctionDef(self, node): + op_def, derived_attrs = self._op_defs.lookup(node.name, node, True) + if op_def is None: + # Nested function. Insert it to symbol table for looking up later. + self.symbol_table.insert_symbol(node.name, node, None) + return + op_name = op_def.name + if self.symbol_table.lookup(op_name): + raise LookupError('Composition has not been registered for op: ' + + op_name) + else: + self.symbol_table.insert_symbol(node.name, None, None) + + self.symbol_table.enter_scope() + self.emit('\ntfr.func @tf__{0}('.format(_camel_to_snake(op_name))) + + arg_list = [] + idx = 0 + max_idx = len(op_def.input_arg) + len(op_def.attr) + for arg in node.args.args: + arg_name = self._ssa_name(anno.getanno(arg, anno.Basic.QN)) + arg_type = anno.getanno(arg, anno.Static.TYPES)[0] + + arg_attr = '' + if idx >= len(op_def.input_arg): + attr_def = op_def.attr[idx - len(op_def.input_arg)] + # skip the derived attributes + while attr_def.name in derived_attrs and (idx + 1) < max_idx: + idx += 1 + attr_def = op_def.attr[idx - len(op_def.input_arg)] + if idx >= max_idx: + raise ValueError('Argument is not defined in OpDef: ' + arg_name) + + arg_attr += '{{tfr.name="{}"'.format(attr_def.name) + if attr_def.HasField('default_value'): + default_val = _get_val_from_proto(arg_type, attr_def.default_value) + arg_attr += ',tfr.default={}'.format(default_val) + arg_attr += '}' + + idx += 1 + arg_str = '{}: {}{}'.format(arg_name, arg_type, arg_attr) + arg_list.append(arg_str) + self.symbol_table.insert_symbol(arg.id, arg_name, arg_type) + + ret_type_list = [] + for ret_def in op_def.output_arg: + if ret_def.number_attr or ret_def.type_list_attr: + ret_type_list.append(str(TFRTypes.TENSOR_LIST)) + else: + ret_type_list.append(str(TFRTypes.TENSOR)) + + self.emit('{}) -> ({}) {{'.format(', '.join(arg_list), + ', '.join(ret_type_list))) + self.visit_block(node.body) + self._emit_with_loc('\n}', node) + self.symbol_table.exit_scope() + + def visit_arguments(self, node): + # TODO(fengliuai): return ordered the types and names. + # We need to order the arguments to match the assumption in the TFR dialect. + raise NotImplementedError('arguments not supported.') + + def visit_Lambda(self, node): + raise NotImplementedError('Lambda not supported.') + + def _get_mlir_ssa_values(self, name_prefix, out_types): + """Create MLIR convention SSA values.""" + out_ssa_values = [] + if not out_types: + return '', out_ssa_values + + out_name = self._ssa_name(name_prefix) + if len(out_types) == 1: + out_name_suffix = '' + out_ssa_values.append(out_name) + else: + # For multiple returns, MLIR uses '%s:i' when they are defined and + # '%s#i' when they are used. + out_name_suffix = ':{}'.format(len(out_types)) + for idx, _ in enumerate(out_types): + out_ssa_values.append('{}#{}'.format(out_name, idx)) + + return '{}{}'.format(out_name, out_name_suffix), out_ssa_values + + def _visit_if_stmt(self, cond, body_def, orelse_def, get_state, out_symbols, + node): + self.emit('\n') + ret_str, ret_ssa_values = self._get_mlir_ssa_values( + 'if_stmt', [TFRTypes.TENSOR] * len(out_symbols)) + if ret_ssa_values: + self.emit(ret_str + ' = ') + + # add ssa values to the symbol table + out_types = [] + for symbol, ssa_value in zip(out_symbols, ret_ssa_values): + self.symbol_table.insert_symbol(symbol, ssa_value, TFRTypes.TENSOR) + out_types.append(str(TFRTypes.TENSOR)) + + self.emit('scf.if {} -> ({}) {{'.format(cond, ', '.join(out_types))) + # Create a new scope in case the local variables are leaked. + self.symbol_table.enter_scope(scf_scope=True) + self.visit_block(body_def.body) + self.visit_block(get_state.body) + self.symbol_table.exit_scope() + + self.emit('\n} else {') + + # Create a new scope in case the local variables are leaked. + self.symbol_table.enter_scope(scf_scope=True) + self.visit_block(orelse_def.body) + self.visit_block(get_state.body) + self.symbol_table.exit_scope() + + self._emit_with_loc('\n}', node) + return list(zip(ret_ssa_values, out_types)) + + def _visit_iter(self, node): + if isinstance(node, ast.Call): + f_name = anno.getanno(node.func, anno.Basic.QN) + if f_name == QN('range'): + args = [self.visit(arg) for arg in node.args] + begin = None + step = None + end = None + if len(args) == 1: + end, end_ty = args[0] + elif len(args) == 2: + begin, begin_ty = args[0] + end, end_ty = args[1] + elif len(args) == 3: + begin, begin_ty = args[0] + end, end_ty = args[1] + step, step_ty = args[2] + + if begin is None: + begin = self._ssa_name('begin') + self._emit_with_loc('\n{} = constant 0 : index'.format(begin), node) + elif begin_ty != TFRTypes.INDEX: + begin_ = self._ssa_name('begin') + self._emit_with_loc( + '\n{} = index_cast {} : {} to index'.format( + begin_, begin, begin_ty), node) + begin = begin_ + + if end_ty != TFRTypes.INDEX: + end_ = self._ssa_name('end') + self._emit_with_loc( + '\n{} = index_cast {} : {} to index'.format(end_, end, end_ty), + node) + end = end_ + + if step is None: + step = self._ssa_name('step') + self._emit_with_loc('\n{} = constant 1 : index'.format(step), node) + elif step_ty != TFRTypes.INDEX: + step_ = self._ssa_name('step') + self._emit_with_loc( + '\n{} = index_cast {} : {} to index'.format(step_, step, step_ty), + node) + step = step_ + + return begin, end, step + + raise NotImplementedError('Iterator entity not supported.' + node) + + def _visit_for_stmt(self, range_, body_def, get_state, loop_carried, node): + self.emit('\n') + ret_str, ret_ssa_values = self._get_mlir_ssa_values( + 'for_stmt', [TFRTypes.TENSOR] * len(loop_carried)) + if ret_ssa_values: + self.emit(ret_str + ' = ') + + # Before enter the loop, we use the original ssa values as the initial + # values to the loop iteration arguments. We also create new ssa values as + # the returns of the scf for statements. The symbol table needs to be + # updated to these new ssa values before it enters the scope of the loop. + out_types = [] + init_values = [] + for symbol, ssa_value in zip(loop_carried, ret_ssa_values): + init, ty = self.symbol_table.lookup(symbol) + self.symbol_table.insert_symbol(symbol, ssa_value, ty) + out_types.append(str(ty)) + init_values.append((init, ty)) + + # Create a new scope in case the local variables are leaked. + self.symbol_table.enter_scope(scf_scope=True) + + # Create the iteration variable with index type + assert len(body_def.args.args) == 1 + it_name = body_def.args.args[0].id + it = self._ssa_name(it_name) + self.symbol_table.insert_symbol(it_name, it, TFRTypes.INDEX) + + self.emit('scf.for {} = {} to {} step {} '.format(it, range_[0], range_[1], + range_[2])) + if loop_carried: + iter_args = [] + for symbol, init in zip(loop_carried, init_values): + # create new ssa values for the loop carried variables + it_arg = self._ssa_name('it_arg') + self.symbol_table.insert_symbol(symbol, it_arg, init[1]) + iter_args.append('{} = {}'.format(it_arg, init[0])) + self.emit('iter_args({}) '.format(', '.join(iter_args))) + self.emit('-> ({}) {{'.format(', '.join(out_types))) + else: + self.emit(' {') + self.visit_block(body_def.body) + self.visit_block(get_state.body) + self.symbol_table.exit_scope() + self._emit_with_loc('\n}', node) + return list(zip(ret_ssa_values, out_types)) + + def _emit_default_constant_from_proto(self, attr_def): + """emit mlir constant statement from default value of the ArgDef proto.""" + name = self._ssa_name('cst') + cst_ty = _get_type_from_proto(None, attr_def) + cst_val = _get_val_from_proto(cst_ty, attr_def.default_value) + if cst_ty == TFRTypes.ATTR: + self._emit_with_loc('\n{} = tfr.constant {} -> {}'.format( + name, cst_val, cst_ty)) + elif cst_ty == TFRTypes.I1: + self._emit_with_loc('\n{} = constant {}'.format(name, cst_val)) + else: + self._emit_with_loc('\n{} = constant {} : {}'.format( + name, cst_val, cst_ty)) + return name, cst_ty + + def visit_keyword(self, node): + return node.arg, self.visit(node.value) + + def _visit_tf_op(self, op_name, args, keywords, node): + op_def, derived_attrs = self._op_defs.lookup(op_name) + ret_tys = [_get_type_from_proto(arg) for arg in op_def.output_arg] + + ret_str, ret_ssa_values = self._get_mlir_ssa_values(op_name, ret_tys) + + arg_strs = [] + ty_strs = [] + for arg in args: + value, ty = self.visit(arg) + arg_strs.append(value) + ty_strs.append(str(ty)) + + input_args = [arg for arg in op_def.input_arg] + attrs_no_default = [ + attr for attr in op_def.attr + if not attr.HasField('default_value') and attr.name not in derived_attrs + ] + attrs_with_default = [ + attr for attr in op_def.attr + if attr.HasField('default_value') and attr.name not in derived_attrs + ] + + kw_args = {} + for arg in keywords: + value, (ssa_name, ty) = self.visit(arg) + ty = self._get_inferred_type(arg.value, ty) + + # TODO(fengliuai): implement the "rename_to" for the customization in + # tensorflow/core/api_def/base_api/* + if value == 'axis': + value = 'split_dim' + + kw_args[value] = (ssa_name, ty) + + # tensor arguments and attribute arguments + ordered_args = input_args + attrs_no_default + attrs_with_default + for attr_def in ordered_args[len(args):]: + if attr_def.name in kw_args: + value, ty = kw_args[attr_def.name] + if attr_def in input_args: + if ty in _attribute_types: + # the argument shouldn't be used as tf op calls directly. + value, ty = self._value_to_tensor(value, ty, node) + if ty is TFRTypes.TENSOR_LIST and not _require_tensor_list(attr_def): + value, ty = self._pack_tensor_list(value) + else: + value, ty = self._emit_default_constant_from_proto(attr_def) + arg_strs.append(value) + ty_strs.append(str(ty)) + + if ret_ssa_values: + self.emit('\n{} = '.format(ret_str)) + + self.emit('tfr.call @tf__{}('.format(_camel_to_snake(op_name))) + arg_str = ', '.join(arg_strs) + arg_ty_str = ', '.join(ty_strs) + ret_ty_str = ', '.join([str(ty) for ty in ret_tys]) + self._emit_with_loc( + '{}) : ({}) -> ({})'.format(arg_str, arg_ty_str, ret_ty_str), node) + return list(zip(ret_ssa_values, ret_tys)) + + def visit_If(self, node): + raise NotImplementedError('If not supported.') + + def visit_Name(self, node): + val, lookup_type = self.symbol_table.lookup(node.id) + type_ = self._get_inferred_type(node, lookup_type) + return val, type_ + + def visit_Return(self, node): + values = self.visit(node.value) + if self.symbol_table.in_scf_scope(): + self.emit('\nscf.yield ') + else: + self.emit('\ntfr.return ') + if not values: + return + + if isinstance(values, list): + vals, tys = zip(*values) + else: + vals = values[0] + tys = values[1] + + if isinstance(tys, list) or isinstance(tys, tuple): + tys = [str(t) for t in tys] + self._emit_with_loc('{} : {}'.format(', '.join(vals), ', '.join(tys)), + node) + elif tys != TFRTypes.NONE: + # TODO(fengliuai): scf region yield uses this branch. Fix it. + self._emit_with_loc('{} : {}'.format(vals, tys), node) + + def visit_Subscript(self, node): + val, ty = self.visit(node.value) + type_ = self._get_inferred_type(node.value, ty) + + # TODO(fengliuai): Here we hardcode the node.slice here to get the index + # type. Use the visit method once the type inference is done. + # slice_val, slice_ty = self.visit(node.slice) + if isinstance(node.slice, ast.Index): + if isinstance(node.slice.value, ast.Constant): + # TODO(fengliuai): promote to an assignment + idx_val = self._ssa_name('cst') + self._emit_with_loc( + '\n{} = constant {} : index'.format(idx_val, + node.slice.value.value), node) + else: + idx_val, _ = self.visit(node.slice.value) + else: + raise NotImplementedError('non-index slice not supported.') + + elt = self._ssa_name('elt') + if type_ == TFRTypes.TENSOR_LIST: + self.emit('\n{} = tfr.get_element {}[{}] '.format(elt, val, idx_val)) + self._emit_with_loc(': (!tfr.tensor_list, index) -> !tfr.tensor', node) + return (elt, TFRTypes.TENSOR) + elif type_ == TFRTypes.TF_TENSOR_SHAPE_LIST: + size_ = self._ssa_name('size') + self.emit('\n{} = shape.get_extent {}, {}'.format(size_, val, idx_val)) + self._emit_with_loc(': !shape.shape, index -> !shape.size', node) + self._emit_with_loc( + '\n{} = shape.size_to_index {} : !shape.size'.format(elt, size_), + node) + return (elt, TFRTypes.INDEX) + + def visit_List(self, node): + out_type = self._get_inferred_type(node) + vals = [] + tys = [] + for elt in node.elts: + val, ty = self.visit(elt) + if ty in _attribute_types and out_type == TFRTypes.TENSOR_LIST: + # This list is a tensor list, then cast all the input values to tensors. + val, ty = self._value_to_tensor(val, ty, node) + else: + # We shouldn't use index type to build the list because list will be use + # as attribute. + val, ty = self._index_to_I64(val, ty) + vals.append(val) + tys.append(str(ty)) + + list_val = self._ssa_name('list') + self.emit('\n{} = "tfr.build_list"({})'.format(list_val, ', '.join(vals))) + self._emit_with_loc(' : ({}) -> {}'.format(', '.join(tys), out_type), node) + return (list_val, out_type) + + def visit_Tuple(self, node): + return [self.visit(elt) for elt in node.elts] + + def visit_UnaryOp(self, node): + value, ty = self.visit(node.operand) + if isinstance(node.op, ast.USub): + zero_value = self._ssa_name('zero') + self._emit_with_loc('\n{} = constant 0 : {}'.format(zero_value, ty), node) + ssa_value = self._ssa_name('cst') + if ty == TFRTypes.I32 or ty == TFRTypes.I64: + self._emit_with_loc( + '\n{} = subi {}, {} : {}'.format(ssa_value, zero_value, value, ty), + node) + elif ty == TFRTypes.F32: + self._emit_with_loc( + '\n{} = subf {}, {} : {}'.format(ssa_value, zero_value, value, ty), + node) + else: + raise NotImplementedError('USub type not recognized: ' + str(ty)) + return ssa_value, ty + raise NotImplementedError('USub operator not recognized') + + def visit_For(self, node): + raise NotImplementedError('For operator not recognized') + + def visit_While(self, node): + raise NotImplementedError('While operator not recognized') + + def visit_Try(self, node): + # Only handles the body of the try statement. + self.visit_block(node.body) + + +def _apply_py_to_tf_passes(node, ctx): + """Apply transformations from PyToTF to match tf.function tracing.""" + # TODO(fengliuai): we don't know which passes are required, thus we evalute + # each one when the corresponding node is handled. + # copied from PyToTF.transform_ast + node = return_statements.transform(node, ctx, False) + node = control_flow.transform(node, ctx) + return node + + +class TfrGen(transpiler.GenericTranspiler): + """Transforms Python objects into TFR MLIR source code.""" + + def __init__(self, op_defs): + self._op_defs = op_defs + + def transform_ast(self, node, ctx): + node = _apply_py_to_tf_passes(node, ctx) + # TODO(mdan): Enable this. + # node = anf.transform(node, ctx) + + graphs = cfg.build(node) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx) + node = reaching_definitions.resolve(node, ctx, graphs) + node = reaching_fndefs.resolve(node, ctx, graphs) + node = type_inference.resolve(node, ctx, graphs, + TFRTypeResolver(self._op_defs)) + + mlir_generator = TFRGen(ctx, self._op_defs) + mlir_generator.visit(node) + return mlir_generator.code_buffer + + +def tfr_gen(func, op_defs): + """Parse a function and emit the TFR functions.""" + mlir_code, _ = TfrGen(op_defs).transform(func, None) + assert tfr.verify(mlir_code), 'mlir code not verified: {}'.format(mlir_code) + return mlir_code + + +def tfr_gen_from_module(source, method_prefix=None, op_libraries=None): + """Parse a python code and emit the TFR functions from a target class.""" + op_defs = OpDefCache() + + if op_libraries: + for m in op_libraries: + lib_dir = os.path.dirname(m.__file__) + prefix_len = len('gen_') + lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so') + # Load the op library so the op is added to the op registry. This is + # required when the op cc_library couldn't be statically linked in open + # source. + # This is a no op if the op shared library couldn't be found in the same + # directory of the op Python API. + load_library.load_op_library(os.path.join(lib_dir, lib_name)) + + mlir_funcs = [ + tfr_gen(func, op_defs) + for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction) + if not method_prefix or name.startswith(method_prefix) + ] + + return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs()) diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py b/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py new file mode 100644 index 00000000000..88696490c4a --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py @@ -0,0 +1,563 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `tfr_gen` module.""" + +# pylint: disable=missing-function-docstring +# pylint: disable=invalid-name +# pylint: disable=g-direct-tensorflow-import + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.compiler.mlir.python.mlir_wrapper import filecheck_wrapper as fw +from tensorflow.compiler.mlir.tfr.python import composite +from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module as tfr_gen +from tensorflow.compiler.mlir.tfr.resources import gen_test_ops as test_ops +from tensorflow.python.ops import gen_array_ops as array_ops +from tensorflow.python.ops import gen_math_ops as math_ops +from tensorflow.python.platform import test + + +Composite = composite.Composite + +#--- test fn for mlir location --- + + +@Composite('TestInputNOp') +def _tfr_loc_test(x): + n = 10 + x_sum = x[0] + for i in range(1, n): + x_sum = math_ops.Add(x_sum, x[i]) + return x_sum + + +#--- test fn for tfr tensors --- + + +@composite.Composite('TestNoOp') +def _tfr_tensor_empty_arg(): + pass + + +@composite.Composite('TestIdentityOp') +def _tfr_tensor_tensor(x): + return x + + +@composite.Composite('TestIdentityNOp') +def _tfr_tensor_tensor_list(x): + return x + + +@composite.Composite('TestInputNOp') +def _tfr_tensor_tensor_list_get_elt(x): + return x[1] + + +@composite.Composite('TestOutputNOp') +def _tfr_tensor_tensor_list_output(x): + return [x, x] + + +@composite.Composite('TestTwoInputsOp') +def _tfr_tensor_tensor_list_split(x, y, pred): + z, _ = array_ops.Split(axis=0, value=x, num_split=2) + (y, pred) # pylint: disable=pointless-statement + return z + + +@composite.Composite('TestNumAttrsOp') +def _tfr_tensor_tensor_with_cst(x1, y1, x2, y2): + x = array_ops.OneHot( + indices=[0, 2, -1, x1], depth=y1, on_value=True, off_value=False) + (x, x2, y2) # pylint: disable=pointless-statement + return + + +@composite.Composite('TestTwoOutputsOp') +def _tfr_tensor_two_output(x): + z = array_ops.Split(axis=0, value=x, num_split=2) + return z[0], z[1] + + +#--- test fn for scf control flow --- + + +@composite.Composite('TestTwoInputsOp') +def _tfr_control_flow_if(x, y, pred): + if pred: + return x + else: + return y + + +@composite.Composite('TestThreeInputsOp') +def _tfr_control_flow_nested_if(x, y, z, select): + if select == 'x': + return x + elif select == 'y': + return y + else: + return z + + +@composite.Composite('TestInputNOp') +def _tfr_control_flow_range_for(x): + # TODO(fengliuai): use len(x) instead + n = 10 + x_sum = x[0] + for i in range(1, n): + x_sum = math_ops.Add(x_sum, x[i]) + return x_sum + + +#--- test fn for tf ops --- + + +@composite.Composite('TestComplexTFOp') +def _tfr_tf_ops_complex(lhs, rhs): + left_padding, _ = array_ops.SplitV( + value=lhs, size_splits=[rhs, -1], axis=0, num_split=2) + _, right_padding = array_ops.SplitV( + value=lhs, size_splits=[rhs, rhs], axis=1, num_split=2) + return [left_padding, right_padding] + + +@composite.Composite('TestIdentityOp') +def _tfr_tf_ops_tensor(x): + return array_ops.Identity(x) + + +@composite.Composite('TestTwoInputsOp') +def _tfr_tf_ops_tensors(x, y, pred): + if pred: + return math_ops.Add(x, y) + else: + return array_ops.Concat(0, [x, y]) + + +@composite.Composite('TestInputNOp') +def _tfr_tf_ops_with_defaults(ins): + return test_ops.TestTwoInputsOp(ins[0], ins[1]) + + +#--- test fn for tfr attributes --- + + +@composite.Composite('TestNumAttrsOp') +def _tfr_attrs_num_type(x, y, x1, y1): + # int + z0 = [x, y] + z1 = x == y + z2 = x < y + z3 = x <= y + z4 = x > y + z5 = x >= y + z6 = x != y + z7 = x + y + z8 = x - y + z8 += x + z8 += 1 + (z0, z1, z2, z3, z4, z5, z6, z7, z8) # pylint: disable=pointless-statement + + # float + z9 = x1 > y1 + z10 = x1 + y1 + z11 = [x1, y1] + (z9, z10, z11) # pylint: disable=pointless-statement + return + + +@composite.Composite('TestNonNumAttrsOp') +def _tfr_attrs_tfr_type(x, y, z): + z1 = x == y + z2 = x == 'test' + z3 = y == z + (z1, z2, z3) # pylint: disable=pointless-statement + return + + +#--- test fn for shapes --- + + +@composite.Composite('TestIdentityOp') +def _tfr_shapes(x): + s1 = x.shape + s3 = x.shape.as_list() + + for i in range(len(s3)): + s3[i] # pylint: disable=pointless-statement + + for i in range(1, len(s3), 2): + s3[i] # pylint: disable=pointless-statement + + s5 = array_ops.Shape(x) + (s1, s3, s5) # pylint: disable=pointless-statement + return x + + +class TFRGenTestBase(test.TestCase): + + def _check_code(self, tfr_code, exp_tfr_code): + return self.assertTrue(fw.check(str(tfr_code), exp_tfr_code), str(tfr_code)) + + +class TFRGenTensorTest(TFRGenTestBase): + """MLIR Generation Tests for MLIR TFR Program.""" + + def test_tfr_loc(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_loc', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: %[[n:.*]] = constant 10 : i64 + CHECK-SAME loc("tfr_gen_test.py":%{{.*}}:6) + CHECK-NEXT: %[[cst:.*]] = constant 0 : index + CHECK-SAME loc("tfr_gen_test.py":%[[sum_line:.*]]:10) + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-SAME loc("tfr_gen_test.py":%[[sum_line]]:10) + CHECK-NEXT: %[[cst_1:.*]] = constant 1 : i64 + CHECK-SAME loc("tfr_gen_test.py":%[[for_line:.*]]:2) + CHECK-NEXT: %[[begin:.*]] = index_cast %[[cst_1]] : i64 to index + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %[[end:.*]] = index_cast %[[n]] : i64 to index + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %[[step:.*]] = constant 1 : index + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]] + CHECK-SAME: iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) { + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-SAME loc("tfr_gen_test.py":%[[add_line:.*]]:34) + CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) + CHECK-SAME loc("tfr_gen_test.py":%[[add_line]]:12) + CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor + CHECK-SAME loc(unknown) + CHECK-NEXT: } + CHECK-SAME loc("tfr_gen_test.py":%[[for_line]]:2) + CHECK-NEXT: %{{.*}} = constant true + CHECK-SAME loc(unknown) + CHECK-NEXT: tfr.return %[[for_stmt]] : !tfr.tensor + CHECK-SAME loc(unknown) + CHECK-NEXT: } + CHECK-SAME loc("tfr_gen_test.py":%{{def_line:.*}}:0) + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_tensors(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tensor', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_no_op() -> () { + CHECK-NEXT: tfr.return + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return %x : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_identity_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor_list) { + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return %x : !tfr.tensor_list + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: constant true + CHECK-NEXT: %[[index:.*]] = constant 1 : index + CHECK-NEXT: %[[sub:.*]] = tfr.get_element %x[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: tfr.return %[[sub]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_output_n_op(%x: !tfr.tensor) -> (!tfr.tensor_list) { + CHECK-NEXT: constant true + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x, %x) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: tfr.return %[[list]] : !tfr.tensor_list + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { + CHECK-NEXT: %[[cst:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list) + CHECK-NEXT: %[[cst_4:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%idx] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_5:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return %[[elt]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_num_attrs_op(%x1: i64{tfr.name="x1",tfr.default=-10}, %y1: i64{tfr.name="y1",tfr.default=1}, %x2: f32{tfr.name="x2",tfr.default=0.0}, %y2: f32{tfr.name="y2",tfr.default=-3.0}) -> () { + CHECK-NEXT: %[[cst:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[cst_2:.*]] = constant 1 : i64 + CHECK-NEXT: %[[zero:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_3:.*]] = subi %zero, %cst_2 : i64 + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%[[cst]], %[[cst_1]], %[[cst_3]], %x1) : (i64, i64, i64, i64) -> !tfr.attr + CHECK-NEXT: %[[cst_4:.*]] = constant true + CHECK-NEXT: %[[cst_5:.*]] = constant false + CHECK-NEXT: %[[cst_6:.*]] = "tfr.constant_tensor"(%[[list]]) : (!tfr.attr) -> !tfr.tensor + CHECK-NEXT: %[[cst_7:.*]] = "tfr.constant_tensor"(%y1) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_4]]) : (i1) -> !tfr.tensor + CHECK-NEXT: %[[cst_9:.*]] = "tfr.constant_tensor"(%[[cst_5]]) : (i1) -> !tfr.tensor + CHECK-NEXT: %[[cst_10:.*]] = constant -1 : i64 + CHECK-NEXT: %[[OneHot:.*]] = tfr.call @tf__one_hot(%[[cst_6]], %[[cst_7]], %[[cst_8]], %[[cst_9]], %[[cst_10]]) + CHECK-SAME: (!tfr.tensor, !tfr.tensor, !tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor) + CHECK-NEXT: constant true + CHECK-NEXT: tfr.return + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_two_outputs_op(%x: !tfr.tensor) -> (!tfr.tensor, !tfr.tensor) { + CHECK-NEXT: %[[cst:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[Split:.*]] = tfr.call @tf__split(%[[cst_2]], %x, %[[cst_1]]) : (!tfr.tensor, !tfr.tensor, i64) -> (!tfr.tensor_list) + CHECK-NEXT: constant true + CHECK-NEXT: %[[cst_4:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %[[Split]][%cst_4] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_5:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %[[Split]][%cst_5] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: tfr.return %[[elt]], %[[elt_1]] : !tfr.tensor, !tfr.tensor + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_control_flow(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_control_flow', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, + CHECK-SAME: %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { + CHECK-NEXT: %[[if:.*]] = scf.if %pred -> (!tfr.tensor) { + CHECK-NEXT: constant true + CHECK-NEXT: scf.yield %x : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: constant true + CHECK-NEXT: scf.yield %y : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: tfr.return %if_stmt : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_three_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, %z: !tfr.tensor, + CHECK-SAME: %select: !tfr.attr{tfr.name="act",tfr.default="z"}) -> (!tfr.tensor) { + CHECK-NEXT: %[[cst:.*]] = tfr.constant "x" -> !tfr.attr + CHECK-NEXT: %[[eq:.*]] = tfr.equal %select, %[[cst]] -> i1 + CHECK-NEXT: %[[if_stmt:.*]] = scf.if %[[eq]] -> (!tfr.tensor) { + CHECK-NEXT: %[[cst_1:.*]] = constant true + CHECK-NEXT: scf.yield %x : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: %[[cst_2:.*]] = tfr.constant "y" -> !tfr.attr + CHECK-NEXT: %[[eq_1:.*]] = tfr.equal %select, %[[cst_2]] -> i1 + CHECK-NEXT: %[[if_stmt1:.*]] = scf.if %[[eq_1]] -> (!tfr.tensor) { + CHECK-NEXT: %[[cst_3:.*]] = constant true + CHECK-NEXT: scf.yield %y : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: %[[cst_4:.*]] = constant true + CHECK-NEXT: scf.yield %z : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: scf.yield %[[if_stmt1]] : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: tfr.return %[[if_stmt]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_input_n_op(%x: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: %[[n:.*]] = constant 10 : i64 + CHECK-NEXT: %[[cst:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %x[%[[cst]]] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_1:.*]] = constant 1 : i64 + CHECK-NEXT: %[[begin:.*]] = index_cast %[[cst_1]] : i64 to index + CHECK-NEXT: %[[end:.*]] = index_cast %[[n]] : i64 to index + CHECK-NEXT: %[[step:.*]] = constant 1 : index + CHECK-NEXT: %[[for_stmt:.*]] = scf.for %[[itr_1:.*]] = %[[begin]] to %[[end]] step %[[step]] + CHECK-SAME: iter_args(%[[it_arg:.*]] = %[[elt]]) -> (!tfr.tensor) { + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %x[%itr_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%[[it_arg]], %[[elt_1]]) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) + CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return %[[for_stmt]] : !tfr.tensor + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_tf_ops(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_tf_ops', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_complex_tf_op(%lhs: !tfr.tensor, %rhs: !tfr.tensor) -> (!tfr.tensor_list) { + CHECK-NEXT: %[[cst:.*]] = constant 1 : i64 + CHECK-NEXT: %[[zero:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_1:.*]] = subi %[[zero]], %cst : i64 + CHECK-NEXT: %[[cst_2:.*]] = "tfr.constant_tensor"(%[[cst_1]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%rhs, %[[cst_2]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: %[[cst_3:.*]] = constant 0 : i64 + CHECK-NEXT: %[[cst_4:.*]] = constant 2 : i64 + CHECK-NEXT: %[[zero_1:.*]] = constant 0 : i64 + CHECK-NEXT: %[[pack:.*]] = tfr.call @tf__pack(%[[list]], %[[zero_1]]) : (!tfr.tensor_list, i64) -> !tfr.tensor + CHECK-NEXT: %[[cst_5:.*]] = "tfr.constant_tensor"(%[[cst_3]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[SplitV:.*]] = tfr.call @tf__split_v(%lhs, %[[pack]], %[[cst_5]], %[[cst_4]]) + CHECK-NEXT: %[[idx:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %SplitV[%idx] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[idx_1:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %SplitV[%idx_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[list_1:.*]] = "tfr.build_list"(%rhs, %rhs) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: %[[cst_6:.*]] = constant 1 : i64 + CHECK-NEXT: %[[cst_7:.*]] = constant 2 : i64 + CHECK-NEXT: %[[zero_2:.*]] = constant 0 : i64 + CHECK-NEXT: %[[pack_1:.*]] = tfr.call @tf__pack(%[[list_1]], %[[zero_2]]) : (!tfr.tensor_list, i64) -> !tfr.tensor + CHECK-NEXT: %[[cst_8:.*]] = "tfr.constant_tensor"(%[[cst_6]]) : (i64) -> !tfr.tensor + CHECK-NEXT: %[[SplitV_1:.*]] = tfr.call @tf__split_v(%lhs, %[[pack_1]], %[[cst_8]], %[[cst_7]]) + CHECK-NEXT: %[[idx_2:.*]] = constant 0 : index + CHECK-NEXT: %[[elt_2:.*]] = tfr.get_element %SplitV_1[%idx_2] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[idx_3:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_3:.*]] = tfr.get_element %SplitV_1[%idx_3] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_9:.*]] = constant true + CHECK-NEXT: %[[list_2:.*]] = "tfr.build_list"(%[[elt]], %[[elt_3]]) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: tfr.return %[[list_2]] : !tfr.tensor_list + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { + CHECK-NEXT: %cst = constant true + CHECK-NEXT: %[[Id:.*]] = tfr.call @tf__identity(%x) : (!tfr.tensor) -> (!tfr.tensor) + CHECK-NEXT: tfr.return %[[Id]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op(%x: !tfr.tensor, %y: !tfr.tensor, + CHECK-SAME: %pred: i1{tfr.name="pred",tfr.default=false}) -> (!tfr.tensor) { + CHECK-NEXT: %[[if_stmt:.*]] = scf.if %pred -> (!tfr.tensor) { + CHECK-NEXT: %cst = constant true + CHECK-NEXT: %[[Add:.*]] = tfr.call @tf__add(%x, %y) : (!tfr.tensor, !tfr.tensor) -> (!tfr.tensor) + CHECK-NEXT: scf.yield %[[Add]] : !tfr.tensor + CHECK-NEXT: } else { + CHECK-NEXT: %cst_1 = constant true + CHECK-NEXT: %[[cst_2:.*]] = constant 0 : i64 + CHECK-NEXT: %[[list:.*]] = "tfr.build_list"(%x, %y) : (!tfr.tensor, !tfr.tensor) -> !tfr.tensor_list + CHECK-NEXT: %[[Concat:.*]] = tfr.call @tf__concat(%[[cst_2]], %[[list]]) : (i64, !tfr.tensor_list) -> (!tfr.tensor) + CHECK-NEXT: scf.yield %[[Concat]] : !tfr.tensor + CHECK-NEXT: } + CHECK-NEXT: tfr.return %[[if_stmt]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) { + CHECK-NEXT: %cst = constant true + CHECK-NEXT: %[[cst_1:.*]] = constant 0 : index + CHECK-NEXT: %[[elt:.*]] = tfr.get_element %ins[%cst_1] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_2:.*]] = constant 1 : index + CHECK-NEXT: %[[elt_1:.*]] = tfr.get_element %ins[%cst_2] : (!tfr.tensor_list, index) -> !tfr.tensor + CHECK-NEXT: %[[cst_3:.*]] = constant false + CHECK-NEXT: %[[call:.*]] = tfr.call @tf__test_two_inputs_op( + CHECK-SAME: %[[elt]], %[[elt_1]], %[[cst_3]]) : (!tfr.tensor, !tfr.tensor, i1) -> (!tfr.tensor) + CHECK-NEXT: tfr.return %[[call]] : !tfr.tensor + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__add_(!tfr.tensor,!tfr.tensor) -> (!tfr.tensor) attributes {T} + + CHECK-LABEL: tfr.func @tf__concat_(!tfr.tensor,!tfr.tensor_list) -> (!tfr.tensor) attributes {N,T,i32_} + + CHECK-LABEL: tfr.func @tf__identity_(!tfr.tensor) -> (!tfr.tensor) attributes {T} + + CHECK-LABEL: tfr.func @tf__pack_(!tfr.tensor_list,i64{tfr.name="axis"}) -> (!tfr.tensor) attributes {N,T,axis} + + CHECK-LABEL: tfr.func @tf__split_v_(!tfr.tensor,!tfr.tensor,!tfr.tensor,i64{tfr.name="num_split"}) -> (!tfr.tensor_list) attributes {T,Tlen,i32_,num_split} + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor,!tfr.tensor,i1{tfr.name="pred"}) -> (!tfr.tensor) attributes {T,pred} + + CHECK-LABEL: tfr.func @tf__test_complex_tf_op_(!tfr.tensor,!tfr.tensor,i64{tfr.name="N"}) -> (!tfr.tensor_list) attributes {N,T,Tlen} + + CHECK-LABEL: tfr.func @tf__test_identity_op_(!tfr.tensor) -> (!tfr.tensor) attributes {T} + + CHECK-LABEL: tfr.func @tf__test_two_inputs_op_(!tfr.tensor,!tfr.tensor,i1{tfr.name="pred"}) -> (!tfr.tensor) attributes {T,pred} + + CHECK-LABEL: tfr.func @tf__test_input_n_op_(!tfr.tensor_list) -> (!tfr.tensor) attributes {N,T} + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tfr_attrs(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_attrs', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_num_attrs_op( + CHECK-SAME: %x: i64{tfr.name="x1",tfr.default=-10}, + CHECK-SAME: %y: i64{tfr.name="y1",tfr.default=1}, + CHECK-SAME: %x1: f32{tfr.name="x2",tfr.default=0.0}, + CHECK-SAME: %y1: f32{tfr.name="y2",tfr.default=-3.0}) -> () { + CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x, %y) : (i64, i64) -> !tfr.attr + CHECK-NEXT: %{{.*}} = cmpi "eq", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ult", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ule", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ugt", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "uge", %x, %y : i64 + CHECK-NEXT: %{{.*}} = cmpi "ne", %x, %y : i64 + CHECK-NEXT: %{{.*}} = addi %x, %y : i64 + CHECK-NEXT: %{{.*}} = subi %x, %y : i64 + CHECK-NEXT: %[[add_1:.*]] = addi %sub, %x : i64 + CHECK-NEXT: %[[cst:.*]] = constant 1 : i64 + CHECK-NEXT: %{{.*}} = addi %[[add_1]], %[[cst]] : i64 + CHECK-NEXT: %{{.*}} = cmpf "ugt", %x1, %y1 : f32 + CHECK-NEXT: %{{.*}} = addf %x1, %y1 : f32 + CHECK-NEXT: %{{.*}} = "tfr.build_list"(%x1, %y1) : (f32, f32) -> !tfr.attr + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return + CHECK-NEXT: } + + CHECK-LABEL: tfr.func @tf__test_non_num_attrs_op( + CHECK-SAME: %x: !tfr.attr{tfr.name="z"}, + CHECK-SAME: %y: !tfr.attr{tfr.name="x",tfr.default="hello"}, + CHECK-SAME: %z: !tfr.attr{tfr.name="y",tfr.default=f32}) -> () { + CHECK-NEXT: %{{.*}} = tfr.equal %x, %y -> i1 + CHECK-NEXT: %[[cst:.*]] = tfr.constant "test" -> !tfr.attr + CHECK-NEXT: %{{.*}} = tfr.equal %x, %[[cst]] -> i1 + CHECK-NEXT: %{{.*}} = tfr.equal %y, %z -> i1 + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + def test_tf_tensor_shape(self): + mlir_code = tfr_gen(sys.modules[__name__], '_tfr_shapes', [test_ops]) + mlir_code_exp = r""" + CHECK-LABEL: tfr.func @tf__test_identity_op(%x: !tfr.tensor) -> (!tfr.tensor) { + CHECK-NEXT: %[[shape:.*]] = tfr.get_shape %x -> !shape.shape + + CHECK-NEXT: %[[shape_1:.*]] = tfr.get_shape %x -> !shape.shape + CHECK-NEXT: %[[len:.*]] = shape.rank %[[shape_1]] : !shape.shape -> !shape.size + CHECK-NEXT: %[[index:.*]] = shape.size_to_index %[[len]] : !shape.size + CHECK-NEXT: %[[begin:.*]] = constant 0 : index + CHECK-NEXT: %[[step:.*]] = constant 1 : index + CHECK-NEXT: scf.for %[[itr_1:.*]] = %[[begin]] to %[[index]] step %[[step]] { + CHECK-NEXT: %[[size:.*]] = shape.get_extent %[[shape_1]], %[[itr_1]]: !shape.shape, index -> !shape.size + CHECK-NEXT: %[[elt:.*]] = shape.size_to_index %[[size]] : !shape.size + CHECK-NEXT: scf.yield + CHECK-NEXT: } + + CHECK-NEXT: %[[cst:.*]] = constant 1 : i64 + CHECK-NEXT: %[[len_1:.*]] = shape.rank %shape_1 : !shape.shape -> !shape.size + CHECK-NEXT: %[[len_size_1:.*]] = shape.size_to_index %[[len_1]] : !shape.size + CHECK-NEXT: %[[cst_1:.*]] = constant 2 : i64 + CHECK-NEXT: %[[begin_1:.*]] = index_cast %[[cst]] : i64 to index + CHECK-NEXT: %[[step_1:.*]] = index_cast %[[cst_1]] : i64 to index + CHECK-NEXT: scf.for %[[itr_3:.*]] = %[[begin_1]] to %[[len_size_1]] step %[[step_1]] + + CHECK: %[[cst:.*]] = tfr.constant i32 -> !tfr.attr + CHECK-NEXT: %[[Shape:.*]] = tfr.call @tf__shape(%x, %[[cst]]) : (!tfr.tensor, !tfr.attr) -> (!tfr.tensor) + CHECK-NEXT: %{{.*}} = constant true + CHECK-NEXT: tfr.return %x : !tfr.tensor + CHECK-NEXT: } + """ + self._check_code(mlir_code, mlir_code_exp) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc new file mode 100644 index 00000000000..b7372cffe2d --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/python/tfr_wrapper.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/AsmState.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +PYBIND11_MODULE(tfr_wrapper, m) { + m.def("verify", [](std::string input) { + mlir::MLIRContext ctx(/*loadAllDialects=*/true); + auto& registry = ctx.getDialectRegistry(); + registry.insert(); + ctx.getDialectRegistry().loadAll(&ctx); + + llvm::SourceMgr source_mgr = llvm::SourceMgr(); + source_mgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), + llvm::SMLoc()); + auto module = mlir::parseSourceFile(source_mgr, &ctx); + if (!module) { + return false; + } + + mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &ctx); + if (failed(mlir::verify(*module))) { + module->emitError("Invalid MLIR module: failed verification."); + return false; + } + return true; + }); +} diff --git a/tensorflow/compiler/mlir/tfr/resources/BUILD b/tensorflow/compiler/mlir/tfr/resources/BUILD index 6abb3917e57..62ca65c5b57 100644 --- a/tensorflow/compiler/mlir/tfr/resources/BUILD +++ b/tensorflow/compiler/mlir/tfr/resources/BUILD @@ -1,4 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") package( default_visibility = [ @@ -24,10 +25,6 @@ filegroup( cc_library( name = "composite_ops_cc", srcs = ["composite_ops.cc"], - copts = [ - "-Wno-unused-result", - "-Wno-unused-variable", - ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -36,22 +33,34 @@ cc_library( alwayslink = 1, ) +tf_custom_op_library( + name = "composite_ops.so", + srcs = [ + "composite_ops.cc", + ], +) + tf_gen_op_wrapper_py( - name = "composite_ops", - out = "composite_ops.py", + name = "gen_composite_ops", + out = "gen_composite_ops.py", deps = [ ":composite_ops_cc", ], ) +tf_custom_op_py_library( + name = "composite_ops", + dso = [":composite_ops.so"], + kernels = [":composite_ops_cc"], + visibility = ["//visibility:public"], + deps = [ + ":gen_composite_ops", + ], +) + cc_library( name = "test_ops_cc", - testonly = 1, srcs = ["test_ops.cc"], - copts = [ - "-Wno-unused-result", - "-Wno-unused-variable", - ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -59,3 +68,30 @@ cc_library( ], alwayslink = 1, ) + +tf_custom_op_library( + name = "test_ops.so", + srcs = [ + "test_ops.cc", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_test_ops", + out = "gen_test_ops.py", + deps = [ + ":test_ops_cc", + ], +) + +tf_custom_op_py_library( + name = "test_ops", + dso = ["test_ops.so"], + kernels = [ + ":test_ops_cc", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_test_ops", + ], +) diff --git a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc index bff7fe0c18c..3aaa0850805 100644 --- a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc +++ b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc @@ -50,6 +50,14 @@ REGISTER_OP("TestTwoInputsOp") .Attr("T: numbertype") .Attr("pred: bool = false"); +REGISTER_OP("TestComplexTFOp") + .Input("lhs: T") + .Input("rhs: Tlen") + .Output("output: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .Attr("Tlen: {int32, int64} = DT_INT64"); + REGISTER_OP("TestNumAttrsOp") .Attr("x1: int = -10") .Attr("y1: int = 1") diff --git a/tensorflow/compiler/mlir/tfr/tfr.bzl b/tensorflow/compiler/mlir/tfr/tfr.bzl new file mode 100644 index 00000000000..cc1b617f932 --- /dev/null +++ b/tensorflow/compiler/mlir/tfr/tfr.bzl @@ -0,0 +1,120 @@ +"""BUILD extension for TF composition project.""" + +load("//tensorflow:tensorflow.bzl", "py_binary", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.google.bzl", "pytype_library") + +def gen_op_libraries( + name, + src, + deps, + tags = [], + test = False): + """gen_op_libraries() generates all cc and py libraries for composite op source. + + Args: + name: used as the name component of all the generated libraries. + src: File contains the composite ops. + deps: Libraries the 'src' depends on. + tags: + test: + """ + if not src.endswith(".py") or name == src[:-3]: + fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src) + + gen_op_lib_exec = src[:-3] + py_binary( + name = gen_op_lib_exec, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/python:platform", + ] + deps, + ) + + register_op = "register_" + name + native.genrule( + name = register_op, + srcs = [], + outs = [name + ".inc.cc"], + cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, + exec_tools = [":" + gen_op_lib_exec], + local = 1, + tags = tags, + ) + + native.cc_library( + name = name + "_cc", + testonly = test, + srcs = [":" + register_op], + copts = [ + "-Wno-unused-result", + "-Wno-unused-variable", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, + ) + + tf_gen_op_wrapper_py( + name = name, + out = name + ".py", + deps = [ + ":%s_cc" % name, + ], + ) + + pytype_library( + name = name + "_grads", + srcs = [ + src, + ], + srcs_version = "PY2AND3", + deps = [ + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ] + deps, + ) + + pytype_library( + name = name + "_lib", + srcs = [ + name + ".py", + ], + srcs_version = "PY2AND3", + deps = [ + ":%s" % name, + ":%s_cc" % name, + ":%s_grads" % name, + "//third_party/py/numpy", + "//third_party/py/tensorflow", + ] + deps, + ) + + # Link the register op and rebuild the binary + gen_tfr_lib_exec = gen_op_lib_exec + "_registered" + py_binary( + name = gen_tfr_lib_exec, + main = src, + srcs = [src], + srcs_version = "PY2AND3", + python_version = "PY3", + deps = [ + "//tensorflow/python:platform", + ":%s" % name + "_cc", + ] + deps, + ) + + op_tfr = "composite_" + name + native.genrule( + name = op_tfr, + srcs = [], + outs = [name + ".mlir"], + cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, + exec_tools = [":" + gen_tfr_lib_exec], + local = 1, + tags = tags, + ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index f2e0e875de0..b4cbf765c79 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -1,8 +1,8 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", + "get_compatible_with_cloud", "tf_cc_binary", - "tf_cc_test", ) load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", @@ -42,10 +42,9 @@ cc_library( "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", - "//tensorflow/compiler/mlir/hlo:materialize_broadcasts", # buildcleaner: keep "//tensorflow/compiler/mlir/hlo:transform_unranked_hlo", # buildcleaner: keep - "//tensorflow/compiler/mlir/hlo:unfuse_batch_norm", # buildcleaner: keep "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", @@ -58,11 +57,8 @@ cc_library( "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", "//tensorflow/compiler/xla/service/mlir_gpu:passes", - "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", + "//tensorflow/core/platform:cuda_libdevice_path", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", @@ -72,9 +68,7 @@ cc_library( "@llvm-project//mlir:GPUTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Parser", @@ -85,8 +79,6 @@ cc_library( "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TargetNVVMIR", - "@llvm-project//mlir:TargetROCDLIR", "@llvm-project//mlir:Transforms", ], ) @@ -110,40 +102,13 @@ tf_cc_binary( ], ) -# We can test tf_to_gpu_binary by depending on the kernels that are generated -# using it. This only tests that the headers containing the hex dump of the -# generated GPU binaries are built successfully, but this is better than -# nothing. -# If this test fails, use the build flag --verbose_failures to see a full -# command line of the failing build command, then copy the command line and run -# it for further debugging. -tf_cc_test( - name = "tf_to_gpu_binary_test", - srcs = [], - # Building the MLIR generated kernels requires hexdump on the ROCM - # platform, which might not be available. - tags = ["no_rocm"], - deps = [ - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ] + if_cuda_is_configured([ - "//tensorflow/core/kernels/mlir_generated:abs_kernels", - "//tensorflow/core/kernels/mlir_generated:ceil_kernels", - "//tensorflow/core/kernels/mlir_generated:cos_kernels", - "//tensorflow/core/kernels/mlir_generated:exp_kernels", - "//tensorflow/core/kernels/mlir_generated:floor_kernels", - "//tensorflow/core/kernels/mlir_generated:log_kernels", - "//tensorflow/core/kernels/mlir_generated:neg_kernels", - "//tensorflow/core/kernels/mlir_generated:rsqrt_kernels", - "//tensorflow/core/kernels/mlir_generated:sqrt_kernels", - "//tensorflow/core/kernels/mlir_generated:tanh_kernels", - ]), -) - tf_cc_binary( name = "tf_to_kernel", srcs = ["tf_to_kernel.cc"], - visibility = ["//tensorflow/core/kernels/mlir_generated:__pkg__"], + visibility = [ + "//tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel:__pkg__", + "//tensorflow/core/kernels/mlir_generated:__pkg__", + ], deps = [ ":kernel_creator", "//tensorflow/compiler/mlir:init_mlir", @@ -194,3 +159,16 @@ cc_library( "@llvm-project//mlir:mlir_runner_utils", ], ) + +cc_library( + name = "tf_cuda_runtime_wrappers", + srcs = ["tf_cuda_runtime_wrappers.cc"], + compatible_with = get_compatible_with_cloud(), + copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]), + deps = [ + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:mlir_c_runner_utils", + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 690ee39113c..6c95323ed37 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -21,14 +21,12 @@ limitations under the License. //===----------------------------------------------------------------------===// #include "tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project @@ -42,7 +40,7 @@ limitations under the License. #include "mlir/Parser.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/Bufferize.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" @@ -71,12 +69,11 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors) { mlir::PassManager pm(module.getContext()); - applyPassManagerCLOptions(pm); - // TODO(b/169357508): renable when pipeline is serializable. - // SetCrashReproducer(pm); + applyTensorflowAndCLOptions(pm); - pm.addPass(mlir::mhlo::createLegalizeTFPass(false)); if (gpu_binary_only) { + pm.addPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/true)); pm.addNestedPass( mlir::kernel_gen::transforms::CreateMaterializeBroadcastsPass()); pm.addNestedPass( @@ -88,6 +85,9 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, pm.addNestedPass(mlir::createCopyRemovalPass()); pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); } else { + pm.addPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/false)); + pm.addPass(mlir::mhlo::createChloLegalizeToHloPass()); pm.addPass(mlir::createTransformUnrankedHloPass()); pm.addPass(mlir::kernel_gen::transforms::CreateShapeToDescriptorsPass()); pm.addPass(mlir::kernel_gen::transforms::CreateBufferizePass()); @@ -158,8 +158,6 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, // Approximate Tanh using standard operations. pm.addNestedPass<::mlir::FuncOp>( ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass()); - // Move scalar operations into the launch to ensure smaller signatures. - pm.addPass(xla::mlir_gpu::createMoveScalarComputationsIntoGpuLaunchPass()); // Take launches to launches with kernels. pm.addPass(::mlir::createGpuKernelOutliningPass()); @@ -178,23 +176,21 @@ Status LowerTFtoGPU(mlir::ModuleOp module, bool gpu_binary_only, Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, llvm::ArrayRef same_shape, llvm::StringRef gpu_binary_attr_name, - int32_t architecture) { + llvm::ArrayRef architectures, + bool generate_fatbin) { mlir::PassManager pm(module.getContext()); - applyPassManagerCLOptions(pm); - // TODO(b/169357508): renable when pipeline is serializable. - // SetCrashReproducer(pm); + applyTensorflowAndCLOptions(pm); auto& kernel_pm = pm.nest(); if (gpu_binary_only) { // Grab the original signature from the single function. - auto func = *module.getBody()->op_begin(); kernel_pm.addNestedPass( mlir::kernel_gen::transforms::CreatePropagateTensorFlowABIKnowledgePass( - func.getType(), same_shape)); + same_shape)); } kernel_pm.addPass(mlir::createStripDebugInfoPass()); kernel_pm.addPass(mlir::kernel_gen::transforms::CreateGpuKernelToBlobPass( - gpu_binary_attr_name, architecture)); + gpu_binary_attr_name, architectures, generate_fatbin)); if (!gpu_binary_only) { pm.addPass(mlir::kernel_gen::transforms::CreateTFKernelToLLVMPass()); @@ -209,9 +205,9 @@ Status LowerGPUToLLVM(mlir::ModuleOp module, bool gpu_binary_only, StatusOr GenerateKernelForTfCode( mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, - int32_t architecture, llvm::ArrayRef tile_sizes, - llvm::ArrayRef same_shape, - llvm::ArrayRef unroll_factors) { + llvm::ArrayRef architectures, + llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, + llvm::ArrayRef unroll_factors, bool generate_fatbin) { mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry()); mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); TF_RETURN_IF_ERROR( @@ -228,7 +224,8 @@ StatusOr GenerateKernelForTfCode( TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); #endif TF_RETURN_IF_ERROR(LowerGPUToLLVM(module.get(), gpu_binary_only, same_shape, - kGpuBinaryAttrName, architecture)); + kGpuBinaryAttrName, architectures, + generate_fatbin)); return module; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index b168ec815de..6767944d539 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -38,9 +38,10 @@ namespace kernel_gen { // false, lowers the host side to LLVM Dialect. xla::StatusOr GenerateKernelForTfCode( mlir::MLIRContext& context, llvm::StringRef tf_code, bool gpu_binary_only, - int32_t architecture = 75, llvm::ArrayRef tile_sizes = {16, 64}, + llvm::ArrayRef architectures = {"sm_75"}, + llvm::ArrayRef tile_sizes = {16, 64}, llvm::ArrayRef same_shape = {}, - llvm::ArrayRef unroll_factors = {}); + llvm::ArrayRef unroll_factors = {}, bool generate_fatbin = true); // Extracts gpu_binary from the converted module. xla::StatusOr ExtractGpuBinary(mlir::ModuleOp module); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir index 762962fc8e6..1a278365464 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir @@ -60,3 +60,19 @@ func @dynamic_tensor_from_elements(%arg : tensor<*xf32>) -> tensor { return %result : tensor } +// CHECK-LABEL: @assuming +// CHECK-SAME: (%[[WITNESS:.*]]: !shape.witness, %[[ARG:.*]]: memref) +// CHECK-SAME: -> memref +func @assuming(%witness: !shape.witness, %arg : memref) + -> tensor { + // CHECK-NEXT: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] + // CHECK-SAME: -> (memref) { + // CHECK-NEXT: shape.assuming_yield %[[ARG]] : memref + // CHECK-NEXT: } + // CHECK-NEXT: return %[[ASSUMING_RESULT]] : memref + %assuming_result = shape.assuming %witness -> (tensor) { + %result = tensor_load %arg : memref + shape.assuming_yield %result : tensor + } + return %assuming_result : tensor +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir index 64c7ff28d6e..2fc585d9e9d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf-legalize-to-mlhlo.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --mhlo-test-chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize +// RUN: tf-opt %s --xla-legalize-tf='legalize-chlo=false' | mlir-hlo-opt --transform-unranked-hlo --chlo-legalize-to-hlo | kernel-gen-opt --shape-to-descriptors --canonicalize --bufferize func @acos(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Acos"(%arg0) { } : (tensor<*xf32>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir index 0b2834accee..b943321e95b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_framework_legalize_to_llvm.mlir @@ -55,7 +55,7 @@ func @alloc_raw(%ctx: !tf_framework.op_kernel_context, // ----- -// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr) +// CHECK: llvm.func @_mlir_ciface_tf_dealloc_raw(!llvm.ptr, !llvm.ptr) // CHECK-LABEL: llvm.func @dealloc_raw( // CHECK-SAME: [[TF_CTX:%.*]]: !llvm.ptr, diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir index bfcd1d5c931..edb023e5fe7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/abs.mlir @@ -1,4 +1,4 @@ -// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70 +// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70 func @abs(%arg0: tensor) -> tensor { %0 = "tf.Abs"(%arg0) { } : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir index 27f9a046c42..25b79c47f4e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/ceil.mlir @@ -1,4 +1,4 @@ -// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70 +// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70 func @ceil(%arg0: tensor) -> tensor { %0 = "tf.Ceil"(%arg0) { } : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir index e596c338b14..69632f498a9 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_gpu_binary/tanh.mlir @@ -1,6 +1,5 @@ -// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=70 +// RUN: tf_to_gpu_binary --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70 func @tanh(%arg0: tensor) -> tensor { - %0 = "tf.Tanh"(%arg0) { } - : (tensor) -> tensor + %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD new file mode 100644 index 00000000000..24e288c246c --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD @@ -0,0 +1,17 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [ + "//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel", + "@llvm-project//mlir:run_lit.sh", + ], + default_tags = [ + # We need access to the CUDA SDK. + "gpu", + "no_rocm", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir new file mode 100644 index 00000000000..85bea1795a5 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/tanh.mlir @@ -0,0 +1,6 @@ +// RUN: tf_to_kernel --input=%s --output=%t --same_shape=0,1 --unroll_factors=4 --tile_sizes=256 --arch=sm_70,compute_75 + +func @tanh(%arg: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc new file mode 100644 index 00000000000..06d613e0599 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_cuda_runtime_wrappers.cc @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements C wrappers around the CUDA library for easy linking in ORC jit. +// Also adds some debugging helpers that are helpful when writing MLIR code to +// run on GPUs. + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" + +#define CUDA_REPORT_IF_ERROR(expr) \ + [](CUresult result) { \ + if (!result) \ + return; \ + const char *name = nullptr; \ + cuGetErrorName(result, &name); \ + if (!name) \ + name = ""; \ + llvm::errs() << "'" << #expr << "' failed with '" << name << "'\n"; \ + }(expr) + +extern "C" CUmodule mgpuModuleLoad(void *data) { + CUmodule module = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); + return module; +} + +extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) { + CUfunction function = nullptr; + CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); + return function; +} + +// The wrapper uses intptr_t instead of CUDA's unsigned int to match +// the type of MLIR's index type. This avoids the need for casts in the +// generated MLIR code. +extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, CUstream stream, + void **params, void **extra) { + CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, + blockY, blockZ, smem, stream, params, + extra)); +} + +extern "C" CUstream mgpuStreamCreate() { + static CUstream stream = []() { + // TODO(b/170649852): This is neither thread-safe nor handles + // creation/descruction of one stream per context. + CUstream stream = nullptr; + CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); + return stream; + }(); + return stream; +} + +extern "C" void mgpuStreamSynchronize(CUstream stream) { + CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); +} + +/// Helper functions for writing mlir example code + +// Allows to register byte array with the CUDA runtime. Helpful until we have +// transfer functions implemented. +extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { + CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); +} + +// Allows to register a MemRef with the CUDA runtime. Helpful until we have +// transfer functions implemented. +extern "C" void +mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType *descriptor, + int64_t elementSizeBytes) { + + llvm::SmallVector denseStrides(rank); + llvm::ArrayRef sizes(descriptor->sizes, rank); + llvm::ArrayRef strides(sizes.end(), rank); + + std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), + std::multiplies()); + auto sizeBytes = denseStrides.front() * elementSizeBytes; + + // Only densely packed tensors are currently supported. + std::rotate(denseStrides.begin(), denseStrides.begin() + 1, + denseStrides.end()); + denseStrides.back() = 1; + assert(strides == llvm::makeArrayRef(denseStrides)); + + auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; + mgpuMemHostRegister(ptr, sizeBytes); +} + +#endif diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc index c7cb92404f5..84c2bf46b55 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_gpu_binary.cc @@ -35,7 +35,7 @@ namespace kernel_gen { namespace { xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, - int32_t architecture, llvm::ArrayRef tile_sizes, + std::string architecture, llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, llvm::ArrayRef unroll_factors) { // Read TF code. @@ -48,7 +48,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, mlir::OwningModuleRef module, GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/true, architecture, tile_sizes, same_shape, - unroll_factors)); + unroll_factors, /*generate_fatbin=*/false)); // Extract gpu_binary. TF_ASSIGN_OR_RETURN(std::string gpu_binary, ExtractGpuBinary(*module)); @@ -69,9 +69,9 @@ int main(int argc, char** argv) { llvm::cl::opt output_file( "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), llvm::cl::init("foo.bin")); - llvm::cl::opt architecture( - "arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"), - llvm::cl::init(50)); + llvm::cl::opt architecture( + "arch", llvm::cl::desc("target architecture (e.g. sm_50)"), + llvm::cl::init("sm_50")); llvm::cl::list tile_sizes( "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc index 2caa806551e..4cbe2e633f8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -95,7 +95,8 @@ xla::StatusOr EmitToBinary(mlir::ModuleOp module) { } xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, - int32_t architecture, llvm::ArrayRef tile_sizes, + llvm::ArrayRef architectures, + llvm::ArrayRef tile_sizes, llvm::ArrayRef same_shape, llvm::ArrayRef unroll_factors) { // Read TF code. @@ -107,7 +108,7 @@ xla::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, GenerateKernelForTfCode(context, tf_code, /*gpu_binary_only=*/false, - architecture, tile_sizes, same_shape, + architectures, tile_sizes, same_shape, unroll_factors)); // Get binary. TF_ASSIGN_OR_RETURN(std::string binary, EmitToBinary(*module)); @@ -129,9 +130,9 @@ int main(int argc, char** argv) { llvm::cl::opt output_file( "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), llvm::cl::init("foo.bin")); - llvm::cl::opt architecture( - "arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"), - llvm::cl::init(50)); + llvm::cl::list architectures( + "arch", llvm::cl::desc("target architectures (e.g. sm_70 or compute_75)"), + llvm::cl::OneOrMore, llvm::cl::CommaSeparated); llvm::cl::list tile_sizes( "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); @@ -151,7 +152,7 @@ int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); auto status = - tensorflow::kernel_gen::Run(input_file, output_file, architecture, + tensorflow::kernel_gen::Run(input_file, output_file, architectures, tile_sizes, same_shape, unroll_factors); if (!status.ok()) { LOG(ERROR) << status; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index ba0fc0525db..b2595d2ad3a 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -93,7 +93,7 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", - "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core/platform:cuda_libdevice_path", "//tensorflow/core:lib", ":bufferize", ":embed_tf_framework", @@ -114,8 +114,10 @@ cc_library( "@llvm-project//mlir:SCFToStandard", "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@llvm-project//llvm:TransformUtils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:lhlo", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc index 397889415f7..f6a47cca6d0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -15,6 +15,8 @@ limitations under the License. // This file implements logic for translating mixed IR to buffer form. +#include "mlir/Transforms/Bufferize.h" // from @llvm-project + #include #include @@ -29,7 +31,6 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -140,7 +141,7 @@ class ExtractElementOpConversion ConversionPatternRewriter &rewriter) const final { ExtractElementOpAdaptor adaptor(operands); - if (!adaptor.aggregate().getType().isa()) { + if (!adaptor.aggregate().getType().isa()) { return failure(); } @@ -150,6 +151,23 @@ class ExtractElementOpConversion } }; +template +class SimpleOpResultConversion + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + OpTy>::BufferAssignmentOpConversionPattern; + using BufferAssignmentOpConversionPattern::converter; + + LogicalResult matchAndRewrite( + OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp(op, converter.convertType(op.getType()), + operands); + return success(); + } +}; + class TensorCastOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -160,9 +178,9 @@ class TensorCastOpConverter TensorCastOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Value arg = operands.front(); - if (!arg.getType().isa()) return failure(); + if (!arg.getType().isa()) return failure(); - auto result_ty = converter->convertType(op.getType()); + auto result_ty = converter.convertType(op.getType()); rewriter.replaceOpWithNewOp(op, arg, result_ty); return success(); @@ -175,8 +193,9 @@ void populateStandardBufferizePattern(MLIRContext *context, BufferAssignmentTypeConverter *converter, OwningRewritePatternList *patterns) { patterns->insert(context, converter); + DynamicTensorFromElementsOpConverter, + SimpleOpResultConversion, TensorLoadOpConversion, + TensorCastOpConverter>(context, *converter); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index 8ddbb15219f..1ddc44123fa 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -26,7 +28,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/Bufferize.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -79,8 +81,7 @@ struct BufferizePass : public BufferizePassBase { target.addLegalOp(); target.addIllegalDialect(); target.addIllegalOp(); + TensorFromElementsOp, TensorLoadOp, TensorCastOp>(); target.addDynamicallyLegalOp([&](TensorStoreOp op) { return !op.tensor().getType().isa(); }); @@ -96,15 +97,16 @@ struct BufferizePass : public BufferizePassBase { return converter.isLegal(inputs) && converter.isLegal(results) && converter.isLegal(&op.getBody()); }); - target.addDynamicallyLegalOp(typesAreLegal); - target.addDynamicallyLegalOp(typesAreLegal); + target.addDynamicallyLegalOp( + typesAreLegal); OwningRewritePatternList patterns; mhlo::populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateWithBufferAssignmentOpConversionPatterns( - &context, &converter, &patterns); + &context, converter, patterns); populateStandardBufferizePattern(&context, &converter, &patterns); + populateShapeTypeConversionPatterns(&context, converter, patterns); patterns.insert(&context); auto module = getOperation(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 773e12f2da3..46bf13b7d20 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/Transforms/Utils/Cloning.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Target/NVVMIR.h" // from @llvm-project #include "mlir/Target/ROCDLIR.h" // from @llvm-project @@ -49,9 +50,12 @@ using xla::InternalError; class GpuKernelToBlobPass : public GpuKernelToBlobPassBase { public: - GpuKernelToBlobPass(mlir::StringRef blob_annotation, int32_t arch) { - blob_annotation_ = blob_annotation; - arch_ = arch; + GpuKernelToBlobPass(mlir::StringRef blob_annotation, + llvm::ArrayRef architectures, + bool generate_fatbin) { + blob_annotation_ = blob_annotation.str(); + architectures_ = architectures; + generate_fatbin_ = generate_fatbin; } void runOnOperation() override { @@ -69,7 +73,17 @@ class GpuKernelToBlobPass xla::StatusOr> GetGpuBinaryBlob( mlir::gpu::GPUModuleOp gpu_module) { + if (architectures_.empty()) { + return InternalError("Expected at least one GPU architecture."); + } + if (!generate_fatbin_ && architectures_.size() > 1) { + return InternalError( + "Can only generate machine code for more than one architecture as a " + "fatbin."); + } + llvm::LLVMContext llvmContext; + #if TENSORFLOW_USE_ROCM auto llvmModule = mlir::translateModuleToROCDLIR(gpu_module, llvmContext); if (!llvmModule) { @@ -81,9 +95,24 @@ class GpuKernelToBlobPass xla::HloModuleConfig config; config.set_debug_options(xla::GetDebugOptionsFromFlags()); - std::string libdevice_dir = tensorflow::RocdlRoot(); + // TODO(b/169066682): Support fatbin on ROCm. + if (generate_fatbin_) { + return InternalError("Fatbins are not yet supported for ROCm."); + } - return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch_, config, + // Parse ROCm architecture. + absl::string_view consumable_arch(architectures_.front()); + if (!absl::ConsumePrefix(&consumable_arch, "gfx")) { + return InternalError( + "Could not parse ROCm architecture prefix (expected gfx)"); + } + uint32_t arch; + if (!absl::SimpleAtoi(consumable_arch, &arch)) { + return InternalError("Could not parse ROCm architecture number"); + } + + std::string libdevice_dir = tensorflow::RocdlRoot(); + return xla::gpu::amdgpu::CompileToHsaco(llvmModule.get(), arch, config, libdevice_dir); #elif GOOGLE_CUDA @@ -102,19 +131,65 @@ class GpuKernelToBlobPass target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; }; - int32_t cc_major = arch_ / 10; - int32_t cc_minor = arch_ % 10; + // Compile and collect requested cubin and PTX images. + std::vector images; TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); - TF_ASSIGN_OR_RETURN( - std::string ptx, - xla::gpu::nvptx::CompileToPtx(llvmModule.get(), - std::make_pair(cc_major, cc_minor), - config, libdevice_dir, enable_fusion)); - VLOG(1) << ptx; + auto gpu_asm_opts = xla::gpu::PtxOptsFromConfig(config); + for (const std::string& arch_str : architectures_) { + // Parse CUDA architecture. + absl::string_view consumable_arch(arch_str); + bool is_compute_profile; + if (absl::ConsumePrefix(&consumable_arch, "compute_")) { + is_compute_profile = true; + } else if (absl::ConsumePrefix(&consumable_arch, "sm_")) { + is_compute_profile = false; + } else { + return InternalError( + "Could not parse cuda architecture prefix (expected sm_ or " + "compute_)"); + } + uint32_t arch; + if (!absl::SimpleAtoi(consumable_arch, &arch)) { + return InternalError("Could not parse cuda architecture number"); + } - return tensorflow::se::CompileGpuAsm(cc_major, cc_minor, ptx.c_str(), - xla::gpu::PtxOptsFromConfig(config)); + uint32_t cc_major = arch / 10; + uint32_t cc_minor = arch % 10; + // Module may be changed by CompileToPtx. + auto llvm_module_copy = llvm::CloneModule(*llvmModule); + TF_ASSIGN_OR_RETURN( + std::string ptx, + xla::gpu::nvptx::CompileToPtx(llvm_module_copy.get(), + std::make_pair(cc_major, cc_minor), + config, libdevice_dir, enable_fusion)); + VLOG(1) << ptx; + TF_ASSIGN_OR_RETURN(std::vector gpu_asm, + tensorflow::se::CompileGpuAsm( + cc_major, cc_minor, ptx.c_str(), gpu_asm_opts)); + + if (!generate_fatbin_) { + // Skip fatbin generation and return the first and only GPU machine + // code. This is currently only used for `tf_to_gpu_binary` and will + // eventually disappear. + return gpu_asm; + } + + // Collect cubin (and ptx image if requested). + images.push_back({absl::StrCat("sm_", arch), std::move(gpu_asm)}); + if (is_compute_profile) { + std::vector ptx_bytes; + std::copy(ptx.begin(), ptx.end(), std::back_inserter(ptx_bytes)); + images.push_back( + {absl::StrCat("compute_", arch), std::move(ptx_bytes)}); + } + } + + // TODO(b/169870789): Revisit the use of fatbins. + // Bundle cubin and PTX images into a single fatbin. + return tensorflow::se::BundleGpuAsm(images, + gpu_asm_opts.preferred_cuda_dir); #endif + return InternalError( "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." " Did you specify either --config=rocm or --config=cuda ?"); @@ -141,8 +216,10 @@ class GpuKernelToBlobPass } // namespace std::unique_ptr> CreateGpuKernelToBlobPass( - mlir::StringRef blob_annotation, int32_t architecture) { - return std::make_unique(blob_annotation, architecture); + mlir::StringRef blob_annotation, ArrayRef architectures, + bool generate_fatbin) { + return std::make_unique(blob_annotation, architectures, + generate_fatbin); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 967473c3fc9..5fd4091b2c0 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -57,11 +57,12 @@ std::unique_ptr CreateParallelLoopsToSequential(); // Pass to propagate TF ABI knowledge, e.g. offsets, alignment. std::unique_ptr> CreatePropagateTensorFlowABIKnowledgePass( - mlir::FunctionType type = {}, llvm::ArrayRef same_shape = {}); + llvm::ArrayRef same_shape = {}); // Pass to annotate GPU Module with its PTX. std::unique_ptr> CreateGpuKernelToBlobPass( - mlir::StringRef blob_annotation = "", int32_t architecture = 0); + mlir::StringRef blob_annotation = "", + ArrayRef architectures = {}, bool generate_fatbin = true); // Pass to unfuse batch norm. std::unique_ptr CreateUnfuseBatchNormPass(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index b06ace853ed..a8b2506bd1c 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -51,9 +51,12 @@ def UnfuseBatchNormPass : FunctionPass<"unfuse-batch-norm"> { def GpuKernelToBlobPass : Pass<"gpu-kernel-to-blob", "gpu::GPUModuleOp"> { let summary = "Pass to annotate GPU Module with its PTX"; let options = [ - Option<"blob_annotation_", "blob-annotation", "mlir::StringRef", + Option<"blob_annotation_", "blob-annotation", "std::string", /*default=*/"", "Blob attribute name">, - Option<"arch_", "arch", "int32_t", /*default=*/"0", "GPU architecture">, + ListOption<"architectures_", "arch", "std::string", "GPU architectures">, + Option<"generate_fatbin_", "generate-fatbin", "bool", /*default=*/"true", + "Bundle machine code for the different architectures in one " + "fatbin.">, ]; let constructor = "transforms::CreateGpuKernelToBlobPass()"; } @@ -67,8 +70,6 @@ def PropagateTensorFlowABIKnowledgePass : Pass<"propagate-tf-abi-knowledge", "LLVM::LLVMFuncOp"> { let summary = "Pass to propagate TF ABI knowledge, e.g. offsets, alignment"; let options = [ - Option<"func_type_", "func-type", "mlir::FunctionType", - /*default=*/"", "Function type">, ListOption<"same_shape_", "same-shape", "uint32_t", "List of same shape args">, ]; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc index 57a5fec527a..3b568f5f25f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/propagate_tf_abi_knowledge_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" @@ -28,8 +30,7 @@ struct PropagateTensorFlowABIKnowledgePass : public PropagateTensorFlowABIKnowledgePassBase< PropagateTensorFlowABIKnowledgePass> { explicit PropagateTensorFlowABIKnowledgePass( - mlir::FunctionType type, llvm::ArrayRef same_shape) { - func_type_ = type; + llvm::ArrayRef same_shape) { same_shape_ = same_shape; } @@ -49,6 +50,21 @@ struct PropagateTensorFlowABIKnowledgePass // This only works if the function is local and we can rewrite it. if (func.isExternal()) return; + auto function_list = + func.getParentOfType().getOps(); + if (function_list.empty()) { + func.emitError() << "No possible kernel function found"; + return signalPassFailure(); + } + auto func_iterator = function_list.begin(); + if (std::next(func_iterator) != function_list.end()) { + func.emitError() << "More than one possible kernel function detected"; + return signalPassFailure(); + } + // Note that this dereference is necessary to prevent a + // stack-use-after-return error. + auto func_type = (*func_iterator).getType(); + mlir::OpBuilder b(func.getBody()); // Steal the LLVM representation of the index type from the third argument. auto index_type = func.getArgument(3).getType(); @@ -60,7 +76,7 @@ struct PropagateTensorFlowABIKnowledgePass std::vector positions; // Collect the agument and return types of the surrounding function. auto arg_types = llvm::to_vector<4>(llvm::concat( - func_type_.getInputs(), func_type_.getResults())); + func_type.getInputs(), func_type.getResults())); for (mlir::Type arg_type : arg_types) { if (!arg_type.isa()) { func.emitError() << "argument of surrounding func is not ranked memref"; @@ -112,10 +128,8 @@ struct PropagateTensorFlowABIKnowledgePass } // namespace std::unique_ptr> -CreatePropagateTensorFlowABIKnowledgePass(mlir::FunctionType type, - llvm::ArrayRef same_shape) { - return std::make_unique(type, - same_shape); +CreatePropagateTensorFlowABIKnowledgePass(llvm::ArrayRef same_shape) { + return std::make_unique(same_shape); } } // namespace transforms diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc index ab66c513e33..f5d01808c1b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc @@ -49,6 +49,10 @@ struct ShapeToDescriptorsPass target.addIllegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + // Don't mark the primary Cstr/Assuming ops as illegal, so they can be + // lowered at a later time to assertions. + target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index 3ce111ff3ff..431919c2de7 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -171,7 +171,8 @@ class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern { protected: StringRef GetFuncName() const override { return kCInterfaceDealloc; } LLVMType GetFuncType() const override { - return LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), + return LLVM::LLVMType::getFunctionTy(getVoidType(), + {getVoidPtrType(), getVoidPtrType()}, /*isVarArg=*/false); } }; diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 689eb14e4af..1919446a365 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -136,6 +136,7 @@ cc_library( ":hlo_module_importer", ":hlo_utils", ":mlir_hlo_to_hlo", + ":translate_cl_options", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/xla:debug_options_flags", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 89c35bb2974..253156b44a5 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -140,31 +140,42 @@ tensorflow::Status HloFunctionImporter::ImportAsRegion( return ImportInstructions(computation, block); } -tensorflow::Status HloFunctionImporter::ImportInstructions( - const HloComputation& computation, mlir::Block* block) { +StatusOr HloFunctionImporter::ImportInstructionsImpl( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { // Setup the input parameters. const int num_parameters = computation.num_parameters(); + + if (arguments.size() != num_parameters) + return InvalidArgument("Caller vs callee argument sizes do not match"); + for (int i = 0; i < num_parameters; i++) { auto hlo_parameter = computation.parameter_instruction(i); - instruction_value_map_[hlo_parameter] = block->getArgument(i); + instruction_value_map_[hlo_parameter] = arguments[i]; } - mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); for (auto instruction : computation.MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(auto new_operation, - ImportInstruction(instruction, &builder)); + ImportInstruction(instruction, builder)); if (new_operation) { instruction_value_map_[instruction] = new_operation->getResult(0); } } + // Setup the return type (HLO only supports a single return value). + return GetMlirValue(computation.root_instruction()); +} + +Status HloFunctionImporter::ImportInstructions( + const HloComputation& computation, mlir::Block* block) { + llvm::SmallVector arguments(block->args_begin(), block->args_end()); + mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); + TF_ASSIGN_OR_RETURN(Value result, + ImportInstructionsImpl(computation, arguments, &builder)); + // TODO(suderman): Add location tracking details. mlir::Location loc = builder.getUnknownLoc(); - // Setup the return type (HLO only supports a single return value). - TF_ASSIGN_OR_RETURN(auto result, - GetMlirValue(computation.root_instruction())); - // Create terminator op depending on the parent op of this region. if (llvm::isa(block->getParentOp())) { builder.create(loc, result); @@ -174,15 +185,29 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( return tensorflow::Status::OK(); } -StatusOr HloFunctionImporter::ImportInstruction( +StatusOr HloFunctionImporter::ImportInstructions( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, mlir::OpBuilder* builder) { + mlir::Block* block = builder->getBlock(); + if (block == nullptr) + return InvalidArgument( + "ImportInstructions requires a valid block in the builder"); + + HloFunctionImporter importer( + block->getParent()->getParentOfType(), {}, builder); + return importer.ImportInstructionsImpl(computation, arguments, builder); +} + +StatusOr HloFunctionImporter::ImportInstructionImpl( HloInstruction* instruction, mlir::OpBuilder* func_builder) { TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction)); TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType( instruction->shape(), *builder_)); - llvm::SmallVector attributes = {builder_->getNamedAttr( - "name", builder_->getStringAttr(instruction->name()))}; - mlir::Location loc = func_builder->getUnknownLoc(); + mlir::Location loc = + mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()), + func_builder->getContext()); + llvm::SmallVector attributes; switch (instruction->opcode()) { case HloOpcode::kParameter: { return nullptr; @@ -214,8 +239,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return new_operation; \ } case HloOpcode::kBroadcast: { - // Note that the HLO broadcast is more powerful than the XLA broadcast op. - // BroadcastInDim offers a superset of the HLO op's functionality. + // Note that the HLO broadcast is more powerful than the XLA broadcast + // op. BroadcastInDim offers a superset of the HLO op's functionality. attributes.push_back( builder_->getNamedAttr("broadcast_dimensions", ConvertDimensions(instruction->dimensions()))); @@ -417,13 +442,27 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kSort: { auto sort_instruction = Cast(instruction); + + llvm::SmallVector return_types = {result_type}; + if (mlir::TupleType tuple_ty = result_type.dyn_cast()) { + return_types = llvm::to_vector<6>(tuple_ty.getTypes()); + } + auto sort_op = func_builder->create( - loc, result_type, operands, + loc, return_types, operands, builder_->getI64IntegerAttr(sort_instruction->sort_dimension()), builder_->getBoolAttr(sort_instruction->is_stable())); TF_RETURN_IF_ERROR( ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator())); - return sort_op.getOperation(); + + // Check if the output needs to be tupled. + if (return_types.size() == 1 && return_types.front() == result_type) { + return sort_op.getOperation(); + } + + return func_builder + ->create(loc, result_type, sort_op.getResults()) + .getOperation(); } case HloOpcode::kConditional: { llvm::SmallVector rets; @@ -444,7 +483,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return op.getOperation(); } - // Otherwise, it is a indexed conditional and should be mapped to Case op. + // Otherwise, it is a indexed conditional and should be mapped to Case + // op. TF_RETURN_IF_ERROR(GetMlirTypes( {instruction->branch_computation(0)->root_instruction()}, &rets)); @@ -460,8 +500,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return op.getOperation(); } case HloOpcode::kConcatenate: { - // TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr - // for concatenate dimension. + // TODO(b/132057942): Support taking an uint64_t instead of an + // IntegerAttr for concatenate dimension. return func_builder ->create( loc, result_type, operands, @@ -665,6 +705,7 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kAnd, AndOp); NoAttributeCase(kAtan2, Atan2Op); NoAttributeCase(kBitcastConvert, BitcastConvertOp); + NoAttributeCase(kCbrt, CbrtOp); NoAttributeCase(kConvert, ConvertOp); NoAttributeCase(kCeil, CeilOp); NoAttributeCase(kClamp, ClampOp); @@ -689,9 +730,9 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kReal, RealOp); NoAttributeCase(kRemainder, RemOp); NoAttributeCase(kReplicaId, ReplicaIdOp); - // The dimensions attribute is not present on the HLO Reshape instruction. - // If dimensions are non-default, the XLA builder implements it as a - // separate transpose. + // The dimensions attribute is not present on the HLO Reshape + // instruction. If dimensions are non-default, the XLA builder + // implements it as a separate transpose. NoAttributeCase(kReshape, ReshapeOp); NoAttributeCase(kRoundNearestAfz, RoundOp); NoAttributeCase(kRsqrt, RsqrtOp); @@ -706,9 +747,9 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kTanh, TanhOp); NoAttributeCase(kTuple, TupleOp); NoAttributeCase(kXor, XorOp); - // TODO(b/129422361) Copy needs special handling because it is not defined - // in tensorflow/compiler/xla/client/xla_builder.h. - // See operation semantics in + // TODO(b/129422361) Copy needs special handling because it is not + // defined in tensorflow/compiler/xla/client/xla_builder.h. See + // operation semantics in // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy NoAttributeCase(kCopy, CopyOp); #undef NoAttributeCase @@ -722,6 +763,20 @@ StatusOr HloFunctionImporter::ImportInstruction( &fusion.fused_computation())); return fusion.getOperation(); } + case HloOpcode::kBitcast: + return func_builder + ->create(loc, result_type, operands, + attributes) + .getOperation(); + case HloOpcode::kReducePrecision: { + auto op = func_builder->create( + loc, result_type, operands[0], attributes); + op.exponent_bitsAttr(func_builder->getIntegerAttr( + func_builder->getI32Type(), instruction->exponent_bits())); + op.mantissa_bitsAttr(func_builder->getIntegerAttr( + func_builder->getI32Type(), instruction->mantissa_bits())); + return op.getOperation(); + } case HloOpcode::kAddDependency: // Arbitrary op code that I suspect we will not implement for quite a // while and allows testing handling of unknown ops. Selected because it @@ -740,6 +795,28 @@ StatusOr HloFunctionImporter::ImportInstruction( } } +StatusOr HloFunctionImporter::ImportInstruction( + HloInstruction* instruction, mlir::OpBuilder* func_builder) { + TF_ASSIGN_OR_RETURN(mlir::Operation * op, + ImportInstructionImpl(instruction, func_builder)); + if (op == nullptr) return op; + + // See MlirToHloConversionOptions for more about layouts. + // + // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions + // in physical minor-to-major order. + if (instruction->shape().IsArray() && + instruction->shape().layout() != + LayoutUtil::MakeDescendingLayout( + instruction->shape().dimensions().size())) { + llvm::SmallVector minor_to_major( + instruction->shape().layout().minor_to_major().begin(), + instruction->shape().layout().minor_to_major().end()); + op->setAttr("minor_to_major", builder_->getIndexTensorAttr(minor_to_major)); + } + return op; +} + StatusOr> HloFunctionImporter::GetOperands( HloInstruction* instruction) { llvm::SmallVector operands; diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index e0cc89004cf..4a75b079d76 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -55,6 +55,13 @@ class HloFunctionImporter { static Status ImportAsRegion(const xla::HloComputation& computation, mlir::Region* region, mlir::Builder* builder); + // Imports the given computation to the given place specified by `builder`. + // `arguments` contains values for all parameters. + static StatusOr ImportInstructions( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); + private: HloFunctionImporter(mlir::ModuleOp module, std::unordered_map ImportInstructionsImpl( + const xla::HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); // Imports an instruction. StatusOr ImportInstruction(xla::HloInstruction* instruction, mlir::OpBuilder* func_builder); + StatusOr ImportInstructionImpl( + HloInstruction* instruction, mlir::OpBuilder* func_builder); // Gets the MLIR operand values from an HLO Instruction. StatusOr> GetOperands( diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 3da893548bb..daea2d9b8f6 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -243,11 +243,22 @@ StatusOr MlirHloBuilder::SortInternal(const Shape& shape, int64 dimension, bool is_stable) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); + llvm::SmallVector sort_types = {ty}; + if (auto tuple_ty = ty.dyn_cast()) { + sort_types = llvm::to_vector<6>(tuple_ty.getTypes()); + } + auto op = builder_.create( - loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension), - builder_.getBoolAttr(is_stable)); + loc_, sort_types, GetValues(operands), + builder_.getI64IntegerAttr(dimension), builder_.getBoolAttr(is_stable)); TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator())); - return MakeXlaOp(op); + + if (ty.isa()) { + auto tuple = builder_.create(loc_, op.getResults()); + return MakeXlaOp(tuple); + } + + return MakeXlaOp(op.getResult(0)); } StatusOr MlirHloBuilder::WhileInternal(const Shape& shape, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index d6ef39d03dd..ccfcebab60e 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -499,12 +499,14 @@ class ConvertToHloModule { // single value. explicit ConvertToHloModule( mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, - tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + MlirToHloConversionOptions options) : module_(module), module_builder_("main"), use_tuple_args_(use_tuple_args), return_tuple_(return_tuple), - shape_representation_fn_(shape_representation_fn) { + shape_representation_fn_(shape_representation_fn), + options_(options) { if (!shape_representation_fn_) shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn(); } @@ -585,6 +587,8 @@ class ConvertToHloModule { // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; + + MlirToHloConversionOptions options_; }; } // namespace @@ -1008,9 +1012,14 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { &comparator))) return failure(); + auto tupled = xla::Sort(GetTuple(op.operands(), ctx), comparator, + op.dimension(), op.is_stable()); + auto& value_map = *ctx.values; - value_map[op] = xla::Sort(GetTuple(op.operands(), ctx), comparator, - op.dimension(), op.is_stable()); + // MLIR's sort supports multiple returns, untuple all the results of XLA's. + for (auto it : llvm::enumerate(op.getResults())) { + value_map[it.value()] = xla::GetTupleElement(tupled, it.index()); + } return success(); } @@ -1059,7 +1068,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { llvm::SmallVector operands; for (auto operand : op.operands()) operands.push_back(values[operand]); - xla::XlaOp fusion = xla::internal::XlaBuilderBuildFusion( + xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion( ctx.builder, operands, absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()), fused_computation); @@ -1073,6 +1082,15 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast( + ctx.builder, operand, xla::TypeToShape(op.getType())); + return success(); +} + } // namespace } // namespace mhlo } // namespace mlir @@ -1082,18 +1100,19 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { namespace mlir { namespace { -StatusOr CreateLiteralFromAttr(ElementsAttr attr) { +StatusOr CreateArrayLiteralFromAttr(ElementsAttr attr, + xla::Layout layout) { if (attr.isa()) return tensorflow::errors::Unimplemented( "Opaque elements attr not supported"); xla::Shape shape = xla::TypeToShape(attr.getType()); -#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ - case xla_type: { \ - xla::Array source_data(shape.dimensions()); \ - source_data.SetValues(attr.getValues()); \ - return xla::LiteralUtil::CreateFromArray(source_data); \ +#define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \ + case xla_type: { \ + xla::Array source_data(shape.dimensions()); \ + source_data.SetValues(attr.getValues()); \ + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \ } switch (shape.element_type()) { @@ -1123,7 +1142,7 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { } xla::Array source_data(shape.dimensions()); source_data.SetValues(values); - return xla::LiteralUtil::CreateFromArray(source_data); + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); } case xla::PrimitiveType::BF16: { xla::Array source_data(shape.dimensions()); @@ -1140,7 +1159,7 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { } source_data.SetValues(values_double); return xla::LiteralUtil::ConvertF64ToBF16( - xla::LiteralUtil::CreateFromArray(source_data)); + xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout)); } default: return tensorflow::errors::Internal(absl::StrCat( @@ -1149,13 +1168,46 @@ StatusOr CreateLiteralFromAttr(ElementsAttr attr) { #undef ELEMENTS_ATTR_TO_LITERAL } +xla::Layout ExtractLayout(mlir::Operation* op, int rank) { + if (auto attr = + op->getAttrOfType("minor_to_major")) { + llvm::SmallVector minor_to_major; + minor_to_major.reserve(attr.size()); + for (const llvm::APInt& i : attr) { + minor_to_major.push_back(i.getZExtValue()); + } + return xla::LayoutUtil::MakeLayout(minor_to_major); + } + return xla::LayoutUtil::MakeDescendingLayout(rank); +} + LogicalResult ConvertToHloModule::Lower( mlir::Operation* inst, bool is_entry_function, llvm::ArrayRef> ret_shardings, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaComputation* result) { + // See MlirToHloConversionOptions for more about layouts. + auto propagate_layouts = [this](mlir::Operation* inst, xla::XlaOp xla_op) { + if (options_.propagate_layouts) { + auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op) + ->mutable_shape(); + if (shape->tuple_shapes().empty()) + *shape->mutable_layout() = + ExtractLayout(inst, shape->dimensions().size()).ToProto(); + } + }; + if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) { + if (inst->getNumResults() == 1) { + auto iter = value_lowering->find(inst->getResult(0)); + if (iter == value_lowering->end()) { + inst->emitOpError( + "inst has a result, but it's not found in value_lowering"); + return failure(); + } + propagate_layouts(inst, iter->second); + } return success(); } @@ -1181,16 +1233,19 @@ LogicalResult ConvertToHloModule::Lower( if (failed(GetXlaOp(operand, value_map, &xla_operand, op))) return failure(); value_map[op.getResult()] = xla_operand; + propagate_layouts(inst, xla_operand); return success(); } - // TODO(jpienaar): This doesn't support layouts yet. if (matchPattern(inst, m_Constant(&const_attr))) { - auto literal_or = CreateLiteralFromAttr(const_attr); + xla::Layout layout; + layout = ExtractLayout(inst, const_attr.getType().getRank()); + auto literal_or = CreateArrayLiteralFromAttr(const_attr, layout); if (!literal_or.ok()) return inst->emitError(literal_or.status().ToString()); - value_map[inst->getResult(0)] = - xla::ConstantLiteral(builder, literal_or.ValueOrDie()); + auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie()); + value_map[inst->getResult(0)] = constant; + return success(); } @@ -1643,22 +1698,24 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, } // namespace Status ConvertRegionToComputation(mlir::Region* region, - xla::XlaComputation* func) { + xla::XlaComputation* func, + MlirToHloConversionOptions options) { mlir::ModuleOp module; - ConvertToHloModule converter(module, true, true, {}); + ConvertToHloModule converter(module, true, true, {}, options); if (failed(converter.LowerRegionAsComputation(region, func))) return tensorflow::errors::Internal( "failed to convert region to computation"); return Status::OK(); } -Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple, - const tensorflow::XlaHelpers::ShapeRepresentationFn - shape_representation_fn) { +Status ConvertMlirHloToHlo( + mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, + bool return_tuple, + const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + MlirToHloConversionOptions options) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); ConvertToHloModule converter(module, use_tuple_args, return_tuple, - shape_representation_fn); + shape_representation_fn, options); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 6f2b5a6db95..4ca3e586128 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -25,6 +25,18 @@ limitations under the License. namespace mlir { +struct MlirToHloConversionOptions { + // Best-effort propagation of the layouts. These layouts serve as performance + // hints to the backend. + // + // Note that non-array shapes are not carrying layouts, and users have to + // figure out the proper layouts of them through context. This is one of the + // reasons why the attribute-based solution is temporary. + // + // TODO(timshen): Investigate the necessity of having layouts in MHLO. + bool propagate_layouts = false; +}; + // Converts a MLIR module in HLO dialect into a HloModuleProto. If // use_tuple_args is set, then the entry computations's arguments are converted // to a tuple and passed as a single parameter. @@ -32,15 +44,19 @@ namespace mlir { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. +// +// TODO(timshen): move other options into `options`. Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, const tensorflow::XlaHelpers::ShapeRepresentationFn - shape_representation_fn = nullptr); + shape_representation_fn = nullptr, + MlirToHloConversionOptions options = {}); // Converts a region to a computation. It returns a standalone module that // contains the converted region as the entry computation. Status ConvertRegionToComputation(mlir::Region* region, - ::xla::XlaComputation* func); + ::xla::XlaComputation* func, + MlirToHloConversionOptions options = {}); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt new file mode 100644 index 00000000000..781e203510b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/fusion_layouts.hlotxt @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s + +HloModule TestModule + +// CHECK: func @TestComputation + +FusedComputation { + // CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>} + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} add(x, x) +} + +ENTRY TestComputation { + x = f32[3, 2]{0,1} parameter(0) + ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index 2e1b63b0db7..e7312e2114c 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -316,12 +316,61 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW0:.*]] = std.view %[[ARG2]]{{.*}} : memref<100xi8> to memref<5x5xi32> // CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32> // CHECK: "lmhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]]) -func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple, tensor<5x5xf32>> { - %res = "mhlo.sort"(%key, %value) ({ +func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> (tensor<5x5xi32>, tensor<5x5xf32>) { + %res:2 = "mhlo.sort"(%key, %value) ({ ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %ret = "mhlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%ret) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> tuple, tensor<5x5xf32>> + }) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> (tensor<5x5xi32>, tensor<5x5xf32>) - return %res : tuple, tensor<5x5xf32>> + return %res#0, %res#1 : tensor<5x5xi32>, tensor<5x5xf32> +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[ARG0:.*]]: memref {{.*}}lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref {{.*}}lmhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8> +// CHECK: "lmhlo.fusion"() ( { +// CHECK: %[[VAR0:.*]] = tensor_load %[[ARG0]] : memref +// CHECK: %[[VAR1:.*]] = tensor_load %[[ARG1]] : memref +// CHECK: %[[VAR2:.*]] = mhlo.add %[[VAR0]], %[[VAR1]] : tensor +// CHECK: tensor_store %[[VAR2]], %[[MEMREF:.*]] : memref +// CHECK: "lmhlo.terminator"() : () -> () +// CHECK: }) : () -> () +func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %result = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %result = "mhlo.add"(%arg2, %arg3): (tensor, tensor) -> tensor + "mhlo.return"(%result) : (tensor) -> () + }) { fusion_kind = "kLoop" } : (tensor, tensor) -> tensor + + return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @main +// CHECK: "lmhlo.fusion"() ( { +// CHECK: %[[VAL0:.*]] = tensor_load %{{.*}} : memref +// CHECK: %[[VAL1:.*]] = tensor_load %{{.*}} : memref +// CHECK: %[[VAL2:.*]] = tensor_load %{{.*}} : memref +// CHECK: tensor_store %[[VAL0]], %{{.*}} : memref +// CHECK: tensor_store %[[VAL1]], %{{.*}} : memref +// CHECK: tensor_store %[[VAL2]], %{{.*}} : memref +// CHECK: "lmhlo.terminator"() : () -> () +// CHECK: }) : () -> () +func @main(%arg0: tuple>, tensor>, %arg1: tuple>) -> tuple, tensor, tensor> { + %result = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tuple>, tensor>, %arg3: tuple>): + %0 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple>, tensor>) -> tuple> + %1 = "mhlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple>) -> tensor + %2 = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32} : (tuple>, tensor>) -> tensor + %3 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple>) -> tensor + %4 = "mhlo.tuple"(%1, %2, %3) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%4) : (tuple, tensor, tensor>) -> () + }) { fusion_kind = "kLoop" } : (tuple>, tensor>, tuple>) -> tuple, tensor, tensor> + + return %result : tuple, tensor, tensor> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index ed86168d1fc..fcbc5ebd337 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -483,6 +483,57 @@ func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tenso return %0 : tensor } + +//===----------------------------------------------------------------------===// +// ClipByValue +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @clip +func @clip(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK: [[VAL:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) + + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + // CHECK: return [[VAL]] + return %0 : tensor +} + +// CHECK-LABEL: @clip_dynamic +func @clip_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + return %0 : tensor +} + +// CHECK-LABEL: @clip_static_broadcast +func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<5xf32> { + // CHECK-DAG: [[SHP:%.+]] = mhlo.constant dense<5> + // CHECK-DAG: [[SHPIDX:%.+]] = tensor_cast [[SHP]] + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor, tensor) -> tensor<5xf32> + + // CHECK: return [[CLAMP]] + return %0 : tensor<5xf32> +} + + +// CHECK-LABEL: @clip_dynamic_broadcast +func @clip_dynamic_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-DAG: [[SHP:%.+]] = shape.shape_of %arg0 + // CHECK-DAG: [[EXT:%.+]] = shape.to_extent_tensor [[SHP]] + // CHECK-DAG: [[SHPIDX:%.+]] = index_cast %1 + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + return %0 : tensor +} + //===----------------------------------------------------------------------===// // DiagPart //===----------------------------------------------------------------------===// @@ -2125,6 +2176,13 @@ func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } +// CHECK-LABEL: func @invert_op_unranked +func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: "mhlo.not"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + // CHECK-LABEL: @is_finite func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { // CHECK: "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> @@ -3824,15 +3882,13 @@ func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( { + // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( { // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor, %{{.*}}: tensor, %{{.*}}: tensor): // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {comparison_direction = "GT"} // CHECK-NEXT: "mhlo.return"(%[[CMP]]) - // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> - // CHECK-NEXT: %[[TUPL0:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 0 : i32} - // CHECK-NEXT: %[[TUPL1:.*]] = "mhlo.get_tuple_element"(%[[SORT]]) {index = 1 : i32} - // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[TUPL0]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[TUPL1]]) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) + // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} // CHECK-NEXT: return %[[VAL]], %[[IDX]] %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> @@ -4199,12 +4255,11 @@ func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { // CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor // CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) - // CHECK: [[SORT:%.*]] = "mhlo.sort"([[RNG]], [[INPUT]]) ( { + // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ( { // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): // CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"} - // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> tuple, tensor<16xf32>> - // CHECK: [[RES:%.*]] = "mhlo.get_tuple_element"([[SORT]]) {index = 1 : i32} - // CHECK: return [[RES]] + // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) + // CHECK: return [[SORT]]#1 %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) return %0: tensor<16xf32> } @@ -4213,10 +4268,8 @@ func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { // CHECK: mhlo.rng_uniform // CHECK: mhlo.sort - // CHECK: mhlo.get_tuple_element // CHECK: mhlo.rng_uniform // CHECK: mhlo.sort - // CHECK: mhlo.get_tuple_element %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) return %0: tensor<10240xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir index 1032bb723c5..cea0599adb0 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -1,10 +1,10 @@ // RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FILECHECK_OPTS="" FileCheck %s func @main() -> tensor { - %cst = constant {name = "constant"} dense<1> : tensor - %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor - %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor - %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %cst = constant dense<1> : tensor + %cst_0 = constant dense<5.600000e+01> : tensor + %cst_1 = constant dense<1.200000e+01> : tensor + %cst_2 = constant dense<1.300000e+01> : tensor %0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor @@ -17,7 +17,7 @@ func @main() -> tensor { ^bb0(%arg0: tensor): %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () - }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> tensor + }) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -48,23 +48,23 @@ func @main() -> tensor { // ----- func @main() -> (tensor, tensor) { - %cst = constant {name = "constant"} dense<1> : tensor - %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor - %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor - %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %cst = constant dense<1> : tensor + %cst_0 = constant dense<5.600000e+01> : tensor + %cst_1 = constant dense<1.200000e+01> : tensor + %cst_2 = constant dense<1.300000e+01> : tensor %0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): - %1 = "mhlo.negate"(%arg0) {name = "negate"} : (tensor) -> tensor + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "mhlo.copy"(%arg0) {name = "copy"} : (tensor) -> tensor + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "mhlo.floor"(%arg0) {name = "floor"} : (tensor) -> tensor + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () - }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) + }) : (tensor, tensor, tensor, tensor) -> (tensor, tensor) return %0#0, %0#1 : tensor, tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt index 62f0d7a59e4..1fa7367763e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -26,21 +26,21 @@ ENTRY %indexed_conditional () -> f32[] { } // CHECK-LABEL: func @main() -> tensor -// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor -// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor -// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor -// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor +// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor +// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor +// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor +// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor // CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { // CHECK: ^bb0(%[[ARG_1:.*]]: tensor): -// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_1]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_2:.*]]: tensor): -// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_2]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_3:.*]]: tensor): -// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_3]]) : (tensor) -> () -// CHECK: }) {name = "{{.*}}"} : (tensor, tensor, tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor, tensor, tensor) -> tensor // CHECK: return %[[RESULT]] : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index ff1bcadda7b..c078191d170 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -439,7 +439,7 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10 // CHECK-SAME: index_vector_dim=1 // CHECK-SAME: slice_sizes={1,1,300} // CHECK-SAME: indices_are_sorted=true - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> return %0 : tensor<10x300xf32> } @@ -502,7 +502,7 @@ func @main() -> tensor<1x10xf32> { func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors - %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -739,7 +739,7 @@ func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2) // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]]) - %0 = "mhlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -948,19 +948,20 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: HloModule func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] { // CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT -// CHECK: ENTRY %{{.*}} ([[MAIN_ARG0:.*]]: f32[16,16], [[MAIN_ARG1:.*]]: s32[16,16]) -> (f32[16,16], s32[16,16]) { -// CHECK: ROOT %{{.*}} = (f32[16,16], s32[16,16]) sort(f32[16,16] %[[MAIN_ARG0]], s32[16,16] %[[MAIN_ARG1]]), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] +// CHECK: [[SORT:%.+]] = (f32[16,16], s32[16,16]) sort(f32[16,16] %Arg_0.1, s32[16,16] %Arg_1.2), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] +// CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0 +// CHECK: ROOT [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1 // ----- @@ -1101,3 +1102,33 @@ func @main(%arg: tensor<3xui64>) -> tuple, tensor<2x2xui32>> { %0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple, tensor<2x2xui32>> return %0 : tuple, tensor<2x2xui32>> } + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] cbrt(f32[3,4] %[[ARG0]]) + %0 = "mhlo.cbrt"(%arg) : (tensor<3x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] reduce-precision(f32[3,4] %[[ARG0]]), exponent_bits=8, mantissa_bits=10 + %0 = "mhlo.reduce_precision"(%arg) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + +// CHECK: HloModule +func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> { +// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4,1] bitcast(f32[3,4] %[[ARG0]]) + %0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32> + return %0 : tensor<3x4x1xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index 86adcf0710f..4cc70be0965 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -9,95 +9,95 @@ ENTRY %tfcompile.48 { %arg0.1 = f32[1,300] parameter(0) %arg1.2 = f32[1,300,3,1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> + // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32> %reshape.3 = f32[1,300] reshape(%arg0.1) - // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} - // CHECK-NEXT: %2 = "mhlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> + // CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32> %reshape.28 = f32[300,1,1] reshape(%transpose.27) - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32> %reshape.29 = f32[300,1] reshape(%reshape.28) - // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} - // CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor %constant.8 = f32[] constant(1) - // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} - // CHECK-NEXT: %6 = mhlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> + // CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32> %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) - // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_0 = constant dense<0.000000e+00> : tensor %constant.32 = f32[] constant(0) - // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} - // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> + // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT - // CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_1 = constant dense<0.000000e+00> : tensor %constant.10 = f32[] constant(0) - // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} - // CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_2 = constant dense<0.000000e+00> : tensor %constant.40 = f32[] constant(0) - // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> + // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} - // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %copy.1 = f32[1,300,3,1] copy(%arg1.2) - // CHECK-NEXT: %12 = "mhlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %reshape.4 = f32[1,300,3,1] reshape(%copy.1) - // CHECK-NEXT: %13 = "mhlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> + // CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> %reshape.24 = f32[1,300,3] reshape(%reshape.4) - // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} - // CHECK-NEXT: %15 = "mhlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> + // CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32> %reshape.26 = f32[300,3] reshape(%transpose.25) - // CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> + // CHECK-NEXT: %cst_3 = constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) // TODO(b/129709049) consider making this default precision config implied. - // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} - // CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32> + // CHECK-NEXT: %cst_4 = constant dense<0.000000e+00> : tensor<5xf32> %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) - // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32> %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} - // CHECK-NEXT: %18 = mhlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32> + // CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32> %add.39 = f32[300,5] add(%dot.36, %broadcast.38) - // CHECK-NEXT: %19 = mhlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32> + // CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32> %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) - // CHECK-NEXT: %20 = "mhlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32> %reshape.44 = f32[300,1,5] reshape(%maximum.42) - // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44) - // CHECK-NEXT: %22 = "mhlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %reshape.46 = f32[300,1,5] reshape(%select.45) - // CHECK-NEXT: %23 = "mhlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> + // CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple> // CHECK-NEXT: return %23 : tuple> ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 4d4e0213da8..cce49b16c6c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -13,12 +13,12 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %Arg_0.1 = f32[4]{0} parameter(0) %Arg_1.2 = f32[4]{0} parameter(1) - // CHECK-NEXT: mhlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> - %add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) + // CHECK-NEXT: mhlo.add %arg0, %arg1 : tensor<4xf32> + %add.42 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor - ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor + ROOT %dot.4 = f32[] dot(f32[4]{0} %add.42, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } // CHECK-LABEL: func @test_after_all @@ -26,7 +26,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %test_after_all (token0: token[], token1: token[] ) -> token[] { token0 = token[] parameter(0) token1 = token[] parameter(1) - // CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!mhlo.token, !mhlo.token) -> !mhlo.token + // CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) : (!mhlo.token, !mhlo.token) -> !mhlo.token ROOT after-all = token[] after-all(token0, token1) } @@ -75,10 +75,10 @@ add { %test_broadcast_in_dim { %Arg_0.1 = f32[1, 2] parameter(0) - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1} - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2} } @@ -113,7 +113,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32> %test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] { %a = f32[1,291,291] parameter(0) - // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> + // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true } @@ -124,16 +124,16 @@ add { %Arg_1.2 = f32[4] parameter(1) %Arg_2.3 = f32[] parameter(2) - // CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3) } // CHECK-LABEL: func @test_collective_permute // CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> %test_collective_permute (input: f32[128,32]) -> f32[128,32] { - %input = f32[128,32]{0,1} parameter(0) - // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> - ROOT root = f32[128,32]{0,1} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}} + %input = f32[128,32]{1,0} parameter(0) + // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> + ROOT root = f32[128,32]{1,0} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}} } @@ -143,14 +143,14 @@ add { %Arg_1.2 = f32[3] parameter(1) %Arg_2.3 = f32[3] parameter(2) - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -159,7 +159,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -176,12 +176,12 @@ add { %test_constant { // Scalar/0D tensor constant - // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<1> : tensor + // CHECK-NEXT: %cst = constant dense<1> : tensor %constant.0 = s64[] constant(1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: constant dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} // CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64> @@ -206,15 +206,15 @@ add { %test_conv { %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: %cst = constant dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} // CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) { @@ -241,10 +241,10 @@ add { %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} - // CHECK-NEXT: "mhlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple> + // CHECK-NEXT: "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple> ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} } @@ -263,10 +263,10 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "mhlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %0 = "mhlo.convert"(%arg0) : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "mhlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %1 = "mhlo.convert"(%arg1) : (tensor<4xf32>) -> tensor<4xf64> %convert.4 = f64[4] convert(f32[4] %Arg_1.2) // CHECK-NEXT: mhlo.add %0, %1 @@ -277,7 +277,7 @@ add { %test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.cosine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -286,7 +286,7 @@ add { %test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] { %arg1 = f32[2,3] parameter(0) %arg2 = f32[5,5] parameter(1) -// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, minor_to_major = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true } @@ -295,7 +295,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.divide %arg0, %arg1 : tensor<4xf32> ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -304,17 +304,17 @@ add { %Arg_0.1 = f32[1, 4] parameter(0) %Arg_1.2 = f32[4, 1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} - // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default} - // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -325,17 +325,17 @@ add { %Arg_0.1 = f32[4, 1] parameter(0) %Arg_1.2 = f32[1, 4] parameter(1) - // CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} + // CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGH", "HIGHEST"]} dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest} - // CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} + // CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "DEFAULT"]} dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default} - // CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1} } @@ -376,7 +376,7 @@ add { %test_exponential (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.exponential"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1) } @@ -384,7 +384,7 @@ add { %test_expm1 (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1) } @@ -400,7 +400,7 @@ add { %test_floor (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.floor"([[A0]]) : (tensor<16xf32>) -> tensor<16xf32> ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1) } @@ -430,7 +430,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>) %test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] { %Arg_0.1 = f32[4,2] parameter(0) - // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor + // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1} } @@ -438,7 +438,7 @@ add { %test_imag (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "mhlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1) } @@ -468,7 +468,7 @@ add { %test_log (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.log"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %log.2 = f32[16] log(f32[16] %arg0.1) } @@ -476,7 +476,7 @@ add { %test_log1p (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1) } @@ -507,7 +507,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32> ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -516,7 +516,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32> ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -525,7 +525,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 : tensor<4xf32> ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -533,7 +533,7 @@ add { %test_negate (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.negate"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1) } @@ -541,7 +541,7 @@ add { %test_not (arg0.1: pred[16]) -> pred[16] { %arg0.1 = pred[16] parameter(0) - // CHECK: "mhlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1> + // CHECK: "mhlo.not"(%arg0) : (tensor<16xi1>) -> tensor<16xi1> ROOT %not.2 = pred[16] not(pred[16] %arg0.1) } @@ -595,7 +595,7 @@ add { %test_popcnt (arg0.1: s32[16]) -> s32[16] { %arg0.1 = s32[16] parameter(0) - // CHECK: "mhlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32> + // CHECK: "mhlo.popcnt"(%arg0) : (tensor<16xi32>) -> tensor<16xi32> ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1) } @@ -604,7 +604,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.power %arg0, %arg1 : tensor<4xf32> ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -632,7 +632,7 @@ add { %test_real (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "mhlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.real"(%arg0) : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1) } @@ -687,7 +687,7 @@ add { // CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor %reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3 - // CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor + // CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor %sub.5 = f32[] subtract(%reduce.3, %reduce.4) ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5) @@ -741,7 +741,7 @@ add { %test_rsqrt (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.rsqrt"([[ARG0]]) : (tensor<16xf32>) -> tensor<16xf32> ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1) } @@ -788,7 +788,7 @@ add { %Arg_1.2 = s32[2,3] parameter(1) %Arg_2.3 = s32[2,3] parameter(2) - // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) } @@ -835,7 +835,7 @@ add { %test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { %Arg_0.1 = f32[4,4] parameter(0) %Arg_1.2 = s32[] parameter(1) - // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1} } @@ -843,7 +843,7 @@ add { %test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.sine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -862,7 +862,7 @@ add { // CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32> // CHECK: "mhlo.sort"([[ARG]]) ( { // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): -// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor, tensor) -> tensor +// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32> @@ -871,7 +871,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.subtract %arg0, %arg1 : tensor<4xf32> ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -879,7 +879,7 @@ add { %test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.tanh"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"} } @@ -887,7 +887,7 @@ add { %test_transpose { %Arg_0.1 = s32[1,2,3,4] parameter(0) - // CHECK: "mhlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // CHECK: "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} } @@ -909,10 +909,10 @@ add { %Arg_0.1 = s32[1] parameter(0) %Arg_1.2 = f32[1, 2] parameter(1) - // CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple> + // CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) : (tensor<1xi32>) -> tuple> %tuple.3 = (s32[1]) tuple(%Arg_0.1) - // CHECK: "mhlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + // CHECK: "mhlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2) } @@ -934,11 +934,11 @@ add { %arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"} // CHECK-NEXT: "mhlo.while"(%arg0) ( { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor, tensor) -> tensor + // CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK-NEXT: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor + // CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 : tensor // CHECK-NEXT: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK-NEXT: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond @@ -992,8 +992,8 @@ add { %Arg_1.2 = c128[2] parameter(1) %abs.4 = f64[2] abs(c128[2] %Arg_1.2) - // CHECK: "mhlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf32> - // CHECK: "mhlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> + // CHECK: "mhlo.abs"(%[[ARG0]]) : (tensor<2xcomplex>) -> tensor<2xf32> + // CHECK: "mhlo.abs"(%[[ARG1]]) : (tensor<2xcomplex>) -> tensor<2xf64> ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) } @@ -1002,7 +1002,7 @@ add { %unsigned_int(Arg_0.1: u16[4]) -> u16[4] { %Arg_0.1 = u16[4] parameter(0) - // CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> + // CHECK: "mhlo.not"(%[[ARG0]]) : (tensor<4xui16>) -> tensor<4xui16> ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) } @@ -1014,3 +1014,26 @@ add { ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox } +// CHECK-LABEL: func @cbrt +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) +%cbrt (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + // CHECK: "mhlo.cbrt"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4xf32> + ROOT %cbrt = f32[3,4] cbrt(f32[3,4] %Arg_0.1) +} + +// CHECK-LABEL: func @bitcast +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) -> tensor<3x4x1xf32> +%bitcast (Arg_0.1: f32[3,4]) -> f32[3,4,1] { + %Arg_0.1 = f32[3,4] parameter(0) + // CHECK: "mhlo.bitcast"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4x1xf32> + ROOT %bitcast = f32[3,4,1] bitcast(f32[3,4] %Arg_0.1) +} + +// CHECK-LABEL: func @reduce_precision +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) +%reduce_precision (Arg_0.1: f32[3,4]) -> f32[3,4] { + %Arg_0.1 = f32[3,4] parameter(0) + // CHECK: "mhlo.reduce_precision"(%[[ARG0]]) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32> + ROOT %reduce_precision = f32[3,4] reduce-precision(f32[3,4] %Arg_0.1), exponent_bits=8, mantissa_bits=10 +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt new file mode 100644 index 00000000000..da07dc0a76b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-print-debuginfo -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule Test + +// CHECK-LABEL: func @main +ENTRY A { + %input = f16[128,224,224,4] parameter(0) + %filter = f16[64,7,7,4] parameter(1) + // %0 = "mhlo.convolution"{{.*}}minor_to_major = dense<[1, 3, 2, 0]> : tensor<4xindex>{{.*}} loc("root.42") + ROOT %root.42 = f16[128,64,112,112]{1,3,2,0} convolution(%input, %filter), dim_labels=b01f_o01i->bf01, window={size=7x7 stride=2x2 pad=3_3x3_3} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir new file mode 100644 index 00000000000..2ef0aaf3f50 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir @@ -0,0 +1,34 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text-with-layouts %s | FileCheck %s + +// Checks exporting layouts + +// CHECK: HloModule +func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> tensor<128x64x112x112xf16> { + // CHECK: %convolution.{{.*}} = f16[128,64,112,112]{1,3,2,0} convolution{{.*}}op_name="root.42" + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>, + kernel_input_feature_dimension = 3 : i64, + kernel_output_feature_dimension = 0 : i64, + kernel_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[ 2, 3 ]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + lhs_dilations = dense<1> : tensor<2xi64>, + minor_to_major = dense<[ 1, 3, 2, 0 ]> : tensor<4xindex>, + padding = dense<3> : tensor<2x2xi64>, + precision_config = [ "DEFAULT", "DEFAULT" ], + rhs_dilations = dense<1> : tensor<2xi64>, + window_strides = dense<2> : tensor<2xi64> + } : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42") + + // CHECK: s32[1,1]{0,1} constant({ {42} }) + %cst_1 = "std.constant"() {value = dense<[[42]]> : tensor<1x1xi32>, minor_to_major = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<1x1xi32> + + return %0 : tensor<128x64x112x112xf16> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo index d97c5150335..4c288aee956 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo @@ -139,8 +139,8 @@ dynamic_parameter_binding { } # CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { -# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32> +# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 : tensor<4xf32> # TODO(b/129709049) consider making this default precision config inferred. -# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor +# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor # CHECK-NEXT: return %1 : tensor # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt index 855b1c4bcd5..f7e1ba9ff15 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt @@ -4,25 +4,25 @@ HloModule tfcompile.1 // CHECK-LABEL: func @main() -> tensor { ENTRY %tfcompile.1 { - // CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor %constant.0 = f32[] constant(1) - // CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor %constant.1 = f64[] constant(1) - // CHECK-NEXT: %cst_1 = constant {name = "constant.2"} dense<1> : tensor + // CHECK-NEXT: %cst_1 = constant dense<1> : tensor %constant.2 = s8[] constant(1) - // CHECK-NEXT: %cst_2 = constant {name = "constant.3"} dense<1> : tensor + // CHECK-NEXT: %cst_2 = constant dense<1> : tensor %constant.3 = s16[] constant(1) - // CHECK-NEXT: %cst_3 = constant {name = "constant.4"} dense<1> : tensor + // CHECK-NEXT: %cst_3 = constant dense<1> : tensor %constant.4 = s32[] constant(1) - // CHECK-NEXT: %cst_4 = constant {name = "constant.5"} dense<1> : tensor + // CHECK-NEXT: %cst_4 = constant dense<1> : tensor %constant.5 = s64[] constant(1) - // CHECK-NEXT: %cst_5 = constant {name = "constant.6"} dense : tensor + // CHECK-NEXT: %cst_5 = constant dense : tensor // CHECK-NEXT: return %cst_5 : tensor ROOT %constant.6 = pred[] constant(1) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt index 126bc88ec7a..f989104323a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt @@ -26,4 +26,4 @@ ENTRY %foo (arg0.1: s64[]) -> s64[] { // CHECK: "mhlo.return" // CHECK: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir index 61d7aadb23f..f852ef06421 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir @@ -10,11 +10,11 @@ module { // CHECK: %[[A0]] = s64[] parameter(0) // CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT ^bb0(%arg1: tensor): - %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor + %1 = mhlo.add %arg1, %arg1 : tensor "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 4dacdbe2e6a..d541270f671 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -4722,10 +4722,8 @@ class ConvertTopKV2Op : public OpRewritePattern { &rewriter); // Get the sorted input and index tuple element. - auto tuple_first_element = - rewriter.create(op.getLoc(), sort_op, 0); - auto tuple_second_element = - rewriter.create(op.getLoc(), sort_op, 1); + auto tuple_first_element = sort_op.getResult(0); + auto tuple_second_element = sort_op.getResult(1); SmallVector begin_indices(input_rank, 0); auto end_indices = llvm::to_vector<4>(input_type.getShape()); @@ -5011,8 +5009,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { BuildSortComparisonBody({i32_type, input_type.getElementType()}, /*direction=*/"LT", &sorted.comparator(), &rewriter); - current = rewriter.create(op.getLoc(), - sorted.getResult(), 1); + current = sorted.getResult(1); } rewriter.replaceOp(op, current); return success(); @@ -5210,6 +5207,46 @@ class ConvertXlaDynamicUpdateSliceOp } }; +// Converts ClipByValue to XLA's clamp operation. Includes the broadcasting +// semantics for static and dynamic cases. +class ConvertClipByValueOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ClipByValueOp op, + PatternRewriter &rewriter) const override { + Value input = op.t(); + Value min = op.clip_value_min(); + Value max = op.clip_value_max(); + + auto input_ty = input.getType().cast(); + auto min_ty = min.getType().cast(); + auto max_ty = max.getType().cast(); + + if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()), + input); + + if (min_ty != input_ty) { + min = + rewriter.create(op.getLoc(), input_ty, min, shape); + } + + if (max_ty != input_ty) { + max = + rewriter.create(op.getLoc(), input_ty, max, shape); + } + + rewriter.replaceOpWithNewOp(op, input_ty, min, input, max); + return success(); + } +}; + // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by // setting appropriate window dimensions, with the given aggregation op as the // reduction function. The input tensor needs to have a static shape, and 'axis' @@ -6067,26 +6104,26 @@ LogicalResult legalizeTF( void PopulateLegalizeTfPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - populateWithGenerated(context, patterns); + populateWithGenerated(context, *patterns); patterns->insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, - ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, - ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, - ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, - ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, - ConvertDynamicExpandDimsOp, ConvertDynamicReshapeOp, ConvertEinsumOp, - ConvertRFFTOp, ConvertIRFFTOp, ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, - ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, - ConvertMaxOp, ConvertMinOp, ConvertAvgPool2DOp, ConvertAvgPool3DOp, - ConvertAvgPool2DGradOp, ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, - ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, - ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, - ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp, - ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op, - ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, + ConvertClipByValueOp, ConvertConv2DOp, ConvertConv3DOp, + ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, + ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, + ConvertConv3DBackpropInputOp, ConvertCumprodOp, ConvertCumsumOp, + ConvertDiagPartOp, ConvertDynamicExpandDimsOp, ConvertDynamicReshapeOp, + ConvertEinsumOp, ConvertRFFTOp, ConvertIRFFTOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, + ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, + ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, + ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp, + ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, + ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, + ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp, + ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 7f3caef7426..52bbbf6f9da 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -572,12 +572,14 @@ def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), foreach Mapping = [ [TF_AbsOp, HLO_AbsOp], [TF_AcosOp, HLOClient_AcosOp], + [TF_AtanOp, HLOClient_AtanOp], [TF_CeilOp, HLO_CeilOp], [TF_ComplexAbsOp, HLO_AbsOp], [TF_CosOp, HLO_CosOp], [TF_ExpOp, HLO_ExpOp], [TF_FloorOp, HLO_FloorOp], [TF_ImagOp, HLO_ImagOp], + [TF_InvertOp, HLO_NotOp], [TF_IsFiniteOp, HLO_IsFiniteOp], [TF_LogOp, HLO_LogOp], [TF_Log1pOp, HLO_Log1pOp], diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 8cd67e5db08..6e4a2495a4f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -151,6 +151,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -177,6 +178,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -227,6 +229,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -239,6 +242,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get() diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index a731b5a4a50..9cf161fb2ae 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -29,7 +29,9 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -110,7 +113,7 @@ Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, // Run all HLO passes to produce an optimized module. auto result_or = backend->compiler()->RunHloPassesAndBufferAssignement( std::move(hlo_module), backend->default_stream_executor(), - backend->memory_allocator()); + backend->memory_allocator(), optimize_xla_hlo); TF_RETURN_WITH_CONTEXT_IF_ERROR(result_or.status(), "running XLA pass pipeline"); std::unique_ptr optimized_hlo_module = @@ -276,27 +279,138 @@ Status LhloDialectEmitter::HandleSort(HloInstruction* instr) { return EmitSortOp(instr).status(); } -Status LhloDialectEmitter::CreateView(const HloInstruction* instr, - const Shape& current_shape, - ::xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values) { - if (current_shape.IsTuple()) { - for (int i = 0; i < current_shape.tuple_shapes().size(); i++) { - current_shape_index->push_back(i); - TF_RETURN_IF_ERROR(CreateView(instr, current_shape.tuple_shapes(i), - current_shape_index, values)); - current_shape_index->pop_back(); +// Walks MHLO::TupleOp recursively. +Status WalkTuplePostOrder(Value v, + const std::function& visitor) { + if (auto* op = v.getDefiningOp()) { + if (auto tuple = dyn_cast(op)) { + for (Value sub_v : tuple.val()) { + TF_RETURN_IF_ERROR(WalkTuplePostOrder(sub_v, visitor)); + } + return Status::OK(); } - return Status::OK(); } + return visitor(v); +} + +// This function removes all uses of a fused region argument, and rewire those +// uses to a `tensor_load %memref`, where %memref is caller argument. +// +// It also flattens all input/output tuples into more region arguments / +// results. +StatusOr LhloDialectEmitter::RewriteFusionOperand( + const HloInstruction* root, const Shape& shape, + ::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) { + if (shape.IsTuple()) { + llvm::SmallVector values; + for (int i = 0; i < shape.tuple_shapes_size(); i++) { + shape_index->push_back(i); + TF_ASSIGN_OR_RETURN( + auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index, + b, loc)); + values.push_back(v); + shape_index->pop_back(); + } + return Value(b->create(loc, values)); + } + TF_ASSIGN_OR_RETURN(Value memref, + GetOrCreateArrayView(root, shape, *shape_index)); + auto load = b->create(loc, memref); + if (shape.layout() != + xla::LayoutUtil::MakeDescendingLayout(shape.dimensions().size())) { + llvm::SmallVector minor_to_major( + shape.layout().minor_to_major().begin(), + shape.layout().minor_to_major().end()); + load.setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major)); + } + return load.getResult(); +} + +StatusOr LhloDialectEmitter::EmitFusionOp( + HloInstruction* instr) { + Location loc = getLocation(instr); + + auto* fusion_instr = ::xla::Cast<::xla::HloFusionInstruction>(instr); + + auto fusion = builder_.create(getLocation(instr), + ArrayRef{}); + auto after_fusion = builder_.saveInsertionPoint(); + builder_ = mlir::OpBuilder(fusion); + + auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front()); + + llvm::SmallVector arguments; + for (int i = 0; i < instr->operands().size(); i++) { + const HloInstruction* operand = instr->operand(i); + xla::ShapeIndex shape_index; + TF_ASSIGN_OR_RETURN( + auto arg, RewriteFusionOperand(operand, operand->shape(), &shape_index, + ®ion_builder, loc)); + arguments.push_back(arg); + } + + TF_ASSIGN_OR_RETURN(Value result, + ::xla::HloFunctionImporter::ImportInstructions( + *fusion_instr->fused_instructions_computation(), + arguments, ®ion_builder)); + + { + int i = 0; + llvm::SmallVector output; + TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output)); + TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable { + region_builder.create(loc, v, output[i++]); + return Status::OK(); + })); + if (i != output.size()) { + return ::xla::InternalError("output sizes don't match"); + } + } + + // Fold GTE/Tuple pairs. + // + // Since the fused region refers to values in its parent region, we can't + // call applyPatternAndFoldGreedily. We optimize it manually. + // + // Only walk once, because post-ordering is exactly what we need for GTE + // optimizations. + fusion.region().walk([](mhlo::GetTupleElementOp gte) { + SmallVector folded_values; + if (succeeded(OpBuilder(gte).tryFold(gte, folded_values))) { + gte.replaceAllUsesWith(folded_values[0]); + } + }); + + // Effectively a DCE on the region. + { + llvm::SmallVector ops; + fusion.region().walk([&](mlir::Operation* op) { ops.push_back(op); }); + // Visit the user first. + std::reverse(ops.begin(), ops.end()); + for (auto op : ops) { + if (isOpTriviallyDead(op)) op->erase(); + } + } + + LOG(ERROR) << instr->GetModule()->ToString(); + builder_.restoreInsertionPoint(after_fusion); + return fusion; +} + +Status LhloDialectEmitter::HandleFusion(HloInstruction* instr) { + return EmitFusionOp(instr).status(); +} + +StatusOr LhloDialectEmitter::GetOrCreateArrayView( + const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, + const ::xla::ShapeIndex& shape_index) { TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType( current_shape, builder_)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(instr, *current_shape_index)); + assignment_.GetUniqueSlice(instr, shape_index)); Value alloc = allocations_[slice.allocation()]; - if (alloc.getType() == out_type) { - values->push_back(alloc); - return Status::OK(); + if (alloc.getType() == out_type && slice.offset() == 0) { + return alloc; } auto out_memref_type = out_type.dyn_cast(); @@ -304,6 +418,13 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, return tensorflow::errors::Internal( "Expected memref type when creating a view for leaf type of a tuple."); + // Cache generated ViewOp and StaticMemRefCastOp by (instruction, + // shape_index). + auto& cached_value = slices_[std::make_pair(instr, shape_index)]; + if (cached_value) { + return cached_value; + } + Value byte_shift = builder_.create(alloc.getLoc(), slice.offset()); @@ -327,7 +448,24 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, if (physical_out_type != out_type) result = builder_.create(loc, out_memref_type, result); - values->push_back(result); + return cached_value = result; +} + +Status LhloDialectEmitter::GetOrCreateViewImpl( + const HloInstruction* instr, const Shape& current_shape, + ::xla::ShapeIndex* current_shape_index, SmallVectorImpl* values) { + if (current_shape.IsTuple()) { + for (int i = 0; i < current_shape.tuple_shapes().size(); i++) { + current_shape_index->push_back(i); + TF_RETURN_IF_ERROR(GetOrCreateViewImpl( + instr, current_shape.tuple_shapes(i), current_shape_index, values)); + current_shape_index->pop_back(); + } + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( + auto v, GetOrCreateArrayView(instr, current_shape, *current_shape_index)); + values->push_back(v); return Status::OK(); } @@ -336,25 +474,8 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, // create another view to adjust the slice for the shape of the instruction. Status LhloDialectEmitter::GetOrCreateView(const HloInstruction* instr, SmallVectorImpl* values) { - // Cache generated ViewOp and StaticMemRefCastOp by instruction. We could have - // gone fancier to do the following cacheing: - // %range = ViewOp(%allocation, %offset) : memref - // %typed_range = ViewOp(%range) : memref - // - // where %range is cached. This in theory gives easier time for alias - // analysis, since the identity of %range defines alias. However, - // %typed_range can't be cached, as different buffers with different types and - // shapes may still alias. Creating two ViewOps doesn't seem to worth the - // effort for a slightly easier aliasing, so we don't over optimize here. - auto result = slices_.try_emplace(instr, llvm::SmallVector{}); - llvm::SmallVectorImpl& new_values = result.first->second; - if (result.second) { - ::xla::ShapeIndex shape_index; - TF_RETURN_IF_ERROR( - CreateView(instr, instr->shape(), &shape_index, &new_values)); - } - values->insert(values->end(), new_values.begin(), new_values.end()); - return Status::OK(); + ::xla::ShapeIndex shape_index; + return GetOrCreateViewImpl(instr, instr->shape(), &shape_index, values); } Status LhloDialectEmitter::Initialize() { @@ -373,7 +494,7 @@ Status LhloDialectEmitter::Initialize() { if (computation_.IsEntryComputation()) { // Sort the rather arbitrarily ordered allocations to match the input/output - // parameters. Specifically We want to sort buffer allocations in the + // parameters. Specifically we want to sort buffer allocations in the // following order: // * Parameters always order before non-parameters. // * Different parameters order by parameter number. diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 89514116254..a57db3cb67e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -43,6 +43,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { i8_type_(builder_.getIntegerType(8)) {} ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitFusionOp(::xla::HloInstruction* instr); private: template @@ -57,21 +58,31 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { } tensorflow::Status HandleSort(::xla::HloInstruction* instr) final; + tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final; // Helper function that recursively visits the tuple structure in // `current_shape`, and reconstruct a matching lmhlo::TupleOp. // Each leaf node is converted to an std.view op with corresponding offsets. // If no tuple presents, it simply returns a view of the buffer. - tensorflow::Status CreateView(const ::xla::HloInstruction* instr, - const ::xla::Shape& current_shape, - ::xla::ShapeIndex* current_shape_index, - SmallVectorImpl* values); + tensorflow::Status GetOrCreateViewImpl(const ::xla::HloInstruction* instr, + const ::xla::Shape& current_shape, + ::xla::ShapeIndex* current_shape_index, + SmallVectorImpl* values); // Helper function to create view/tuple of views to a buffer for a given // instruction result. tensorflow::Status GetOrCreateView(const ::xla::HloInstruction* instr, SmallVectorImpl* values); + ::xla::StatusOr GetOrCreateArrayView( + const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, + const ::xla::ShapeIndex& current_shape_index); + + ::xla::StatusOr RewriteFusionOperand(const ::xla::HloInstruction* root, + const ::xla::Shape& shape, + ::xla::ShapeIndex* shape_index, + OpBuilder* b, Location loc); + // Return an MLIR location for an HLO instruction. Location getLocation(::xla::HloInstruction* inst) { return NameLoc::get(builder_.getIdentifier(inst->name()), @@ -102,7 +113,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { // // `slices_` is populated lazily in the `GetOrCreateView()` helper as we // process every instruction. - llvm::DenseMap> + absl::flat_hash_map, + Value> slices_; // The BufferAssignment computed by XLA ahead of time. diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index b725f56b455..3822e10089b 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -102,8 +102,7 @@ Shape TypeToShape(mlir::Type type) { if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) return ShapeUtil::MakeShape(ptype, {}); - if (type.isBF16() || type.isF32() || type.isF64() || - type.isa()) { + if (type.isIntOrFloat()) { auto* context = type.getContext(); mlir::emitError(mlir::UnknownLoc::get(context)) << "lowering should have been handled by primitive type lowering for " @@ -140,7 +139,8 @@ Shape TypeToShape(mlir::Type type) { for (const auto& e : llvm::enumerate(strides)) { strides_with_indices.push_back({e.value(), e.index()}); } - std::sort(strides_with_indices.begin(), strides_with_indices.end()); + std::stable_sort(strides_with_indices.begin(), + strides_with_indices.end()); llvm::SmallVector minor_to_major; int64_t stride = 1; @@ -149,7 +149,7 @@ Shape TypeToShape(mlir::Type type) { // Either the affine map is not perfectly strided, or the dimensions // recovered from strides don't match the actual dimensions in shapes. - if (stride != pr.first) return {}; + if (stride != pr.first && m.getShape()[pr.second] != 1) return {}; stride *= m.getShape()[pr.second]; } diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index a4a2bc42d99..97417748b64 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -196,5 +196,22 @@ TEST(TypeToShapeTest, ConvertMemRefToShape) { EXPECT_TRUE(ShapeUtil::Equal(converted, shape)); } +TEST(TypeToShapeTest, ConvertMemRefToShape2) { + Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, {2, 4, 3, 3}, + {2, 3, 1, 0}); + MLIRContext context; + mlir::Builder builder(&context); + + StatusOr mlir_type = + ConvertShapeToType(shape, builder); + ASSERT_TRUE(mlir_type.ok()); + mlir::Type type = mlir_type.ConsumeValueOrDie(); + Shape converted = TypeToShape(type); + EXPECT_TRUE(ShapeUtil::Equal( + converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, + {2, 4, 3, 3}, {2, 3, 1, 0}))); + EXPECT_TRUE(ShapeUtil::Equal(converted, shape)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 4ad44d1bd77..3ee70db1813 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -124,13 +124,16 @@ static StatusOr> HloModuleFromProto( return HloModule::CreateFromProto(module_proto, module_config); } -static mlir::LogicalResult MlirHloToHloTextTranslateFunction( - mlir::ModuleOp module, llvm::raw_ostream& output) { +static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl( + mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) { if (!module) return mlir::failure(); HloProto hloProto; + mlir::MlirToHloConversionOptions options; + options.propagate_layouts = with_layouts; Status status = mlir::ConvertMlirHloToHlo( - module, &hloProto, emit_use_tuple_arg, emit_return_tuple); + module, &hloProto, emit_use_tuple_arg, emit_return_tuple, + /*shape_representation_fn=*/nullptr, options); if (!status.ok()) { LOG(ERROR) << "Module conversion failed: " << status; return mlir::failure(); @@ -146,9 +149,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( HloModule* hlo_module = statusOrHloModule.ValueOrDie().get(); - // We don't interpret or use layouts output << hlo_module->ToString( - HloPrintOptions().set_include_layout_in_shapes(false)); + HloPrintOptions().set_include_layout_in_shapes(with_layouts)); // Output alias information as comments in the HLO text. hlo_module->input_output_alias_config().ForEachAlias( @@ -162,6 +164,18 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( return mlir::success(); } +static mlir::LogicalResult MlirHloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + return MlirHloToHloTextTranslateFunctionImpl(module, output, + /*with_layouts=*/false); +} + +static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + return MlirHloToHloTextTranslateFunctionImpl(module, output, + /*with_layouts=*/true); +} + } // namespace xla static void RegisterInputDialects(mlir::DialectRegistry& registry) { @@ -176,6 +190,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate( "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction, RegisterInputDialects); +static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate( + "mlir-hlo-to-hlo-text-with-layouts", + xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects); + static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc index bfe4ed3844f..7eb1fb40f5e 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.cc @@ -27,3 +27,9 @@ llvm::cl::opt emit_return_tuple( "emit-return-tuple", llvm::cl::desc("Emit HLO modules with entry computations returning tuple"), llvm::cl::init(false)); + +// NOLINTNEXTLINE +llvm::cl::opt optimize_xla_hlo( + "optimize-xla-hlo", + llvm::cl::desc("Enable optimizations when translating XLA HLO -> LHLO"), + llvm::cl::init(true)); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h index 1d5a29a5fdb..14a2878dff8 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate_cl.h @@ -24,5 +24,6 @@ limitations under the License. extern llvm::cl::opt emit_use_tuple_arg; extern llvm::cl::opt emit_return_tuple; +extern llvm::cl::opt optimize_xla_hlo; #endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index c57851e3dcf..1dfcf88e654 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1089,6 +1089,7 @@ tf_xla_py_test( name = "reduce_ops_test", size = "medium", srcs = ["reduce_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1281,7 +1282,9 @@ tf_xla_py_test( python_version = "PY3", shard_count = 10, tags = [ + "no_oss", # b/170479349 "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", # b/170479349 "optonly", ], deps = [ @@ -1351,7 +1354,6 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", # b/162025277 ], deps = [ ":xla_test", @@ -1484,6 +1486,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], @@ -1545,6 +1548,7 @@ tf_xla_py_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 1, # Times out in fastbuild mode. diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 07a41d67520..59c8c544347 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -474,7 +474,6 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], dtype=np.int64)) - @test_util.disable_mlir_bridge("Enable tf.NextAfter Compilation") def testNextAfter(self): for dtype in self.numeric_types: if dtype in [np.float32, np.float64]: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index b5f82bcff12..f3f6fa8ae52 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -542,7 +542,7 @@ class UnaryOpsTest(xla_test.XLATestCase): for dtype in self.float_types: def quantize_and_dequantize_v2(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -127, 127, signed_input=True, num_bits=8) self._assertOpOutputMatchesExpected( @@ -551,7 +551,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) def quantize_and_dequantize_v2_round_half_up(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -1, 1.0, @@ -575,7 +575,7 @@ class UnaryOpsTest(xla_test.XLATestCase): dtype=dtype)) def quantize_and_dequantize_v2_round_half_to_even(x): - return array_ops.quantize_and_dequantize_v2( + return array_ops.quantize_and_dequantize( x, -1.0, 1.0, diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 3e9f5e8c5dd..b80b6263992 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -299,6 +302,78 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), expected=np.array([0, 45, 120, 231], dtype=dtype)) + @test_util.disable_mlir_bridge('Not supported yet') + def testVariadicReduce(self): + for dtype in set(self.numeric_types).intersection( + set([np.float32, np.complex64])): + + @def_function.function + def kahan_sum_reducer(t0, t1): + (s0, c0), (s1, c1) = t0, t1 + s0minusc = s0 - (c0 + c1) + t = s1 + s0minusc + c = (t - s1) - s0minusc + s = t + return s, c + + def kahan_sum_reduction(dims, output_idx): + + def fn(x): + arg = array_ops.zeros([], dtype) # pylint: disable=cell-var-from-loop + reducer = kahan_sum_reducer.get_concrete_function( + (arg, arg), (arg, arg)) + + return xla.variadic_reduce( + (x, array_ops.zeros_like(x)), + init_value=(arg, arg), + dimensions_to_reduce=dims, + reducer=reducer)[output_idx] + + return fn + + xs = np.array([1e5, np.pi, -1e5, np.exp(1.)]) + xs = np.array([xs, xs[::-1] / 3, xs / 7], dtype) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[], output_idx=0), + args=(xs,), expected=xs) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[], output_idx=1), + args=(xs,), expected=np.zeros_like(xs)) + shuffle_indices = np.argsort(np.random.randn(xs.shape[0])) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0], output_idx=0), + args=(xs[shuffle_indices],), + expected=np.array([np.exp(1) / 3 + 1e5 * 8 / 7, + np.pi * 8 / 7 - 1e5 / 3, + -1e5 * 8 / 7 + np.pi / 3, + np.exp(1) * 8 / 7 + 1e5 / 3], dtype=dtype)) + error_term_equality = functools.partial(self.assertAllClose, atol=.005) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0], output_idx=1), + args=(xs[shuffle_indices],), expected=np.zeros_like(xs[0]), + equality_fn=error_term_equality) + shuffle_indices = np.argsort(np.random.randn(xs.shape[1])) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[1], output_idx=0), + args=(xs[:, shuffle_indices],), + expected=np.array([np.pi + np.exp(1.), + (np.pi + np.exp(1.)) / 3, + (np.pi + np.exp(1.)) / 7], dtype=dtype)) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[1], output_idx=1), + args=(xs[:, shuffle_indices],), expected=np.zeros_like(xs[:, 0]), + equality_fn=error_term_equality) + # Now, shuffle both dims. + xs = xs[np.argsort(np.random.randn(xs.shape[0]))] + xs = xs[:, np.argsort(np.random.randn(xs.shape[1]))] + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0, 1], output_idx=0), + args=(xs,), expected=dtype((np.pi + np.exp(1.)) * 31 / 21)) + self._assertOpOutputMatchesExpected( + kahan_sum_reduction(dims=[0, 1], output_idx=1), + args=(xs,), expected=dtype(0), + equality_fn=error_term_equality) + @test_util.disable_mlir_bridge('Not supported yet') def testSelectAndScatter(self): for dtype in set(self.numeric_types).intersection( diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index fff00cb3906..e7183cd7e18 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -36,8 +36,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "tensorrt_stub", srcs = if_tensorrt([ @@ -72,7 +70,7 @@ tf_cuda_cc_test( deps = [ "//tensorflow/core/common_runtime/gpu:gpu_init", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/core:test", "//tensorflow/core:test_main", ] + if_tensorrt([ @@ -115,7 +113,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/core:stream_executor_headers_lib", "//tensorflow/core/common_runtime:core_cpu_lib_no_ops", "//tensorflow/core/grappler/costs:graph_properties", @@ -497,15 +495,32 @@ tf_cuda_cc_test( # Library for the segmenting portion of TensorRT operation creation cc_library( - name = "segment", - srcs = ["segment/segment.cc"], + name = "union_find", + srcs = ["segment/union_find.cc"], hdrs = [ - "segment/segment.h", "segment/union_find.h", ], copts = tf_copts(), + deps = [ + ":utils", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + ], + copts = tf_copts(), deps = [ ":common_utils", + ":union_find", ":utils", "//tensorflow/core:graph", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index c0c3f25177e..d09485c35c7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -429,11 +430,52 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, return Status::OK(); } +std::string GetLayerNameSuffix(absl::string_view sub_op_name, + absl::optional sub_op_instance) { + std::string op_suffix(sub_op_name); + if (sub_op_instance.has_value()) { + op_suffix = + absl::StrCat(op_suffix, "_", std::to_string(sub_op_instance.value())); + } + return op_suffix; +} + +// Sets the name of an ILayer using the name of the node_def. If the operation +// represented by the ILayer is generated by the converter to support the +// conversion of node_def, callers need to specify a non-empty sub_op_name +// to be appended to the name of node_def to avoid layer name conflicts. If the +// operation is generated multiple times, callers also need to specify +// sub_op_instance to be appended to the name of the layers to avoid layer name +// conflicts. +void SetLayerName(nvinfer1::ILayer* layer, const NodeDef& node_def, + absl::string_view sub_op_name = "", + absl::optional sub_op_instance = absl::nullopt) { + std::string sub_op_suffix = GetLayerNameSuffix(sub_op_name, sub_op_instance); + if (sub_op_suffix.empty()) { + layer->setName(node_def.name().c_str()); + } else { + layer->setName(absl::StrCat(node_def.name(), "-", sub_op_suffix).c_str()); + } +} + +// Sets the name of an ILayer using the format of +// "main_op_name"_"sub_op_name"_"sub_op_instance". +void SetLayerName(nvinfer1::ILayer* layer, absl::string_view main_op_name, + absl::string_view sub_op_name, + absl::optional sub_op_instance = absl::nullopt) { + std::string layer_name_suffix = + GetLayerNameSuffix(sub_op_name, sub_op_instance); + layer->setName(absl::StrCat(main_op_name, "-", layer_name_suffix).c_str()); +} + nvinfer1::ITensor* Converter::CreateConstantLayer( const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) { nvinfer1::Weights trt_weights = weights.GetTrtWeights(); nvinfer1::IConstantLayer* layer = network()->addConstant(dims, trt_weights); if (!layer) return nullptr; + SetLayerName(layer, "_tftrt_constant_", + std::to_string(next_constant_layer_id_)); + next_constant_layer_id_++; nvinfer1::ITensor* trt_tensor = layer->getOutput(0); #if !IS_TRT_VERSION_GE(5, 1, 3, 0) // TODO(laigd): there is a bug in TensorRT 5.0 library that, if we don't set @@ -1313,6 +1355,7 @@ Status Converter::AddInputTensor(const string& name, nvinfer1::DataType dtype, Status Converter::RenameAndMarkOutputTensors( const std::vector& output_tensors) { + int output_index = 0; for (const auto& output : output_tensors) { TRT_TensorOrWeights tensor_or_weights; TF_RETURN_IF_ERROR( @@ -1341,6 +1384,7 @@ Status Converter::RenameAndMarkOutputTensors( nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR( layer, StrCat("Output Copy for ", tensor->getName())); + SetLayerName(layer, tensor->getName(), "shuffle", output_index); MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); } @@ -1349,6 +1393,7 @@ Status Converter::RenameAndMarkOutputTensors( // Set type after marking as output. TRT only supports setType for engine // outputs and inputs (type is inferred otherwise). tensor->setType(output.trt_dtype); + output_index++; VLOG(1) << "Marking output TRT tensor " << output.source_tensor_name << " with data type " << DebugString(output.trt_dtype) << ", which feeds TF node " << output.dest_node_name; @@ -1475,8 +1520,9 @@ Status Converter::GetTensorOrWeights(const string& name, Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, const std::vector& order_with_batch_dim, - absl::string_view name, - nvinfer1::ITensor** output_tensor) { + nvinfer1::ITensor** output_tensor, + const NodeDef& node_def, + absl::string_view sub_op_name) { const auto dims = input_tensor->getDimensions(); const int order_size = use_implicit_batch_ ? order_with_batch_dim.size() - 1 : order_with_batch_dim.size(); @@ -1491,7 +1537,8 @@ Status Converter::TransposeTensor(nvinfer1::ITensor* input_tensor, nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose"); - layer->setName(std::basic_string(name).c_str()); + SetLayerName(layer, node_def, sub_op_name); + MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0)); nvinfer1::Permutation permutation; @@ -1555,7 +1602,9 @@ Status Converter::GetWeightRange(const TRT_ShapedWeights& weights, Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const bool validation_only, - nvinfer1::ITensor** tensor) { + nvinfer1::ITensor** tensor, + const NodeDef& node_def, + absl::optional op_instance) { const nvinfer1::Dims input_dims = input.GetTrtDims(); // If one of input_dims and dims doesn't have static shape, it means some of // the dims are unknown or need to be inferred. And we don't do further checks @@ -1586,6 +1635,7 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input, nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input.tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape"); + SetLayerName(layer, node_def, "shuffle", op_instance); layer->setReshapeDimensions(dims); MarkQuantizationRangesAsInferrable(input.tensor(), layer->getOutput(0)); *tensor = layer->getOutput(0); @@ -2086,6 +2136,7 @@ Status Conv2DPaddingHelper(OpConverterParams* params, const TFAttrs& attrs, *tensor, nvinfer1::DimsHW((*padding)[0].first, (*padding)[1].first), nvinfer1::DimsHW((*padding)[0].second, (*padding)[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, params->node_def.name()); + SetLayerName(pad_layer, params->node_def, "pad"); params->converter->MarkQuantizationRangesAsInferrable( tensor, pad_layer->getOutput(0)); *padding = {{0, 0}, {0, 0}}; @@ -2186,7 +2237,7 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, const bool need_transpose = (data_format == "NHWC"); if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); } // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); @@ -2252,7 +2303,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); conv_layer = layer; } else { @@ -2269,11 +2319,11 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); layer->setDilation(dilation); conv_layer = layer; } + SetLayerName(conv_layer, node_def, "conv"); nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); // Add an extra padding for Deconv because TRT doesn't accept the // argument output_shape and thus the TRT output shape could be wrong @@ -2306,13 +2356,13 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, params->converter->network()->addPadding(*output_tensor, pre_padding, post_padding); output_tensor = padding_layer->getOutput(0); + SetLayerName(padding_layer, node_def, "pad"); } } // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -2370,7 +2420,7 @@ Status ConvertTranspose(OpConverterParams* params) { // Start conversion. nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - input_tensor, perm, params->node_def.name(), &output_tensor)); + input_tensor, perm, &output_tensor, params->node_def)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); } @@ -2401,6 +2451,7 @@ Status ConvertShape(OpConverterParams* params) { nvinfer1::IShapeLayer* shape_layer = params->converter->network()->addShape(*inputs.at(0).tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(shape_layer, params->node_def.name()); + SetLayerName(shape_layer, params->node_def, "shape"); params->outputs->push_back(TRT_TensorOrWeights(shape_layer->getOutput(0))); return Status::OK(); #else @@ -2471,7 +2522,7 @@ Status ConvertReshape(OpConverterParams* params) { nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( input_tensor, output_nonbatch_dims, params->validation_only, - &output_tensor)); + &output_tensor, params->node_def)); if (params->validation_only) return Status::OK(); // Record the conversion result. @@ -2514,7 +2565,8 @@ Status ConvertExpandDims(OpConverterParams* params) { nvinfer1::Dims new_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(input_dims, &new_dims)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input_tensor, new_dims, /*validation_only=*/false, &output_tensor)); + input_tensor, new_dims, /*validation_only=*/false, &output_tensor, + params->node_def)); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -2524,7 +2576,8 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, std::vector> slices, OpConverterParams* params, nvinfer1::ITensor** output, - std::vector size_for_added_dims) { + std::vector size_for_added_dims, + absl::optional op_instance) { *output = nullptr; // DynamicReshape relies on INetworkDefinition::addShape that was introduced // in TensorRT 6. @@ -2536,9 +2589,11 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, nvinfer1::ITensor* shape = network()->addShape(*input)->getOutput(0); // Build new shape = shape[:trt_axis] + [1] + shape[trt_axis:] std::vector concat_inputs; - for (int i = 0; i < std::max(slices.size(), size_for_added_dims.size()); - i++) { + int max_num_slices = std::max(slices.size(), size_for_added_dims.size()); + int op_instance_value = op_instance.has_value() ? op_instance.value() : 0; + for (int i = 0; i < max_num_slices; i++) { nvinfer1::ITensor* tensor; + int slice_instance = i * max_num_slices + op_instance_value; // maybe_add_a_dimension(i); if (i < size_for_added_dims.size() && size_for_added_dims[i] >= 0) { TF_RETURN_IF_ERROR( @@ -2546,11 +2601,11 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, concat_inputs.push_back(tensor); } if (i < slices.size()) { - concat_inputs.push_back( - network() - ->addSlice(*shape, {1, {slices[i].first}}, - {1, {slices[i].second - slices[i].first}}, {1, {1}}) - ->getOutput(0)); + nvinfer1::ISliceLayer* slice_layer = network()->addSlice( + *shape, {1, {slices[i].first}}, + {1, {slices[i].second - slices[i].first}}, {1, {1}}); + concat_inputs.push_back(slice_layer->getOutput(0)); + SetLayerName(slice_layer, params->node_def, "slice", slice_instance); } } nvinfer1::IConcatenationLayer* concat_layer = network()->addConcatenation( @@ -2560,6 +2615,7 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, nvinfer1::ITensor* new_shape = concat_layer->getOutput(0); // Reshape input using new shape nvinfer1::IShuffleLayer* shuffle = network()->addShuffle(*input); + SetLayerName(shuffle, params->node_def, "shuffle", op_instance); shuffle->setInput(1, *new_shape); *output = shuffle->getOutput(0); return Status::OK(); @@ -2572,7 +2628,8 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input, Status Converter::DynamicExpandDims(nvinfer1::ITensor* input, const nvinfer1::Dims& dims, int axis, OpConverterParams* params, - nvinfer1::ITensor** output) { + nvinfer1::ITensor** output, + absl::optional op_instance) { if (params->validation_only) { *output = nullptr; return errors::Internal( @@ -2588,7 +2645,7 @@ Status Converter::DynamicExpandDims(nvinfer1::ITensor* input, if (axis != dims.nbDims) { slices.push_back(std::pair{axis, dims.nbDims}); } - return DynamicReshape(input, slices, params, output, extra_dims); + return DynamicReshape(input, slices, params, output, extra_dims, op_instance); } Status Converter::SqueezeTensor(nvinfer1::ITensor* input, @@ -2616,7 +2673,8 @@ Status Converter::SqueezeTensor(nvinfer1::ITensor* input, VLOG(2) << "input_dims" << input_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(*input_dims, &new_dims)); TF_RETURN_IF_ERROR(PrepareTensorForShape(TRT_TensorOrWeights(input), new_dims, - /*validation_only=*/false, output)); + /*validation_only=*/false, output, + params->node_def)); return Status::OK(); } @@ -2680,11 +2738,11 @@ Status ConvertSqueeze(OpConverterParams* params) { } template -Status ConvertStridedSliceHelper(OpConverterParams* params, - const TRT_TensorOrWeights& input, - Container begin, Container size, - const Container& stride, - const nvinfer1::Dims* final_shape = nullptr) { +Status ConvertStridedSliceHelper( + OpConverterParams* params, const TRT_TensorOrWeights& input, + Container begin, Container size, const Container& stride, + const nvinfer1::Dims* final_shape = nullptr, + absl::optional op_instance = absl::nullopt) { const auto& node_def = params->node_def; // Get input dims. nvinfer1::Dims dims = input.GetTrtDims(); @@ -2709,6 +2767,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, node_def.op(), ", at ", node_def.name()); } } + // TRT 5.1 adds ISliceLayer. For older versions, we attempt to use the // padding layer with negative padding. #if IS_TRT_VERSION_GE(5, 1, 3, 1) @@ -2723,12 +2782,13 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, nvinfer1::ISliceLayer* layer = params->converter->network()->addSlice( *input.tensor(), begin_dims, size_dims, stride_dims); + SetLayerName(layer, params->node_def, "slice", op_instance); nvinfer1::ITensor* tensor = layer->getOutput(0); // Reshape for shrink_axis. if (final_shape) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false, - &tensor)); + &tensor, node_def, op_instance)); } params->outputs->push_back(TRT_TensorOrWeights(tensor)); return Status::OK(); @@ -2782,6 +2842,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, if (params->validation_only) return Status::OK(); nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle(*input.tensor()); + SetLayerName(layer, params->node_def, "shuffle", op_instance); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); } else if (pad_dims.size() == 1) { @@ -2830,30 +2891,32 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, nvinfer1::ITensor* tensor = input.tensor(); if (need_reshape) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, reshape_dims, /*validation_only=*/false, &tensor)); + input, reshape_dims, /*validation_only=*/false, &tensor, node_def, + op_instance)); } if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, transpose_order, StrCat(node_def.name(), "_for_pad"), &tensor)); + tensor, transpose_order, &tensor, node_def, "for_pad", op_instance)); } // Add padding layer nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( *tensor, pre_padding, post_padding); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, params->node_def, "pad"); params->converter->MarkQuantizationRangesAsInferrable(tensor, layer->getOutput(0)); tensor = layer->getOutput(0); // Restore transpose if (need_transpose) { - TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, inv_transpose_order, StrCat(node_def.name(), "_after_pad"), - &tensor)); + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, inv_transpose_order, &tensor, + node_def, "after_pad", op_instance)); } // Reshape for shrink_axis. if (final_shape) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor), *final_shape, /*validation_only=*/false, - &tensor)); + &tensor, node_def, op_instance)); } else if (need_reshape) { // Restore reshape. // Calculate output dimensions @@ -2876,7 +2939,7 @@ Status ConvertStridedSliceHelper(OpConverterParams* params, /*ignore_first_dim=*/true)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor), new_dims, /*validation_only=*/false, - &tensor)); + &tensor, node_def, op_instance)); } params->outputs->push_back(TRT_TensorOrWeights(tensor)); @@ -3166,8 +3229,7 @@ Status ConvertConv3DHelper(OpConverterParams* params, int group, const bool need_transpose = is_ndhwc; if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 4, 1, 2, 3}, StrCat(node_def.name(), "_to_NCDHW"), - &tensor)); + tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW")); } // group == 0 signifies that this is a depthwise convolution, so set @@ -3206,7 +3268,6 @@ Status ConvertConv3DHelper(OpConverterParams* params, int group, layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); conv_layer = layer; } else { @@ -3222,18 +3283,17 @@ Status ConvertConv3DHelper(OpConverterParams* params, int group, layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - layer->setName(node_def.name().c_str()); layer->setNbGroups(num_groups); layer->setDilationNd(dilation_dhw); conv_layer = layer; } + SetLayerName(conv_layer, node_def, "conv"); nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 4, 1}, StrCat(node_def.name(), "_to_NDHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -3302,8 +3362,7 @@ Status ConvertPool3D(OpConverterParams* params) { if (data_format == "NDHWC") { // NDHWC => NCDHW TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 4, 1, 2, 3}, StrCat(node_def.name(), "_to_NCDHW"), - &tensor)); + tensor, {0, 4, 1, 2, 3}, &tensor, node_def, "to_NCDHW")); } const nvinfer1::Dims3 stride(tf_stride[d_index], tf_stride[h_index], @@ -3324,14 +3383,13 @@ Status ConvertPool3D(OpConverterParams* params) { // SAME_UPPER means that post padding is preferred. layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); } - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "pooling"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NDHWC") { // NCDHW => NDHWC TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 4, 1}, StrCat(node_def.name(), "_to_NDHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 4, 1}, &output_tensor, node_def, "to_NDHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -3426,7 +3484,7 @@ Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) { const bool need_transpose = (data_format == "NHWC"); if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); } nvinfer1::DimsHW kernel_size; @@ -3482,7 +3540,7 @@ Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) { #else conv_layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - conv_layer->setName(node_def.name().c_str()); + SetLayerName(conv_layer, node_def, "conv"); conv_layer->setNbGroups(1); conv_layer->setDilation(dilation); nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); @@ -3493,13 +3551,13 @@ Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) { params->converter->network()->addActivation(*output_tensor, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(activation_layer, node_def.name()); + SetLayerName(activation_layer, node_def, "activation"); output_tensor = activation_layer->getOutput(0); } // Restore transpose. if (need_transpose) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -3541,7 +3599,7 @@ Status ConvertPool(OpConverterParams* params) { h_index = 1; w_index = 2; TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); } const auto tf_stride = attrs.get>("strides"); @@ -3575,6 +3633,7 @@ Status ConvertPool(OpConverterParams* params) { *tensor, nvinfer1::DimsHW(padding[0].first, padding[1].first), nvinfer1::DimsHW(padding[0].second, padding[1].second)); TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name()); + SetLayerName(pad_layer, node_def, "pad"); params->converter->MarkQuantizationRangesAsInferrable( tensor, pad_layer->getOutput(0)); padding = {{0, 0}, {0, 0}}; @@ -3604,13 +3663,12 @@ Status ConvertPool(OpConverterParams* params) { #else layer->setPadding(nvinfer1::DimsHW{padding[0].first, padding[1].first}); #endif - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "pooling"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); if (data_format == "NHWC") { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), - &output_tensor)); + output_tensor, {0, 2, 3, 1}, &output_tensor, node_def, "to_NHWC")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -3633,6 +3691,7 @@ Status ConvertLeakyRelu(OpConverterParams* params) { params->converter->network()->addActivation( *inputs.at(0).tensor(), nvinfer1::ActivationType::kLEAKY_RELU); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "activation"); layer->setAlpha(alpha); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); @@ -3655,12 +3714,14 @@ Status ConvertLeakyRelu(OpConverterParams* params) { params->converter->network()->addElementWise( *tensor, *const_alpha_tensor, nvinfer1::ElementWiseOperation::kPROD); TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name()); + SetLayerName(mul_layer, node_def, "mul"); // max(x, alpha * x) nvinfer1::IElementWiseLayer* max_layer = params->converter->network()->addElementWise( *tensor, *mul_layer->getOutput(0), nvinfer1::ElementWiseOperation::kMAX); TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name()); + SetLayerName(mul_layer, node_def, "max"); nvinfer1::ITensor* output_tensor = max_layer->getOutput(0); params->converter->MarkQuantizationRangesAsInferrable( output_tensor, mul_layer->getOutput(0)); @@ -3705,6 +3766,7 @@ Status ConvertClipByValue(OpConverterParams* params) { layer->setAlpha(clip_value_min); layer->setBeta(clip_value_max); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "activation"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, clip_value_min, clip_value_max); @@ -3748,7 +3810,7 @@ Status ConvertActivation(OpConverterParams* params) { params->converter->network()->addActivation(*inputs.at(0).tensor(), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "activation"); // Set parameters. #if IS_TRT_VERSION_GE(5, 1, 2, 0) if (node_def.op() == "Elu") { @@ -3852,7 +3914,7 @@ Status ConvertRelu6(OpConverterParams* params) { TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setAlpha(0.0f); layer->setBeta(6.0f); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def, "activation"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -3867,6 +3929,7 @@ Status ConvertRelu6(OpConverterParams* params) { params->converter->network()->addActivation( *tensor, nvinfer1::ActivationType::kRELU); TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name()); + SetLayerName(relu_layer, node_def, "activation"); // Large range of relu is problematic during quantization in INT8 precision // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization. @@ -3888,6 +3951,7 @@ Status ConvertRelu6(OpConverterParams* params) { *relu_layer->getOutput(0), *const6_tensor, nvinfer1::ElementWiseOperation::kMIN); TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name()); + SetLayerName(relu6_layer, node_def, "min"); nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); @@ -3932,6 +3996,7 @@ Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) { nvinfer1::IShuffleLayer* shuffle_layer = params->converter->network()->addShuffle(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + SetLayerName(shuffle_layer, node_def, "shuffle", /*op_instance=*/0); params->converter->MarkQuantizationRangesAsInferrable( tensor, shuffle_layer->getOutput(0)); @@ -3963,6 +4028,7 @@ Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) { *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "scale"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); @@ -3971,6 +4037,7 @@ Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) { nvinfer1::IShuffleLayer* shuffle_layer = params->converter->network()->addShuffle(*output_tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name()); + SetLayerName(shuffle_layer, node_def, "shuffle", /*op_instance=*/1); // NOTE: for same reason as mentioned above we need to apply the reshape // unconditionally. nvinfer1::Dims reshape_dims = original_dims; @@ -4055,13 +4122,16 @@ Status ConvertBiasAdd(OpConverterParams* params) { // Convert input to a TRT tensor nvinfer1::ITensor* input_tensor{nullptr}; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), input_shape, params->validation_only, &input_tensor)); + inputs.at(0), input_shape, params->validation_only, &input_tensor, + node_def, + /*op_instance=*/0)); // Finally, reshape bias. Since the bias is usually a constant, this will // normally happen at conversion-time. nvinfer1::ITensor* bias_tensor{nullptr}; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), bias_shape, params->validation_only, &bias_tensor)); + inputs.at(1), bias_shape, params->validation_only, &bias_tensor, node_def, + /*op_instance=*/1)); VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape); if (params->validation_only) return Status::OK(); @@ -4070,6 +4140,7 @@ Status ConvertBiasAdd(OpConverterParams* params) { params->converter->network()->addElementWise( *input_tensor, *bias_tensor, nvinfer1::ElementWiseOperation::kSUM); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "sum"); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4298,16 +4369,18 @@ Status ConvertBinary(OpConverterParams* params) { nvinfer1::ITensor* tensor_r = nullptr; // This will also convert constants to tensors, and set quantization ranges. TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - operand_l, broadcasted_dims_l, params->validation_only, &tensor_l)); + operand_l, broadcasted_dims_l, params->validation_only, &tensor_l, + node_def, /*op_instance=*/0)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - operand_r, broadcasted_dims_r, params->validation_only, &tensor_r)); + operand_r, broadcasted_dims_r, params->validation_only, &tensor_r, + node_def, /*op_instance=*/1)); if (params->validation_only) return Status::OK(); // Add ElementWise layer. nvinfer1::ILayer* layer = params->converter->network()->addElementWise( *tensor_l, *tensor_r, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def); nvinfer1::ITensor* trt_tensor = layer->getOutput(0); #if IS_TRT_VERSION_GE(5, 1, 0, 0) @@ -4315,6 +4388,7 @@ Status ConvertBinary(OpConverterParams* params) { layer = params->converter->network()->addUnary( *trt_tensor, nvinfer1::UnaryOperation::kFLOOR); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "floor"); trt_tensor = layer->getOutput(0); } #endif @@ -4353,10 +4427,12 @@ Status ConvertRsqrt(OpConverterParams* params) { nvinfer1::IUnaryLayer* sqrt_layer = params->converter->network()->addUnary( *tensor, nvinfer1::UnaryOperation::kSQRT); TFTRT_RETURN_ERROR_IF_NULLPTR(sqrt_layer, node_def.name()); + SetLayerName(sqrt_layer, node_def, "sqrt"); // Recip nvinfer1::IUnaryLayer* recip_layer = params->converter->network()->addUnary( *sqrt_layer->getOutput(0), nvinfer1::UnaryOperation::kRECIP); TFTRT_RETURN_ERROR_IF_NULLPTR(recip_layer, node_def.name()); + SetLayerName(recip_layer, node_def, "recip"); params->outputs->push_back(TRT_TensorOrWeights(recip_layer->getOutput(0))); return Status::OK(); } @@ -4408,7 +4484,7 @@ Status ConvertUnary(OpConverterParams* params) { nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(*tensor, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Set quantization ranges. @@ -4453,6 +4529,7 @@ Status ConvertSquare(OpConverterParams* params) { *inputs.at(0).tensor(), *const2_tensor, nvinfer1::ElementWiseOperation::kPOW); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4511,6 +4588,7 @@ Status ConvertReduce(OpConverterParams* params) { nvinfer1::ILayer* layer = params->converter->network()->addReduce( *tensor, reduce_operation, axes, keep_dims); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); @@ -4585,21 +4663,25 @@ Status ConvertPack(OpConverterParams* params) { nvinfer1::Dims expanded_dims; TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(tensor_dims, &expanded_dims)); std::vector expanded_tensors; + int input_index = 0; for (const TRT_TensorOrWeights& input : inputs) { nvinfer1::ITensor* expanded_tensor = nullptr; if (input.is_tensor() && !params->use_implicit_batch && !HasStaticShape(dims)) { if (!params->validation_only) { TF_RETURN_IF_ERROR(params->converter->DynamicExpandDims( - input.tensor(), dims, trt_axis, params, &expanded_tensor)); + input.tensor(), dims, trt_axis, params, &expanded_tensor, + input_index)); } } else { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - input, expanded_dims, params->validation_only, &expanded_tensor)); + input, expanded_dims, params->validation_only, &expanded_tensor, + node_def, input_index)); } if (!params->validation_only) { expanded_tensors.push_back(expanded_tensor); } + input_index++; } if (params->validation_only) return Status::OK(); @@ -4615,6 +4697,7 @@ Status ConvertPack(OpConverterParams* params) { const_cast(expanded_tensors.data()), expanded_tensors.size()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "concat"); // Note that trt_axis stays the same even after expanding tensors at the axis. layer->setAxis(trt_axis); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); @@ -4696,7 +4779,7 @@ Status ConvertPad(OpConverterParams* params) { if (pad_index[0] == 1) { legit_pad = false; TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 2, 1}, StrCat(node_def.name(), "_to_pad"), &tensor)); + tensor, {0, 3, 2, 1}, &tensor, node_def, "to_pad")); permuted_pad_index[0] = 3; } @@ -4714,13 +4797,13 @@ Status ConvertPad(OpConverterParams* params) { nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( *tensor, pre_padding, post_padding); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->MarkQuantizationRangesAsInferrable(tensor, output_tensor); if (!legit_pad) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output_tensor, {0, 3, 2, 1}, StrCat(node_def.name(), "_from_pad"), - &output_tensor)); + output_tensor, {0, 3, 2, 1}, &output_tensor, node_def, "from_pad")); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4780,7 +4863,7 @@ Status ConvertSplitHelper(OpConverterParams* params, for (int i = 0; i < num_splits; ++i) { begin[trt_axis + 1] = i * split_size_on_axis; TF_RETURN_IF_ERROR(ConvertStridedSliceHelper( - params, input, begin, size, stride, final_shape_for_unpack_ptr)); + params, input, begin, size, stride, final_shape_for_unpack_ptr, i)); } return Status::OK(); } @@ -4854,6 +4937,7 @@ Status ConvertCast(OpConverterParams* params) { nvinfer1::ITensor* input = params->inputs.at(0).tensor(); nvinfer1::IIdentityLayer* layer = params->converter->network()->addIdentity(*input); + SetLayerName(layer, node_def); layer->setPrecision(nvinfer1::DataType::kFLOAT); if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) { @@ -4911,6 +4995,7 @@ Status ConvertConcat(OpConverterParams* params) { const_cast(input_tensors.data()), input_tensors.size()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); layer->setAxis(trt_axis); params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); return Status::OK(); @@ -5057,7 +5142,7 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { combined_scale_weights.GetTrtWeights(), dummy_power_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); - layer->setName(node_def.name().c_str()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5137,6 +5222,7 @@ Status ConvertGather(OpConverterParams* params) { nvinfer1::IGatherLayer* layer = params->converter->network()->addGather( *params_tensor, *indices_input.tensor(), trt_axis); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); nvinfer1::Dims trt_gather_output_dims = output_tensor->getDimensions(); @@ -5163,7 +5249,7 @@ Status ConvertGather(OpConverterParams* params) { TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(output_tensor), trt_gather_output_dims, - /*validation_only=*/false, &output_tensor)); + /*validation_only=*/false, &output_tensor, node_def)); } params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -5173,7 +5259,7 @@ Status ConvertGather(OpConverterParams* params) { Status ConvertFullyConnectedHelper(OpConverterParams* params, nvinfer1::ITensor* tensor_a, TRT_ShapedWeights weights_b, - bool transpose_b, const string& node_name) { + bool transpose_b, const NodeDef& node_def) { // Reshape input to 3D - this will be a no-op unless using int8 precision. auto input_dim = tensor_a->getDimensions(); while (input_dim.nbDims < 3) { @@ -5181,7 +5267,7 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, } TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(tensor_a), input_dim, /*validation_only=*/false, - &tensor_a)); + &tensor_a, node_def, /*op_instance=*/0)); // FC layer will transpose weights, so we need to pre-transpose. TRT_ShapedWeights weights(weights_b.TrtDType()); @@ -5197,7 +5283,8 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, params->converter->network()->addFullyConnected( *tensor_a, noutput, weights.GetTrtWeights(), biases.GetTrtWeights()); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Reshape output to 1D - this will be a no-op unless using int8 precision. @@ -5205,7 +5292,7 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, output_dim.nbDims = 1; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(output_tensor), output_dim, /*validation_only=*/false, - &output_tensor)); + &output_tensor, node_def, /*op_instance=*/1)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5214,7 +5301,7 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params, Status ConvertMatMulHelper(OpConverterParams* params, TRT_TensorOrWeights input_a, TRT_TensorOrWeights input_b, bool transpose_a, - bool transpose_b, string node_name) { + bool transpose_b, const NodeDef& node_def) { // TODO: ReorderCKtoKC is currently not general enough to transpose weights // that are not 2D. if ((transpose_a && input_a.is_weights() && @@ -5252,7 +5339,7 @@ Status ConvertMatMulHelper(OpConverterParams* params, if (should_use_fc || (can_use_fc && params->converter->precision_mode() == TrtPrecisionMode::INT8)) { return ConvertFullyConnectedHelper( - params, input_a.tensor(), input_b.weights(), transpose_b, node_name); + params, input_a.tensor(), input_b.weights(), transpose_b, node_def); } const auto get_matrix_op = [](nvinfer1::ITensor* in, @@ -5293,7 +5380,8 @@ Status ConvertMatMulHelper(OpConverterParams* params, *tensor_a, get_matrix_op(tensor_a, transpose_a), *tensor_b, get_matrix_op(tensor_b, transpose_b)); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_name); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5316,7 +5404,7 @@ Status ConvertMatMul(OpConverterParams* params) { bool transpose_b = attrs.get("transpose_b"); return ConvertMatMulHelper(params, inputs.at(0), inputs.at(1), transpose_a, - transpose_b, node_def.name()); + transpose_b, node_def); } Status ConvertBatchMatMul(OpConverterParams* params) { @@ -5379,14 +5467,16 @@ Status ConvertBatchMatMul(OpConverterParams* params) { nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l)); + inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l, + node_def)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r)); + inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r, + node_def)); if (params->validation_only) return Status::OK(); return ConvertMatMulHelper(params, TRT_TensorOrWeights(tensor_l), TRT_TensorOrWeights(tensor_r), transpose_a, - transpose_b, node_def.name()); + transpose_b, node_def); } Status ConvertSoftmax(OpConverterParams* params) { @@ -5408,6 +5498,7 @@ Status ConvertSoftmax(OpConverterParams* params) { nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); // Tensorflow SoftMax assumes applying softmax on the last dimension. layer->setAxes(1 << (num_trt_dims - 1)); @@ -5452,6 +5543,7 @@ Status ConvertArgMinMax(OpConverterParams* params) { nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK( *inputs.at(0).tensor(), topk_op, 1, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "topk"); nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); // Squeeze on axis. @@ -5462,7 +5554,7 @@ Status ConvertArgMinMax(OpConverterParams* params) { nvinfer1::ITensor* output_tensor = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(output_indices_tensor), new_dims, - /*validation_only=*/false, &output_tensor)); + /*validation_only=*/false, &output_tensor, node_def)); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5508,6 +5600,7 @@ Status ConvertTopK(OpConverterParams* params) { nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK(*tensor, op, k, reduce_axes); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); nvinfer1::ITensor* output_value_tensor = layer->getOutput(0); nvinfer1::ITensor* output_indices_tensor = layer->getOutput(1); @@ -5583,6 +5676,7 @@ Status ConvertDepthSpaceShuffle(OpConverterParams* params) { nvinfer1::IShuffleLayer* first_shuffle = params->converter->network()->addShuffle(*inputs.at(0).tensor()); TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name()); + SetLayerName(first_shuffle, node_def, "shuffle", /*op_instance=*/0); if (data_format == "NHWC") { first_shuffle->setFirstTranspose({2, 0, 1}); } @@ -5592,6 +5686,7 @@ Status ConvertDepthSpaceShuffle(OpConverterParams* params) { nvinfer1::IShuffleLayer* second_shuffle = params->converter->network()->addShuffle(*first_shuffle->getOutput(0)); TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name()); + SetLayerName(second_shuffle, node_def, "shuffle", /*op_instance=*/1); second_shuffle->setReshapeDimensions(second_shuffle_shape); if (data_format == "NHWC") { second_shuffle->setSecondTranspose({1, 2, 0}); @@ -5619,9 +5714,11 @@ Status ConvertSquaredDifference(OpConverterParams* params) { nvinfer1::ITensor* tensor_l = nullptr; nvinfer1::ITensor* tensor_r = nullptr; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l)); + inputs.at(0), broadcasted_dims_l, params->validation_only, &tensor_l, + node_def)); TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( - inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r)); + inputs.at(1), broadcasted_dims_r, params->validation_only, &tensor_r, + node_def)); if (params->validation_only) return Status::OK(); // Subtract x - y. @@ -5629,12 +5726,15 @@ Status ConvertSquaredDifference(OpConverterParams* params) { params->converter->network()->addElementWise( *tensor_l, *tensor_r, nvinfer1::ElementWiseOperation::kSUB); TFTRT_RETURN_ERROR_IF_NULLPTR(sub, node_def.name()); + SetLayerName(sub, node_def, "sub"); + // Multiply (x - y) * (x - y). nvinfer1::IElementWiseLayer* mul = params->converter->network()->addElementWise( *sub->getOutput(0), *sub->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); TFTRT_RETURN_ERROR_IF_NULLPTR(mul, node_def.name()); + SetLayerName(mul, node_def, "mul"); params->outputs->push_back(TRT_TensorOrWeights(mul->getOutput(0))); return Status::OK(); @@ -5772,6 +5872,7 @@ Status ConvertCombinedNMS(OpConverterParams* params) { nvinfer1::IPluginV2Layer* layer = params->converter->network()->addPluginV2( &plugin_inputs[0], static_cast(plugin_inputs.size()), *plugin); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, "plugin"); // Set plugin outputs nvinfer1::ITensor* output_nmsed_boxes = layer->getOutput(1); @@ -5785,8 +5886,8 @@ Status ConvertCombinedNMS(OpConverterParams* params) { nvinfer1::ITensor* output_nmsed_scores = nullptr; nvinfer1::ITensor* output_nmsed_classes = nullptr; - auto shrink_last_dim = [params](nvinfer1::ITensor* in_tensor, - nvinfer1::ITensor** out_tensor) { + auto shrink_last_dim = [&](int output_index, nvinfer1::ITensor** out_tensor) { + nvinfer1::ITensor* in_tensor = layer->getOutput(output_index); nvinfer1::Dims dims = in_tensor->getDimensions(); if (dims.d[dims.nbDims - 1] != 1) { return errors::Internal("Expect last dims to be 1, for tensor ", @@ -5795,15 +5896,12 @@ Status ConvertCombinedNMS(OpConverterParams* params) { --dims.nbDims; TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( TRT_TensorOrWeights(in_tensor), dims, - /*validation_only=*/false, out_tensor)); + /*validation_only=*/false, out_tensor, node_def, output_index)); return Status::OK(); }; - TF_RETURN_IF_ERROR( - shrink_last_dim(layer->getOutput(2), &output_nmsed_scores)); - TF_RETURN_IF_ERROR( - shrink_last_dim(layer->getOutput(3), &output_nmsed_classes)); - TF_RETURN_IF_ERROR( - shrink_last_dim(layer->getOutput(0), &output_num_detections)); + TF_RETURN_IF_ERROR(shrink_last_dim(2, &output_nmsed_scores)); + TF_RETURN_IF_ERROR(shrink_last_dim(3, &output_nmsed_classes)); + TF_RETURN_IF_ERROR(shrink_last_dim(0, &output_num_detections)); #endif // IS_TRT_VERSION_GE(6, 0, 0, 0) params->outputs->push_back(TRT_TensorOrWeights(output_nmsed_boxes)); @@ -5845,6 +5943,12 @@ Status ConvertResize(OpConverterParams* params) { // Verify resize mode. Initialize resize mode if supported. nvinfer1::ResizeMode resize_mode; if (node_def.op() == "ResizeBilinear") { +#if IS_TRT_VERSION_GE(7, 1, 0, 0) + if (!align_corners) { + return errors::InvalidArgument( + "Cannot Convert Bilinear Resize when align_corners=False"); + } +#endif resize_mode = nvinfer1::ResizeMode::kLINEAR; } else if (node_def.op() == "ResizeNearestNeighbor") { resize_mode = nvinfer1::ResizeMode::kNEAREST; @@ -5858,7 +5962,7 @@ Status ConvertResize(OpConverterParams* params) { // Transpose tensor from NHWC to NCHW format. TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - tensor, {0, 3, 1, 2}, StrCat(node_def.name(), "_to_NCHW"), &tensor)); + tensor, {0, 3, 1, 2}, &tensor, node_def, "to_NCHW")); // Calculate output dimensions. // Given input dimensions [N, C, H, W] and output size [H_out, W_out], @@ -5875,6 +5979,7 @@ Status ConvertResize(OpConverterParams* params) { nvinfer1::IResizeLayer* layer = params->converter->network()->addResize(*tensor); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def); // Set layer parameters. layer->setResizeMode(resize_mode); @@ -5885,7 +5990,7 @@ Status ConvertResize(OpConverterParams* params) { nvinfer1::ITensor* output = layer->getOutput(0); TF_RETURN_IF_ERROR(params->converter->TransposeTensor( - output, {0, 2, 3, 1}, StrCat(node_def.name(), "_to_NHWC"), &output)); + output, {0, 2, 3, 1}, &output, node_def, "to_NHWC")); params->outputs->push_back(TRT_TensorOrWeights(output)); // Success return Status::OK(); @@ -5934,6 +6039,7 @@ Status ConvertAddN(OpConverterParams* params) { nvinfer1::ILayer* layer = params->converter->network()->addElementWise( *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUM); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + SetLayerName(layer, node_def, std::to_string(i)); lhs = layer->getOutput(0); } params->outputs->push_back(TRT_TensorOrWeights(lhs)); @@ -6056,6 +6162,8 @@ Status ConvertGraphDefToEngine( VLOG(1) << "Starting to convert TensorFlow ops to TensorRT layers"; std::vector output_tensors; + int num_layers = converter->network()->getNbLayers(); + absl::flat_hash_set layer_names; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { const string& node_name = node_def.name(); @@ -6134,6 +6242,25 @@ Status ConvertGraphDefToEngine( } else { TF_RETURN_IF_ERROR(converter->ConvertNode(node_def)); } + + // To support TF-TRT profiling, we ensure each ILayer has a non-empty name. + // BuildCudaEngine returns an error if there is any ILayer name collision. + // We want to report the error here before BuildCudaEngine in a more + // meaningful way. + int new_num_layers = converter->network()->getNbLayers(); + for (int i = num_layers; i < new_num_layers; i++) { + auto layer = converter->network()->getLayer(i); + if (layer->getName() == nullptr || + !layer_names.insert(layer->getName()).second) { + std::string error_message = + absl::StrCat("Converting node ", node_name, ", op=", node_def.op(), + layer->getName() ? "create a layer with name collision" + : "create a layer without a name"); + LOG_WARNING_WITH_PREFIX << error_message; + return errors::Internal(error_message); + } + } + num_layers = new_num_layers; } TF_RETURN_IF_ERROR(converter->RenameAndMarkOutputTensors(output_tensors)); if (convert_successfully) *convert_successfully = true; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index a621735fad1..4a84793e254 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" @@ -515,14 +516,18 @@ class Converter { // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch - // dimension which should always be 0. + // dimension which should always be 0. If this is for adding a transpose layer + // to support the conversion of 'node_def', callers need to provide a + // non-empty 'sub_op_name' appended to the name of 'node_def' to avoid layer + // name conflicts. Status TransposeTensor(nvinfer1::ITensor* input_tensor, const std::vector& order_with_batch_dim, - absl::string_view name, - nvinfer1::ITensor** output_tensor); + nvinfer1::ITensor** output_tensor, + const NodeDef& node_def, + absl::string_view sub_op_name = ""); - // Converts 'input' into 'tensor' with shape specified by 'dims' (which - // doesn't contain the batch dimension). + // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims' + // (which doesn't contain the batch dimension). // // If validation_only is true, it doesn't do the conversion but only do some // minimum validation for the eligibility of the conversion, and *tensor will @@ -530,7 +535,9 @@ class Converter { Status PrepareTensorForShape(const TRT_TensorOrWeights& input, const nvinfer1::Dims& dims, const bool validation_only, - nvinfer1::ITensor** tensor); + nvinfer1::ITensor** tensor, + const NodeDef& node_def, + absl::optional op_instance = absl::nullopt); // Reshapes a dynamic shape tensor by removing or adding dimensions of size 1, // and/or permuting the dimensions. The new shape is derived from the shape of @@ -575,12 +582,14 @@ class Converter { Status DynamicReshape(nvinfer1::ITensor* input, std::vector> slices, OpConverterParams* params, nvinfer1::ITensor** output, - std::vector size_for_added_dims = {}); + std::vector size_for_added_dims = {}, + absl::optional op_instance = absl::nullopt); // Inserts a singleton dimension at axis for a dynamic shape tensor. Status DynamicExpandDims(nvinfer1::ITensor* input, const nvinfer1::Dims& dims, int axis, OpConverterParams* params, - nvinfer1::ITensor** output); + nvinfer1::ITensor** output, + absl::optional op_instance = absl::nullopt); // Helper function to add a squeeze op to the network. // @@ -667,6 +676,10 @@ class Converter { // acceptable by TRT. int batch_size_ = -1; + // Assign a ID to each constant layer we create, so that we can assign a + // unique name to the layer. + int next_constant_layer_id_ = 0; + friend class ConverterTest; friend class OpConverterTest; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index b127337e02a..86e6f0dd345 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -203,6 +204,23 @@ void ExpectTrtDimsEqualsArray(const std::vector& lhs, << " actual: " << DebugString(rhs); } +void ExpectTrtLayerNames(absl::Span names, + nvinfer1::INetworkDefinition* network) { + EXPECT_EQ(network->getNbLayers(), names.size()); + + for (int i = 0; i < network->getNbLayers(); i++) { + auto layer = network->getLayer(i); + EXPECT_EQ(layer->getName(), names[i]); + } +} + +void VerifyTrtLayerNameNotEmpty(nvinfer1::INetworkDefinition* network) { + for (int i = 0; i < network->getNbLayers(); i++) { + auto layer = network->getLayer(i); + EXPECT_NE(layer->getName(), nullptr); + } +} + Matcher> ArrayFloatNear(const std::vector& values, float max_abs_error = 1e-5, bool nan_sensitive = false) { @@ -803,6 +821,8 @@ TEST_F(ConverterTest, ConvertNode) { TF_EXPECT_OK(GetTensorOrWeights("my_op:1", &actual_output_2)); EXPECT_EQ(&output_tensors[1], actual_output_2.tensor()); EXPECT_EQ(125, actual_output_2.tensor()->getDimensions().d[0]); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, AddAndGetInputs) { @@ -832,6 +852,8 @@ TEST_F(ConverterTest, AddAndGetInputs) { ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions()); ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions()); ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions()); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, RenameAndMarkOutputTensors) { @@ -880,30 +902,33 @@ TEST_F(ConverterTest, RenameAndMarkOutputTensors) { } EXPECT_EQ("my_output", string(output_tensors[0]->getName())); EXPECT_EQ("my_output_1", string(output_tensors[1]->getName())); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, TransposeTensor) { nvinfer1::ITensor* input_tensor = converter_->network()->addInput( "", nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5})); nvinfer1::ITensor* output_tensor = nullptr; - + NodeDef dummy_node_def = MakeNodeDef("dummy_op", "DummyOp", {}); // Rank doesn't match. ExpectStatus( - converter_->TransposeTensor(input_tensor, {0, 1}, "Bad perm", - &output_tensor), + converter_->TransposeTensor(input_tensor, {0, 1}, &output_tensor, + dummy_node_def, "sub1"), error::INVALID_ARGUMENT, "Rank of perm for transpose does not match with that of the input"); // Transpose at batch dimension. - ExpectStatus(converter_->TransposeTensor(input_tensor, {1, 0, 2, 3}, - "Batch perm", &output_tensor), - error::UNIMPLEMENTED, - "Transpose at batch dimension is not supported."); + ExpectStatus( + converter_->TransposeTensor(input_tensor, {1, 0, 2, 3}, &output_tensor, + dummy_node_def, "sub2"), + error::UNIMPLEMENTED, "Transpose at batch dimension is not supported."); // OK. - TF_EXPECT_OK(converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, "OK", - &output_tensor)); + TF_EXPECT_OK(converter_->TransposeTensor( + input_tensor, {0, 3, 1, 2}, &output_tensor, dummy_node_def, "sub3")); ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions()); + ExpectTrtLayerNames({"dummy_op-sub3"}, converter_->network()); } void TestPrepareTensorForShape( @@ -922,9 +947,11 @@ void TestPrepareTensorForShape( } nvinfer1::ITensor* output_tensor = nullptr; + NodeDef dummy_node_def = MakeNodeDef("dummy_op", "DummyOp", {}); for (bool validation_only : {false, true}) { const Status status = converter->PrepareTensorForShape( - input, GetTestDims(reshape_dims), validation_only, &output_tensor); + input, GetTestDims(reshape_dims), validation_only, &output_tensor, + dummy_node_def); if (expected_code == error::OK) { TF_EXPECT_OK(status); if (validation_only) { @@ -978,6 +1005,8 @@ TEST_F(ConverterTest, PrepareTensorForShape) { /*input_is_tensor=*/false, converter_.get(), weight_store_, error::INVALID_ARGUMENT, "Shape is not fully defined"); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, MaybeUpdateBatchSize) { @@ -1051,6 +1080,8 @@ TEST_F(ConverterTest, ProvideQuantizationRange) { // Symmetric range converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f); EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { @@ -1077,6 +1108,8 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) { EXPECT_EQ(infer_3.getDynamicRange(), 5.0f); EXPECT_EQ(not_infer.getDynamicRange(), 100.0f); #endif + + VerifyTrtLayerNameNotEmpty(int8_converter->network()); } TEST_F(ConverterTest, PropagateQuantizationRanges) { @@ -1099,6 +1132,8 @@ TEST_F(ConverterTest, PropagateQuantizationRanges) { EXPECT_EQ(5.0f, ranges[&infer[i]]); } EXPECT_EQ(ranges.count(¬_infer), 0); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, GetTrtBroadcastShape) { @@ -1202,6 +1237,8 @@ TEST_F(ConverterTest, GetTrtBroadcastShape) { "(tensor #dims 4 vs broadcast #dims 5)"); symmetric_test({2, 3}, {7, 5}, kIsTensor, kIsTensor, {}, {}, error::INVALID_ARGUMENT, "Infeasible broadcast scheme"); + + VerifyTrtLayerNameNotEmpty(converter_->network()); } TEST_F(ConverterTest, CreateConstantLayer) { @@ -1216,6 +1253,8 @@ TEST_F(ConverterTest, CreateConstantLayer) { << DebugString(tensor->getType()); ExpectTrtDimsEqualsArray({3, 10}, tensor->getDimensions()); } + + VerifyTrtLayerNameNotEmpty(converter_->network()); } class ConvertGraphDefToEngineTest : public ::testing::Test { @@ -1575,6 +1614,9 @@ class OpConverterTest : public ::testing::Test { const char* expected_msg_substr = nullptr) { ExpectStatus(converter_->ConvertNode(node->def()), expected_code, expected_msg_substr); + if (expected_code == error::OK) { + VerifyTrtLayerNameNotEmpty(converter_->network()); + } } // Helper method to run both validation and conversion, when the expected @@ -6771,8 +6813,6 @@ void TestConvertResize(OpConverterTest* test) { typedef typename EnumToDataType::Type CType; std::vector> params { -// TODO(b/162442839): Enable the test parameters for TRT 7.1.3.x. -#if !IS_TRT_VERSION_GE(7, 1, 3, 0) { /*input_dims=*/{1, 2, 1}, // H, W, C /*output_resize_dims=*/{2, 3}, // H_out, W_out @@ -6784,7 +6824,6 @@ void TestConvertResize(OpConverterTest* test) { /*expected_bilinear_output_values=*/ CastTestVector({2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f}), }, -#endif { /*input_dims=*/{1, 2, 1}, // H, W, C /*output_resize_dims=*/{2, 3}, // H_out, W_out @@ -6798,6 +6837,13 @@ void TestConvertResize(OpConverterTest* test) { } }; +// This use case is not supported as of TRT version 7.1 +#if IS_TRT_VERSION_GE(7, 1, 0, 0) + if (std::is_same::value) { + params.erase(params.begin()); + } +#endif + for (int i = 0; i < params.size(); ++i) { test->Reset(); // Create resize node. @@ -6840,7 +6886,7 @@ TEST_F(OpConverterTest, ConvertResize) { // First input is weight, should fail. Reset(); NodeDef node_def = - MakeResizeNodeDef("my_resize", DT_FLOAT, false); + MakeResizeNodeDef("my_resize", DT_FLOAT, true); AddTestWeights("input", {1, 2}, {1, 2}); AddTestWeights("size", {1, 2}, {1, 2}); RunValidationAndConversion( @@ -6852,7 +6898,7 @@ TEST_F(OpConverterTest, ConvertResize) { // output dimension is a tensor, should fail. Reset(); NodeDef node_def = - MakeResizeNodeDef("my_resize", DT_FLOAT, false); + MakeResizeNodeDef("my_resize", DT_FLOAT, true); AddTestTensor("input", {1, 2}); AddTestTensor("size", {1, 2}); RunValidationAndConversion( diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 021e28ec6f0..35f966c816f 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -478,7 +478,7 @@ absl::Span GetInputsToDeterminateBatchSize( "Add", "AddV2", "Mul", - "Sub" + "Sub", "Div", "FloorDiv", "RealDiv", @@ -649,7 +649,7 @@ ClusterBatchSize GetClusterBatchSizeForNode( if (!graph_properties || !graph_properties->HasInputProperties(node->name())) { VLOG(3) << "doesn't have input property"; - return cluster_batch_size.SetBatchSizeValue(-1); + return cluster_batch_size.SetBatchSize(-1); } const std::vector& input_properties = @@ -660,7 +660,7 @@ ClusterBatchSize GetClusterBatchSizeForNode( const TensorShapeProto* leading_shape = optional_leading_shape.value(); DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2); - return cluster_batch_size.SetBatchSizeValue(leading_shape->dim(0).size()); + return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size()); } void AddSegmentForNode(const grappler::GraphProperties* graph_properties, @@ -668,12 +668,12 @@ void AddSegmentForNode(const grappler::GraphProperties* graph_properties, SimpleNode* node, const DeviceNameUtils::ParsedName& device_name, bool use_implicit_batch) { - segments->emplace_back( - node, + ClusterProperty property( GetClusterBatchSizeForNode(graph_properties, node == nullptr ? nullptr : node->tf_node(), use_implicit_batch), device_name); + segments->emplace_back(node, std::move(property)); } bool OpBatchSizeExceedMaximumBatchSize( @@ -681,10 +681,10 @@ bool OpBatchSizeExceedMaximumBatchSize( bool use_implicit_batch, absl::optional maximum_batch_size) { ClusterBatchSize cluster_batch_size = GetClusterBatchSizeForNode(graph_properties, node, use_implicit_batch); - if (cluster_batch_size.HasStaticBatchValue() && + if (cluster_batch_size.HasStaticBatchSize() && maximum_batch_size.has_value() && - cluster_batch_size.GetStaticBatchValue() > maximum_batch_size.value()) { - VLOG(2) << "OP batch size " << cluster_batch_size.GetStaticBatchValue() + cluster_batch_size.GetStaticBatchSize() > maximum_batch_size.value()) { + VLOG(2) << "OP batch size " << cluster_batch_size.GetStaticBatchSize() << " max_batch_size " << maximum_batch_size.value(); return true; } @@ -846,9 +846,9 @@ Status SegmentGraph(const Graph* tf_graph, // step until no output edges can be further contracted. This is because // contracting an output edge may unblock new edges for contracting. ClusterBatchSize expected_batch_size = - node_segments[node->id()].BatchSize(); + node_segments[node->id()].Property().BatchSize(); DeviceNameUtils::ParsedName expected_device_name = - node_segments[node->id()].DeviceName(); + node_segments[node->id()].Property().DeviceName(); VLOG(3) << "batch size " << expected_batch_size; while (true) { std::set contract_edges; @@ -869,7 +869,7 @@ Status SegmentGraph(const Graph* tf_graph, continue; } // Out node must have compatible batch size. - ClusterBatchSize out_batch_size = out_cluster->BatchSize(); + ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize(); ClusterBatchSize merged_batch_size = expected_batch_size; if (!merged_batch_size.MergeIfCompatible(out_batch_size)) { VLOG(3) << "... ... incompatible batch sizes " @@ -879,7 +879,7 @@ Status SegmentGraph(const Graph* tf_graph, } const DeviceNameUtils::ParsedName& out_device_name = - out_cluster->DeviceName(); + out_cluster->Property().DeviceName(); absl::optional merged_device_name = MergeIfCompatible(expected_device_name, out_device_name); if (!merged_device_name.has_value()) { @@ -925,11 +925,13 @@ Status SegmentGraph(const Graph* tf_graph, graph->RemoveEdge(r); } } - if (expected_batch_size != node_segments[node->id()].BatchSize()) { + if (expected_batch_size != + node_segments[node->id()].Property().BatchSize()) { return errors::Internal( "expected batch size is not the same as the actual batch size"); } - if (expected_device_name != node_segments[node->id()].DeviceName()) { + if (expected_device_name != + node_segments[node->id()].Property().DeviceName()) { return errors::Internal( "expected device name is not the same as the actual device name"); } diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.cc b/tensorflow/compiler/tf2tensorrt/segment/union_find.cc new file mode 100644 index 00000000000..289a2734183 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" + +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/core/lib/core/errors.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace segment { + +namespace { +template +inline bool CheckIfCompatible(const absl::optional& a, + const absl::optional& b) { + if (!a.has_value() && !b.has_value()) { + return true; + } + if (a.has_value() && b.has_value()) { + return *a == *b; + } + return true; +} + +template +inline bool UnifyValues(absl::optional& a, absl::optional& b) { + if (a.has_value()) { + b = a; + } else { + a = b; + } + return true; +} + +template +inline absl::optional MergeCompatible(const absl::optional& a, + const absl::optional& b) { + DCHECK(CheckIfCompatible(a, b)); + return b; +} + +} // namespace + +ClusterBatchSize::ClusterBatchSize() + : has_dynamic_batch_size_(false), static_batch_size_(absl::nullopt) {} + +bool ClusterBatchSize::operator==(const ClusterBatchSize& other) { + return has_dynamic_batch_size_ == other.has_dynamic_batch_size_ && + static_batch_size_ == other.static_batch_size_; +} + +int ClusterBatchSize::GetStaticBatchSize() const { + DCHECK(HasStaticBatchSize()); + return static_batch_size_.value(); +} + +// Sets the batch size assuming that the object doesn't have a batch size yet: +// a non-negative input value representing a static batch size. +// a negative input value representing a dynamic batch size. +ClusterBatchSize& ClusterBatchSize::SetBatchSize(int batch_size) { + if (batch_size < 0) { + has_dynamic_batch_size_ = true; + return *this; + } + static_batch_size_ = MergeCompatible(static_batch_size_, batch_size); + return *this; +} + +bool ClusterBatchSize::MergeIfCompatible(const ClusterBatchSize& other) { + if (!CheckIfCompatible(static_batch_size_, other.static_batch_size_)) { + return false; + } + if (other.HasStaticBatchSize()) { + static_batch_size_ = other.GetStaticBatchSize(); + } + if (other.HasDynamicBatchSize()) { + has_dynamic_batch_size_ = true; + } + return true; +} + +string ClusterBatchSize::ToString() const { + string s; + absl::StrAppendFormat(&s, "batch_size=(%d,%d", HasDynamicBatchSize(), + HasStaticBatchSize()); + if (HasStaticBatchSize()) { + absl::StrAppendFormat(&s, ",%d", GetStaticBatchSize()); + } + absl::StrAppend(&s, ")"); + return s; +} + +ClusterProperty::ClusterProperty(const ClusterBatchSize& batch_size, + const DeviceNameUtils::ParsedName& device_name) + : batch_size_(batch_size), device_name_(device_name) {} + +Status ClusterProperty::Merge(const ClusterProperty& other) { + ClusterBatchSize merged_batch_size(batch_size_); + if (!merged_batch_size.MergeIfCompatible(other.batch_size_)) { + return errors::Internal( + "trying to merge clusters with incompatible batch sizes."); + } + + absl::optional merged_device_name = + MergeIfCompatible(device_name_, other.device_name_); + if (!merged_device_name.has_value()) { + return errors::Internal( + "trying to merge clusters with incompatible device assignment."); + } + + batch_size_ = std::move(merged_batch_size); + device_name_ = std::move(merged_device_name.value()); + return Status::OK(); +} + +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index 54bbc251e4f..74a20aa4a24 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ -#include "absl/strings/str_format.h" #include "absl/types/optional.h" +#include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -31,108 +31,59 @@ namespace segment { // When constructing clusters for implicit batch mode, we support the // with both dynamic batch size and static batch size. We restrict nodes inside // a cluster to either have dynamic batch size or have the same value for static -// batch size. For this reason, we use a field has_dynamic_batch_value_ to keep +// batch size. For this reason, we use a field has_dynamic_batch_size_ to keep // track of whether the cluster has any node with dynamic batch size. We use -// field static_batch_value_ to keep track of whether the cluster has any node +// field static_batch_size_ to keep track of whether the cluster has any node // with static batch size and what the value of the static batch size, if any. // Examples: // cluster: a = a1[1,3] + a1[1,3] // ClusterBatchSize: has_dynamic_batch_size_ = false -// static_batch_value_ = {has value, 1} +// static_batch_size_ = {has value, 1} // // cluster: b = b1[-1,3] + b2[-1, 3] // ClusterBatchSize: has_dynamic_batch_size_ = true -// static_batch_value_ = {has no value} +// static_batch_size_ = {has no value} // // cluster: a = a1[1,3] + a1[1,3]; b = b1[-1,3] + b2[-1, 3] // ClusterBatchSize: has_dynamic_batch_size_ = true -// static_batch_value_ = {has value, 1} +// static_batch_size_ = {has value, 1} // // When constructing cluster for explicit batch mode, all ClusterBatchSize is // irrelevant. // // -absl::optional static_batch_value_; + class ClusterBatchSize { public: - ClusterBatchSize() - : has_dynamic_batch_value_(false), static_batch_value_(absl::nullopt) {} + ClusterBatchSize(); - bool operator==(const ClusterBatchSize& b) { - return HasDynamicBatchValue() == b.HasDynamicBatchValue() && - static_batch_value_ == b.static_batch_value_; - } + bool operator==(const ClusterBatchSize& other); + bool operator!=(const ClusterBatchSize& other) { return !(*this == other); } - bool operator!=(const ClusterBatchSize& b) { return !(*this == b); } - - int GetStaticBatchValue() const { - DCHECK(HasStaticBatchValue()); - return static_batch_value_.value(); - } - - // Sets the batch size value assuming that the object doesn't have a batch - // size value yet: - // a non-negative input value representing a known batch size. + // Sets the batch size assuming that the object doesn't have a batch size yet: + // a non-negative input value representing a static batch size. // a negative input value representing a dynamic batch size. - ClusterBatchSize SetBatchSizeValue(int value) { - if (value < 0) { - has_dynamic_batch_value_ = true; - return *this; - } - static_batch_value_ = value; - return *this; - } + ClusterBatchSize& SetBatchSize(int batch_size); + bool HasStaticBatchSize() const { return static_batch_size_.has_value(); } + int GetStaticBatchSize() const; - bool MergeIfCompatible(const ClusterBatchSize& b) { - bool is_compatible = MergeIfCompatible(b.static_batch_value_); - if (!is_compatible) return false; + bool MergeIfCompatible(const ClusterBatchSize& other); - if (!HasDynamicBatchValue() && b.HasDynamicBatchValue()) { - has_dynamic_batch_value_ = true; - } - - return true; - } - - // Returns a string for the batch size value. If the object has a static - // batch size value, return a string for the value. If the object has a - // dynamic size value, return -1. Otherwise, returns -2 to represent that - // a batch size hasn't been set yet. - string ToString() const { - string s; - absl::StrAppendFormat(&s, "batch_size=(%d,%d,", HasDynamicBatchValue(), - HasStaticBatchValue()); - if (HasStaticBatchValue()) { - absl::StrAppendFormat(&s, "%d", GetStaticBatchValue()); - } - absl::StrAppend(&s, ")"); - return s; - } - - bool HasStaticBatchValue() const { return static_batch_value_.has_value(); } + // Returns a string for the batch size. + // If the object has a static batch size, return a string for the value. + // If the object has a dynamic size, return -1. + // Otherwise, returns -2 to represent that the batch size hasn't been set + // yet. + std::string ToString() const; private: - bool HasDynamicBatchValue() const { return has_dynamic_batch_value_; } + bool HasDynamicBatchSize() const { return has_dynamic_batch_size_; } - private: - bool MergeIfCompatible(const absl::optional& b) { - bool is_compatible = !HasStaticBatchValue() || !b.has_value() || - GetStaticBatchValue() == b.value(); - if (!is_compatible) { - return false; - } - if (!HasStaticBatchValue() && b.has_value()) { - static_batch_value_ = b; - } - return true; - } - - private: // To track whether the cluster has any node with dynamic batch size. - bool has_dynamic_batch_value_; + bool has_dynamic_batch_size_; // To track whether the cluster has any node with static batch size, and the // unique value for static batch size. - absl::optional static_batch_value_; + absl::optional static_batch_size_; }; inline std::ostream& operator<<(std::ostream& os, @@ -140,89 +91,89 @@ inline std::ostream& operator<<(std::ostream& os, return os << batch_size.ToString(); } -// Represents a disjoint set of copyable values with type T. We use this data -// structure to construct clusters for TRTEngineOp. As such, this data structure -// has a field to record the batch size for the current cluster and merges the -// corresponding batch sizes when merging two clusters. Most of the methods in -// this class are side-effecting as they also compress the path from the object -// to the parent of its containing set. -template -class UnionFind { +// Represents the accumulated properties of a cluster during segmentation, +// including information about batch size and device assignment. Clusters shall +// have compatible properties in order to be merged together. +class ClusterProperty { public: - UnionFind() : size_(1), parent_(nullptr) {} - UnionFind(const T& v, ClusterBatchSize batch_size, - const DeviceNameUtils::ParsedName& device_name) - : size_(1), - cluster_batch_size_(batch_size), - cluster_device_name_(device_name), - parent_(nullptr), - value_(v) {} - - // Returns the number of elements in the cluster and compresses the path from - // this object to the root of the cluster. - int Size() { return FindRoot()->size_; } + ClusterProperty() {} + ClusterProperty(const ClusterBatchSize& batch_size, + const DeviceNameUtils::ParsedName& device_name); // Returns the batch size of the cluster and compresses the path from this // object to the root object. - ClusterBatchSize BatchSize() { return FindRoot()->cluster_batch_size_; } + const ClusterBatchSize& BatchSize() const { return batch_size_; } // Returns the device name of the cluster and compresses the path from this // object to the root object. - const DeviceNameUtils::ParsedName& DeviceName() { - return FindRoot()->cluster_device_name_; - } + const DeviceNameUtils::ParsedName& DeviceName() const { return device_name_; } - // Merges this cluster with 'other'. This cluster's size_ is updated to - // the size of the merged cluster; the size_ of 'other' becomes inaccessible - // as only the size_ of the root object is accessible. - Status Merge(UnionFind* other); - - // Retrieves the value for the root of the cluster. - T& ParentValue() { return FindRoot()->value_; } - - // Returns the value for the object. - T& Value() { return value_; } + Status Merge(const ClusterProperty& other); private: - // Returns the root object for the cluster and compresses the path from this + ClusterBatchSize batch_size_; + DeviceNameUtils::ParsedName device_name_; +}; + +// Represents a disjoint set of copyable value with type T and accumulated +// property of the values with type P. Most of the methods in this class are +// side-effecting as they also compress the path from the object to the parent +// of its containing set. +template +class UnionFind { + public: + UnionFind() : size_(1), parent_(nullptr) {} + UnionFind(const T& v, const P& p) + : size_(1), parent_(nullptr), value_(v), property_(p) {} + UnionFind(const T& v, P&& p) + : size_(1), parent_(nullptr), value_(v), property_(p) {} + + // Returns the number of elements in the set and compresses the path from + // this object to the root of the set. + int Size() { return FindRoot()->size_; } + + // Returns the accumulated property of all the elements in the set and + // compresses the path from this object to the root of the set. + const P& Property() { return FindRoot()->property_; } + + // Merges this set with 'other'. This updates the size_ and property_ of the + // set. The size_ and property_ of 'other' becomes inaccessible as only the + // size_ and property_ of the root of the set is accessible. + Status Merge(UnionFind* other); + + // Retrieves the value for the root of the set. + const T& ParentValue() { return FindRoot()->value_; } + + // Returns the value for the object. + const T& Value() const { return value_; } + + private: + // Returns the root object for the set and compresses the path from this // object to the root object. UnionFind* FindRoot(); int size_; - ClusterBatchSize cluster_batch_size_; - DeviceNameUtils::ParsedName cluster_device_name_; UnionFind* parent_; T value_; + P property_; }; -template -Status UnionFind::Merge(UnionFind* other) { +template +Status UnionFind::Merge(UnionFind* other) { UnionFind* a = FindRoot(); UnionFind* b = other->FindRoot(); if (a == b) return Status::OK(); - ClusterBatchSize batch_size = a->cluster_batch_size_; - if (!batch_size.MergeIfCompatible(other->cluster_batch_size_)) { - return errors::Internal( - "trying to merge clusters with incompatible batch sizes."); - } - - absl::optional device_name = - MergeIfCompatible(a->cluster_device_name_, other->cluster_device_name_); - if (!device_name.has_value()) { - return errors::Internal( - "trying to merge clusters with incompatible device assignment."); - } - - a->cluster_batch_size_ = batch_size; - a->cluster_device_name_ = *device_name; + P merged_property(a->property_); + TF_RETURN_IF_ERROR(merged_property.Merge(b->property_)); b->parent_ = a; a->size_ += b->size_; + a->property_ = std::move(merged_property); return Status::OK(); } -template -UnionFind* UnionFind::FindRoot() { +template +UnionFind* UnionFind::FindRoot() { if (!parent_) return this; // Path compression: update intermediate nodes to point to the root of the // equivalence class. diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 276444b3cb1..1cce3424fae 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,5 +1,5 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_openmp_copts") +load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cc_binary", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", "tf_openmp_copts") load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -13,9 +13,6 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") - -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//tensorflow/compiler/xla/service/cpu:build_defs.bzl", "runtime_copts") @@ -45,6 +42,7 @@ package_group( "//tensorflow/...", "//tensorflow_models/...", "//third_party/mlperf/submissions/training/v0_7/models/...", + "//third_party/py/keras/...", ], ) @@ -83,19 +81,6 @@ tf_proto_library( visibility = ["//visibility:public"], ) -# A proto library that is minimal in size and dependencies for platforms like Android. -tf_portable_proto_library( - name = "portable_tf2xla_proto", - config_string = "allow_all:true", - header_outs = ["//tensorflow/compiler/tf2xla/tf2xla.proto.h"], - portable_deps = ["//tensorflow/core:portable_proto_lib"], - proto_deps = [ - ":tf2xla_proto", - "//tensorflow/core:protos_all", - ], - visibility = ["//visibility:public"], -) - xla_py_proto_library( name = "tf2xla_py", has_services = False, @@ -313,7 +298,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor:platform", - ] + if_tpu( + ] + if_libtpu( if_false = [ "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service/cpu:buffer_info_util", @@ -384,7 +369,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", - ] + if_tpu( + ] + if_libtpu( if_false = [ "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", @@ -464,8 +449,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/platform:stream_executor_no_cuda", ], alwayslink = 1, ) @@ -892,13 +877,13 @@ cc_library( cc_library( name = "mlir_bridge_pass_registration", - srcs = if_tpu( + srcs = if_libtpu( if_false = [ "mlir_bridge_pass_registration.cc", ], if_true = [], ), - deps = if_tpu( + deps = if_libtpu( if_false = [ ":mlir_bridge_pass", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass_registration", diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 7e8d3d7002a..b461aa43153 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -186,7 +186,7 @@ class StatelessCategoricalOp : public CategoricalOp { REGISTER_XLA_OP(Name("StatelessMultinomial") .CompileTimeConstantInput("num_samples") - .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("T", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessCategoricalOp); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc index 8b481d55a80..555905ebe6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { @@ -38,16 +39,28 @@ class XlaReduceOp : public XlaOpKernel { context, dims_set.size() == dimensions_to_reduce_.size(), errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " "argument to XlaReduce")); + if (context->HasAttr("N")) { // variadic reduce + use_tuples_ = true; + OP_REQUIRES_OK(context, context->GetAttr("N", &n_)); + } else { + use_tuples_ = false; + n_ = 1; + } } void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape("input"); - const TensorShape init_value_shape = context->InputShape("init_value"); + OP_REQUIRES(context, n_ * 2 == context->num_inputs(), + errors::InvalidArgument("Expected ", n_ * 2, " inputs but got ", + context->num_inputs())); + + const TensorShape input_shape = context->InputShape(0); + const TensorShape init_value_shape = context->InputShape(n_); const DataType dtype = context->input_type(0); const int rank = input_shape.dims(); OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), - errors::InvalidArgument("init_value must be a scalar")); + errors::InvalidArgument("init_value must be a scalar but got ", + init_value_shape.DebugString())); auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; OP_REQUIRES(context, @@ -67,35 +80,58 @@ class XlaReduceOp : public XlaOpKernel { compile_options.always_return_tuple = false; compile_options.is_entry_computation = false; XlaCompiler::CompilationResult reducer; - OP_REQUIRES_OK(context, context->compiler()->CompileFunction( - compile_options, *reducer_, - {reducer_arg, reducer_arg}, &reducer)); + OP_REQUIRES_OK( + context, + context->compiler()->CompileFunction( + compile_options, *reducer_, + std::vector(n_ * 2, reducer_arg), &reducer)); - xla::Shape scalar_shape; - OP_REQUIRES_OK(context, - TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + xla::Shape expected_shape; + OP_REQUIRES_OK( + context, TensorShapeToXLAShape(dtype, TensorShape(), &expected_shape)); + if (use_tuples_) { + expected_shape = xla::ShapeUtil::MakeTupleShape( + std::vector(n_, expected_shape)); + } OP_REQUIRES( context, - xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + xla::ShapeUtil::Compatible(reducer.xla_output_shape, expected_shape), errors::InvalidArgument( "Invalid output shape of XlaReduce reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(expected_shape), " got ", xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + std::vector inputs; + std::vector inits; + inputs.reserve(n_); + inits.reserve(n_); + for (int i = 0; i < n_; i++) { + inputs.emplace_back(context->Input(i)); + inits.emplace_back(context->Input(n_ + i)); + } xla::XlaOp output = - xla::Reduce(context->Input("input"), context->Input("init_value"), - *reducer.computation, dimensions_to_reduce_); - context->SetOutput(0, output); + xla::Reduce(context->builder(), inputs, inits, *reducer.computation, + dimensions_to_reduce_); + if (use_tuples_) { + for (int i = 0; i < n_; i++) { + context->SetOutput(i, xla::GetTupleElement(output, i)); + } + } else { + context->SetOutput(0, output); + } } private: const NameAttrList* reducer_; std::vector dimensions_to_reduce_; + bool use_tuples_; + int n_; TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); }; REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); +REGISTER_XLA_OP(Name("XlaVariadicReduce"), XlaReduceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index eefef26dc24..b46429ef0d1 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -38,7 +38,7 @@ auto* mlir_bridge_gauge_v2 = monitoring::Gauge::New( // encapsulated graph to a particular device. Status MlirBridgePass::Run(const ConfigProto& config_proto, mlir::ModuleOp module) { - if (!config_proto.experimental().enable_mlir_bridge()) { + if (!IsEnabled(config_proto)) { VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled"; mlir_bridge_gauge_v2->GetCell()->Set(false); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index f7541e634d4..09d3ef1f6d4 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -30,7 +30,8 @@ class MlirBridgePass : public MlirOptimizationPass { llvm::StringRef name() const override { return "bridge"; } bool IsEnabled(const ConfigProto& config_proto) const override { - return config_proto.experimental().enable_mlir_bridge(); + return config_proto.experimental().enable_mlir_bridge() || + tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; } // This should be used as a thin mapper around mlir::ModulePass::runOnModule diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 92a83436346..ac4d1f28803 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -161,7 +161,7 @@ Status ConvertGraphDefToXlaViaMlir( return ConvertMLIRToXlaComputation(*module, /*device_type=*/"XLA_CPU_JIT", computation, /*use_tuple_args=*/false, - /*always_return_tuple=*/true); + /*return_tuple=*/true); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 2f895b17219..00cbe7bc9bc 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -475,6 +475,60 @@ reducer: a reducer function to apply dimensions_to_reduce: dimension numbers over which to reduce )doc"); +REGISTER_OP("XlaVariadicReduce") + .Input("input: N * T") + .Input("init_value: N * T") + .Attr("N: int >= 1") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: N * T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + c->MergeInput(i, c->input(j)); + } + } + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + const int dimensions_to_reduce_size = dimensions_to_reduce.size(); + if (rank < dimensions_to_reduce_size || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaVariadicReduce"); + } + for (int i = 0; i < n; i++) { + c->set_output( + i, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } + } else { + for (int i = 0; i < n; i++) { + c->set_output(i, c->input(i)); + } + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the variadic XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. + +input: the input tensor(s) +init_value: scalar initial value(s) for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") @@ -762,7 +816,7 @@ REGISTER_OP("XlaScatter") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Output("output: T") - .SetShapeFn(UnchangedRank) + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Wraps the XLA Scatter operator documented at https://www.tensorflow.org/xla/operation_semantics#scatter. diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 6278ea1a3a9..2e5667bc02f 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -340,6 +340,7 @@ def random_uniform(minval, maxval, dims, name=None): recv = gen_xla_ops.xla_recv reduce = gen_xla_ops.xla_reduce +variadic_reduce = gen_xla_ops.xla_variadic_reduce def reduce_window(operand, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index e67e183aef9..fdd8484c249 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -56,7 +56,7 @@ limitations under the License. #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/util/dump_graph.h" -#ifndef LIBTFTPU +#ifndef LIBTPU_ON_GCE #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h" #endif @@ -733,7 +733,7 @@ Status XlaCompiler::CompileFunction( } VLOG(1) << "===================================================="; -#ifdef LIBTFTPU +#ifdef LIBTPU_ON_GCE if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { VLOG(1) << "MLIR is not supported in this environment."; } @@ -743,9 +743,13 @@ Status XlaCompiler::CompileFunction( if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { VLOG(1) << "Using MLIR bridge"; GraphDebugInfo debug_info; + std::vector control_rets; + for (const auto* control_ret_node : fbody->control_ret_nodes) { + control_rets.push_back(control_ret_node->name()); + } TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( std::move(*graph), mlir::SpanToArrayRef(args), - options_.device_type.type_string(), options.use_tuple_arg, + control_rets, options_.device_type.type_string(), options.use_tuple_arg, *options_.flib_def, debug_info, options_.shape_representation_fn, result)); } else { diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index cb202ce7600..831da22e033 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -277,6 +277,7 @@ tf_cc_test( ":types", ":util", "//tensorflow/core:test_main", + "//tensorflow/core/platform:bfloat16", ], ) @@ -320,7 +321,7 @@ cc_library( ":xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -526,7 +527,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", ], ) @@ -663,8 +664,8 @@ cc_library( ":statusor", ":types", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 4172de1f6e0..409cf37762b 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -130,7 +130,7 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", @@ -150,7 +150,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", ], @@ -174,7 +174,7 @@ cc_library( "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", @@ -257,6 +257,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1389f548c5d..82a6128025f 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -267,9 +267,9 @@ StatusOr LocalExecutable::RunAsync( } static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer( - Shape const& on_host_shape, const ShapeTree& tree, - se::Platform* platform, int device_ordinal) { - ShapedBuffer result(on_host_shape, tree.shape(), platform, device_ordinal); + const ShapeTree& tree, se::Platform* platform, + int device_ordinal) { + ShapedBuffer result(tree.shape(), platform, device_ordinal); auto it = tree.begin(); auto out_it = result.buffers().begin(); for (; it != tree.end(); ++it, ++out_it) { @@ -299,8 +299,8 @@ StatusOr LocalExecutable::RunAsync( shaped_buffer_ptrs.reserve(arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer( - *argument_host_shapes[i], arguments[i].Buffers(), - backend_->platform(), stream->parent()->device_ordinal())); + arguments[i].Buffers(), backend_->platform(), + stream->parent()->device_ordinal())); shaped_buffer_ptrs.push_back(&shaped_buffers.back()); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c7bbf9f8486..41212e69b2e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -132,10 +132,10 @@ bool InstrIsSetBound(const HloInstructionProto* instr_proto) { namespace internal { -XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation) { +XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation) { return builder->ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; instr.set_fusion_kind(std::string(fusion_kind)); @@ -149,6 +149,21 @@ XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, }); } +XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand, + const Shape& shape) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast, + {operand}); + }); +} + +HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) { + return &op.builder() + ->instructions_[op.builder()->handle_to_index_[op.handle_]]; +} + } // namespace internal XlaOp operator-(XlaOp x) { return Neg(x); } @@ -676,8 +691,10 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) { - return Compare(shape, lhs, rhs, direction, - Comparison::DefaultComparisonType(shape.element_type())); + TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs)); + return Compare( + shape, lhs, rhs, direction, + Comparison::DefaultComparisonType(operand_shape.element_type())); } StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 55bcd86b493..f736ae1d470 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -47,13 +47,21 @@ namespace xla { class XlaBuilder; class XlaOp; +class HloInstruction; namespace internal { -XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation); +struct XlaBuilderFriend { + static XlaOp BuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation); + + static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, + const Shape& shape); + + static HloInstructionProto* GetInstruction(XlaOp op); +}; } // namespace internal @@ -107,6 +115,7 @@ class XlaOp { friend class XlaBuilder; friend class MlirHloBuilder; + friend struct internal::XlaBuilderFriend; // < 0 means "invalid handle". int64 handle_; @@ -1306,9 +1315,7 @@ class XlaBuilder { return LookUpInstructionByHandleInternal(op.handle()); } - friend XlaOp internal::XlaBuilderBuildFusion( - XlaBuilder* builder, absl::Span operands, - absl::string_view fusion_kind, const XlaComputation& fused_computation); + friend struct internal::XlaBuilderFriend; }; // RAII-style object: sets the current sharding assignment in builder on diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 7011c946203..bfd13c8ddf5 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -1203,5 +1205,16 @@ TEST_F(XlaBuilderTest, AddFrontendAttribute) { TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); ExpectInstructionsAttributesMatch(*module, expected); } + +TEST_F(XlaBuilderTest, ComparisonType) { + XlaBuilder b(TestName()); + (void)Le(ConstantR0(&b, 1), ConstantR0(&b, 2)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant())); + EXPECT_EQ(Comparison::Type::kSigned, + DynCast(root)->type()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 50fd6a12d66..43a3860e405 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -26,7 +26,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", ], @@ -109,7 +109,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/stream_executor:event", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", @@ -141,11 +141,12 @@ cc_library( "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", - "//tensorflow/core:allocator", "//tensorflow/core:lib", + "//tensorflow/core/framework:allocator", "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", @@ -167,6 +168,46 @@ cc_library( ], ) +cc_library( + name = "tpu_client", + srcs = ["tpu_client.cc"], + hdrs = ["tpu_client.h"], + visibility = [ + "//learning/brain/research/jax:__subpackages__", + "//learning/deepmind/tensorflow/tensorfn:__subpackages__", + "//learning/pathways:__subpackages__", + ], + deps = [ + ":local_device_state", + ":pjrt_client", + ":tracked_device_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core/tpu:tpu_executor_dlsym_initializer", + "//tensorflow/core/tpu:tpu_on_demand_compiler", + "//tensorflow/stream_executor:device_memory", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", + "//tensorflow/stream_executor/tpu:tpu_computation_placer", + "//tensorflow/stream_executor/tpu:tpu_executable_interface", + "//tensorflow/stream_executor/tpu:tpu_executor", + "//tensorflow/stream_executor/tpu:tpu_executor_interface", + "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "//tensorflow/stream_executor/tpu:tpu_topology_external", + "//tensorflow/stream_executor/tpu:tpu_transfer_manager", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "interpreter_device", srcs = ["interpreter_device.cc"], diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index e2543bda7df..c571ef2a4df 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -28,7 +28,7 @@ CpuDevice::CpuDevice(int id, : PjRtDevice(id, std::move(local_device_state), kCpuPlatformName, /*device_kind=*/kCpuPlatformName) {} -StatusOr> GetCpuClient(bool asynchronous) { +StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Host")); if (platform->VisibleDeviceCount() <= 0) { @@ -56,7 +56,7 @@ StatusOr> GetCpuClient(bool asynchronous) { devices.push_back(std::move(device)); } - return std::make_shared( + return std::make_unique( kCpuPlatformName, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h index ad0079b1c4a..1036d8fedbb 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -28,7 +28,7 @@ class CpuDevice : public PjRtDevice { CpuDevice(int id, std::unique_ptr local_device_state); }; -StatusOr> GetCpuClient(bool asynchronous); +StatusOr> GetCpuClient(bool asynchronous); } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index 298c41c7f58..c56b41861b0 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -28,7 +28,7 @@ namespace { // computation wait for the inputs to be produced before executing. TEST(GpuMultiStream, Basics) { TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr client, + std::unique_ptr client, GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(), /*distributed_client=*/nullptr, /*node_id=*/0)); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index c1149f2dbf9..376d8687892 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -28,7 +28,7 @@ InterpreterDevice::InterpreterDevice( : PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName, /*device_kind=*/kInterpreterPlatformName) {} -StatusOr> GetInterpreterClient() { +StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Interpreter")); if (platform->VisibleDeviceCount() != 1) { @@ -50,7 +50,7 @@ StatusOr> GetInterpreterClient() { absl::make_unique(0, std::move(device_state)); devices.push_back(std::move(device)); - return std::make_shared( + return std::make_unique( kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*should_stage_host_to_device_transfers=*/false, diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h index cf732f70124..4038d8dbf11 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.h +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -29,7 +29,7 @@ class InterpreterDevice : public PjRtDevice { std::unique_ptr local_device_state); }; -StatusOr> GetInterpreterClient(); +StatusOr> GetInterpreterClient(); } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index 6e387f8738f..df92921c39d 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -301,7 +301,7 @@ GpuDevice::GpuDevice(int id, : PjRtDevice(id, std::move(local_device_state), kGpuPlatformName, std::move(device_kind), node_id) {} -StatusOr> GetNvidiaGpuClient( +StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, std::shared_ptr distributed_client, int node_id) { TF_ASSIGN_OR_RETURN(LocalClient * xla_client, GetGpuXlaClient()); @@ -324,13 +324,12 @@ StatusOr> GetNvidiaGpuClient( devices = BuildLocalDevices(std::move(local_device_states)); } - std::shared_ptr pyclient = std::make_shared( + return std::unique_ptr(std::make_unique( "gpu", xla_client, std::move(devices), /*node_id=*/node_id, std::move(allocator), std::move(host_memory_allocator), /*should_stage_host_to_device_transfers=*/true, - /*gpu_run_options=*/std::move(gpu_run_options)); - return pyclient; + /*gpu_run_options=*/std::move(gpu_run_options))); } } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h index 4f22a169bd8..f480a37429a 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h @@ -53,7 +53,7 @@ struct GpuAllocatorConfig { // distributed_client may be nullptr in non-distributed settings. // distributed_client should not be Open()ed before calling this function. -StatusOr> GetNvidiaGpuClient( +StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, std::shared_ptr distributed_client, int node_id); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index f3dafd71ba5..02ae37b71db 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -90,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -282,6 +283,11 @@ StatusOr> PjRtClient::GetParametersThatMustBeDonated( return parameters_to_donate; } +std::unique_ptr PjRtClient::GetHloCostAnalysis() { + return absl::make_unique( + client_->backend().compiler()->ShapeSizeBytesFunction()); +} + namespace { // Ensures that it is safe to deallocate any buffers that have been enqueued in @@ -1258,6 +1264,14 @@ StatusOr> PjRtBuffer::CopyToDevice( "CopyToDevice cannot accept the same source and destination devices"); } + // Copying across PjRtClients involves a copy through the host. + if (dst_device->client() != client_) { + TF_ASSIGN_OR_RETURN(std::shared_ptr literal, ToLiteral()); + return FromHostBuffer(literal->untyped_data(), literal->shape(), + HostBufferSemantics::kZeroCopy, nullptr, + dst_device->client(), dst_device); + } + TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device, dst_device->GetLocalDeviceState()); LocalDeviceState* transfer_local_device = diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 39711534f79..cb4ef9da85b 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -195,6 +195,9 @@ class PjRtClient { return absl::optional(); } + // Returns a backend-specific HLO cost analysis visitor. + virtual std::unique_ptr GetHloCostAnalysis(); + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( @@ -560,8 +563,10 @@ class PjRtBuffer { return GetBufferWithHold(ScopedHold::kExternalReference); } - // Copies the buffer to device `dst_device`. Returns an error if the buffer is - // already on dst_device. + // Copies the buffer to device `dst_device`, performing a d2d transfer when + // `dst_device` is sharing the same Client, and performing a d2h and h2d copy + // if `dst_device` lives on a different Client. + // Returns an error if the buffer is already on dst_device. StatusOr> CopyToDevice(PjRtDevice* dst_device); // Copies the buffer to the remote device encoded in serialized_descriptor. diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.cc b/tensorflow/compiler/xla/pjrt/tpu_client.cc new file mode 100644 index 00000000000..a8711631605 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/tpu_client.cc @@ -0,0 +1,246 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/pjrt/tpu_client.h" + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h" +#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_stream.h" + +namespace tf_tpu = tensorflow::tpu; + +namespace xla { +namespace { + +class TpuDeviceState : public LocalDeviceState { + public: + TpuDeviceState(se::StreamExecutor* executor, LocalClient* client, + bool asynchronous); + + Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream, + se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, + se::DeviceMemoryBase dst_buffer) override; +}; + +TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor, + LocalClient* client, bool asynchronous) + : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous, + asynchronous, + /*allow_event_reuse=*/false) {} + +Status TpuDeviceState::ThenMemcpyDeviceToDevice( + se::Stream* transfer_stream, se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { + auto* transfer_tpu_stream = tensorflow::down_cast( + transfer_stream->implementation()); + tf_tpu::TpuTopologyExternal topology = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); + // TODO(b/157179600): use device-to-device transfers when implemented instead + // of copying via host. + if (topology.version() == kTpuV4) { + LOG(WARNING) + << "device-to-device transfers not yet implemented, copying via host"; + auto* dst_tpu_stream = + tensorflow::down_cast(dst_stream->implementation()); + TF_RET_CHECK(src_buffer.size() == dst_buffer.size()); + auto host_tmp = std::make_unique(src_buffer.size()); + TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueTransferDeviceToHost( + src_buffer, host_tmp.get(), src_buffer.size())); + dst_stream->ThenWaitFor(transfer_stream); + TF_RETURN_IF_ERROR(dst_tpu_stream->EnqueueTransferHostToDevice( + dst_buffer, host_tmp.get(), dst_buffer.size())); + transfer_stream->ThenWaitFor(dst_stream); + char* tmp = host_tmp.release(); + dst_stream->ThenDoHostCallback([tmp] { delete[] tmp; }); + } else { + TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal( + src_buffer, dst_buffer)); + } + return Status::OK(); +} + +class PjRtTpuClient : public PjRtClient { + public: + PjRtTpuClient(LocalClient* client, + std::vector> devices, int host_id, + tf_tpu::TpuPlatformInterface* tpu_platform); + + StatusOr GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const override; + + bool EnqueueD2DTransfersOnSrcStream() const override { + return tpu_platform_->topology().version() == kTpuV4; + } + + StatusOr> ExecutableFingerprint( + const PjRtExecutable& executable) const override; + + private: + tf_tpu::TpuPlatformInterface* tpu_platform_; +}; + +PjRtTpuClient::PjRtTpuClient(LocalClient* client, + std::vector> devices, + int host_id, + tf_tpu::TpuPlatformInterface* tpu_platform) + : PjRtClient("tpu", client, std::move(devices), host_id, + /*allocator=*/nullptr, + /*host_memory_allocator=*/nullptr, + /*should_stage_host_to_device_transfers=*/false, + /*gpu_run_options=*/nullptr), + tpu_platform_(tpu_platform) {} + +StatusOr PjRtTpuClient::GetDefaultDeviceAssignment( + int num_replicas, int num_partitions) const { + tf_tpu::TpuPlatformInterface* platform = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(); + tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation(); + int num_local_devices = host.Cores(kTensorCore).size(); + if (num_replicas * num_partitions <= num_local_devices) { + return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas, + num_partitions); + } + // Fallback to default global device assignment if we can't run locally. + return PjRtClient::GetDefaultDeviceAssignment(num_replicas, num_partitions); +} + +StatusOr> PjRtTpuClient::ExecutableFingerprint( + const PjRtExecutable& executable) const { + if (executable.client() != this) { + return InvalidArgument( + "Passed executable from different client (platform '%s') to " + "PjRtTpuClient::ExecutableFingerprint", + executable.client()->platform_name()); + } + if (executable.executables().size() > 1) { + LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD " + "executables, fingerprint may not be unique."; + } + xla::TpuExecutableInterface* tpu_executable = + tensorflow::down_cast( + executable.executables()[0]->executable()); + return absl::optional(tpu_executable->fingerprint()); +} + +StatusOr>> GetTpuDevices( + LocalClient* client, + std::vector> local_device_states) { + std::vector> devices; + tf_tpu::TpuTopologyExternal topology = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology(); + + std::map core_id_to_device_ordinal; + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + tf_tpu::TpuExecutorInterface* tpu_executor = + tensorflow::down_cast( + executor->implementation()); + core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i; + } + + for (const tf_tpu::TpuCoreLocationExternal& core : + topology.cores(TpuCoreTypeEnum::kTensorCore)) { + auto it = core_id_to_device_ordinal.find(core.Id()); + int device_ordinal = + (it != core_id_to_device_ordinal.end()) ? it->second : -1; + int host_id = topology.IdForHost(core.host_coordinates()); + const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates(); + std::array coords_array = {coords.x, coords.y, coords.z}; + std::unique_ptr local_device_state; + if (device_ordinal >= 0) { + local_device_state = std::move(local_device_states[device_ordinal]); + } + auto device = absl::make_unique( + core, std::move(local_device_state), host_id, coords_array, + std::string(tf_tpu::TpuVersionEnumToString(topology.version()))); + devices.push_back(std::move(device)); + } + return devices; +} + +} // namespace + +StatusOr> GetTpuClient( + bool asynchronous, absl::Duration init_retry_timeout) { + tf_tpu::TpuPlatformInterface* platform = + tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(); + if (platform == nullptr) { + return InvalidArgument("TpuPlatform is not available."); + } + // NOTE: We retry in a loop since some pod failures are transient (e.g. some + // RPCs may timeout waiting for other hosts to come up, but will succeed + // at a later point if retried). + auto start = absl::Now(); + // TODO(b/165870356): TpuPlatform::Initialized() always returns true! + auto status = platform->Initialize({}); + while (!platform->Initialized()) { + status = platform->Initialize({}); + if (!status.ok()) { + LOG(ERROR) << "Platform initialization failed: " << status; + if ((absl::Now() - start) >= init_retry_timeout) { + return status; + } + } + } + if (platform->VisibleDeviceCount() <= 0) { + return InvalidArgument("No TPU devices found."); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + + std::vector> local_device_states; + local_device_states.reserve(client->device_count()); + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + local_device_states.push_back( + absl::make_unique(executor, client, asynchronous)); + } + + TF_ASSIGN_OR_RETURN(auto devices, + GetTpuDevices(client, std::move(local_device_states))); + int host_id = platform->GetTpuHostLocation().Id(); + + return std::shared_ptr(absl::make_unique( + client, std::move(devices), host_id, platform)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/tpu_client.h b/tensorflow/compiler/xla/pjrt/tpu_client.h new file mode 100644 index 00000000000..1a458c1480b --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/tpu_client.h @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/tpu/tpu_topology.h" + +namespace xla { + +class PjRtTpuDevice : public PjRtDevice { + public: + PjRtTpuDevice(const tensorflow::tpu::TpuCoreLocationExternal core, + std::unique_ptr local_device_state, + int host_id, const std::array& coords, + std::string device_kind) + : PjRtDevice(core.Id(), std::move(local_device_state), + /*platform_name=*/"tpu", std::move(device_kind), host_id), + core_(core), + coords_(coords) {} + + const std::array& coords() const { return coords_; } + int core_on_chip() const { return core_.index(); } + const tensorflow::tpu::TpuCoreLocationExternal core() const { return core_; } + + std::string DebugString() const override { + return absl::StrFormat("TPU_%i(host=%i,(%i,%i,%i,%i))", id(), host_id(), + coords_[0], coords_[1], coords_[2], core_.index()); + } + + private: + const tensorflow::tpu::TpuCoreLocationExternal core_; + const std::array coords_; +}; + +StatusOr> GetTpuClient( + bool asynchronous, + absl::Duration init_retry_timeout = absl::ZeroDuration()); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PJRT_TPU_CLIENT_H_ diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 44c263367af..2d392d41f37 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" @@ -90,9 +91,7 @@ struct ArgSignature { template H AbslHashValue(H h, const ArgSignature& s) { h = H::combine(std::move(h), s.dtype); - if (!s.shape.empty()) { - h = H::combine_contiguous(std::move(h), &s.shape.front(), s.shape.size()); - } + h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size()); return h; } @@ -123,10 +122,10 @@ struct CallSignature { std::vector static_args; // A PyTreeDef for each positional dynamic (i.e. not static) argument. std::vector dynamic_positional_args_treedef; - // Keyword arguments. Sorted by the interned keyword pointers. + // Keyword arguments. Sorted by the keyword name. std::vector keyword_args; // Shape and dtype for both the dynamic positional arguments and the keyword - // arguments (sorted by interned keyword pointers). + // arguments (sorted by keyword name). std::vector dynamic_args_signatures; PjRtDevice* device; @@ -177,11 +176,11 @@ H AbslHashValue(H h, const CallSignature& s) { // TODO(jblespiau): We should either ban non-hashable objects from jit or we // should hash them by object identity. h = H::combine_contiguous(std::move(h), - &s.dynamic_positional_args_treedef.front(), + s.dynamic_positional_args_treedef.data(), s.dynamic_positional_args_treedef.size()); - h = H::combine_contiguous(std::move(h), &s.keyword_args.front(), + h = H::combine_contiguous(std::move(h), s.keyword_args.data(), s.keyword_args.size()); - h = H::combine_contiguous(std::move(h), &s.dynamic_args_signatures.front(), + h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(), s.dynamic_args_signatures.size()); h = H::combine(std::move(h), s.device); return h; @@ -191,7 +190,7 @@ std::string CallSignature::DebugString() const { std::vector static_args_str; static_args_str.reserve(static_args.size()); for (auto& static_arg : static_args) { - static_args_str.emplace_back(py::cast(static_arg.str())); + static_args_str.emplace_back(py::cast(py::str(static_arg))); } std::vector signature_str; @@ -225,14 +224,14 @@ std::string CallSignature::DebugString() const { struct CacheEntry { std::shared_ptr executable; - xla::PjRtDevice* device; PyTreeDef out_pytree_def; - // These are the objects required to create a `DeviceArray` object. - // We use Python types within the vector because this is what we will be - // returning to Python. No need to convert back and forth. - // We need py::object to maintain the objects alive. - std::vector out_avals; - std::vector out_lazy_exprs; + // Callables (one for each output) to call on each output to get the Python + // object (usually a DeviceArray) that we should return. + // TODO(jblespiau): The goal of the C++ codepath being to be fast, thus, we + // should not call into Python. It will be trivial to fix this when + // omnistaging is the only option & when DeviceArray and PyBuffer are merged). + std::vector handlers; + // Ensures a single thread performs the compilation for a given executable. // // The first thread (holding the GIL) will create the CacheEntry associated to @@ -240,6 +239,9 @@ struct CacheEntry { // will wait for the notification. absl::Notification compilation_complete; absl::optional compilation_error = absl::nullopt; + // Trivial computation will fallback to Python. + // Running a jax(pmap) will also fallback to Python. + bool fall_back_to_python = false; }; // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the @@ -247,9 +249,10 @@ struct CacheEntry { // the correct underlying `PyExecutable`. This class is thread-safe. class CompiledFunction { public: - CompiledFunction(py::function fun, py::function cache_miss_fun, - py::function python_f_jitted, bool jax_enable_x64, - bool jax_disable_jit, std::vector static_argnums); + CompiledFunction(py::function fun, py::function cache_miss, + py::function get_device, py::function get_jax_enable_x64, + py::function get_jax_disable_jit, + std::vector static_argnums); ~CompiledFunction(); // This function will: @@ -267,27 +270,19 @@ class CompiledFunction { } private: - CacheEntry& GetCacheEntry(const py::args& args, const py::kwargs& kwargs, + // Returns nullptr if not present in the cache. + CacheEntry* GetCacheEntryIfPresent(const CallSignature& signature); + // Should never return nullptr. + CacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs, const CallSignature& signature, - absl::optional cache_miss_return); - CacheEntry& SetAndReturnCacheEntry( - const py::args& args, const py::kwargs& kwargs, - const CallSignature& signature, - absl::optional cache_miss_return = absl::nullopt); - bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_; } + py::object out_and_fastpath_data); + bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); } + + bool always_fallback_to_python_ = false; const py::function fun_; // The Python function to jit. - // The Python function in charge of returning a `xla::PyExecutable` from - // the arguments passed to `jitted_f`. - const py::function cache_miss_fun_; - // A function to call as fallback. This is the result of calling the Python - // `jax.jit`. - // TODO(jblespiau): Delete this when the C++ codepath supports all features. - const py::function python_f_jitted_; - - // The value of the Python flag when the object was created. - const bool jax_enable_x64_; - const bool jax_disable_jit_; + // See JAX _cpp_jit in api.py for documentation. + const py::function cache_miss_; // We need to know the static arguments to remove them from the arguments // passed to the underlying PyExecutable. In sorted order. @@ -299,35 +294,43 @@ class CompiledFunction { // `CompiledFunction` is being instantiated from Python, the clients are not // yet available (done after GoogleInit). They will be during the first call // to `Call`. - std::shared_ptr pyclient_ = nullptr; - xla::PjRtDevice* default_device_ = nullptr; + // A function taking no arguments and returning the default device and whether + // jax.jit has been committed to it. + const py::function get_jax_enable_x64_; + const py::function get_jax_disable_jit_; + const py::function get_device_; - // IMPORTANT: The GIL is not always held, because we call back to Python and - // Python will release the GIL. - // Thus, we protect the critical section modifying the `executables_` map - // and more generally the compilation with some `absl::Notification`. - // The first thread reaching such point will be responsible to create the - // notification for the executable and others will wait until notified. - // It's safe because the first thread will be holding the GIL while - // initializing the `Notification`. - // - // absl::optional is not supported - bool first_compilation_started_ = false; - absl::Notification first_compilation_complete_; - absl::optional first_compilation_error_ = absl::nullopt; + // The writing of the following is protected by the mutex. + absl::Mutex mu_; + // The value of the Python flag. The value will be computed only during the + // first object call, because GoogleInit must have been executed. + absl::optional jax_enable_x64_ = absl::nullopt; + absl::optional jax_disable_jit_ = absl::nullopt; + + // The logic if the following: + // - if `device` or `backend` are not specified to `jax.jit`, we will use + // the input sticky buffer device, or `default_device_` if there is no + // such sticky buffer. + // - When one of `device` or `backend` is specified, this will determine + // the `default_device_` which will be used as the targeted device. In + // which case, we will always copy input buffers to this device. + std::shared_ptr default_pyclient_ = nullptr; + xla::ClientAndPtr default_pydevice_; + xla::PjRtDevice* default_device_ = nullptr; + bool is_committed_; }; -CompiledFunction::CompiledFunction(py::function fun, - py::function cache_miss_fun, - py::function python_f_jitted, - bool jax_enable_x64, bool jax_disable_jit, +CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss, + py::function get_device, + py::function get_jax_enable_x64, + py::function get_jax_disable_jit, std::vector static_argnums) : fun_(std::move(fun)), - cache_miss_fun_(std::move(cache_miss_fun)), - python_f_jitted_(std::move(python_f_jitted)), - jax_enable_x64_(jax_enable_x64), - jax_disable_jit_(jax_disable_jit), - static_argnums_(std::move(static_argnums)) { + cache_miss_(std::move(cache_miss)), + static_argnums_(std::move(static_argnums)), + get_jax_enable_x64_(get_jax_enable_x64), + get_jax_disable_jit_(get_jax_disable_jit), + get_device_(std::move(get_device)) { std::sort(static_argnums_.begin(), static_argnums_.end()); } @@ -339,6 +342,34 @@ CompiledFunction::~CompiledFunction() { namespace { +// The equivalent of the Python jax/lazy.py::is_trivial: +// return (type(lexpr.input) is ArrayVar and +// lexpr.dims == tuple(range(len(lexpr.shape)))) +// +// Expects *only* instances of `DeviceArray`. +bool HasTrivialLazyExpr(py::handle device_array) { + static const auto* lazy_module = + new py::module(py::module::import("jax.lazy")); + + auto lexpr = py::getattr(device_array, "_lazy_expr"); + auto input = py::getattr(lexpr, "input"); + if (!input.get_type().is(lazy_module->attr("ArrayVar"))) { + return false; + } + py::tuple dims = py::cast(lexpr.attr("dims")); + py::tuple shape = py::cast(lexpr.attr("shape")); + + for (int i = 0; i < shape.size(); ++i) { + if (dims[i].is_none()) { + return false; + } + if (py::cast(dims[i]) != i) { + return false; + } + } + return true; +} + // The resulting information of the parsing and conversion of the arguments. struct ParsedArgumentsAsBuffers { // The call signature will be filled during 2 steps: @@ -391,8 +422,10 @@ void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs, // Keyword arguments. std::vector> kwargs(py_kwargs.begin(), py_kwargs.end()); - // We first intern the keys, then sort them (by pointer) and then create - // the signatures. + // We first intern the keys, then sort them (by name, as in the Python path) + // (see also PyTreeDef::Flatten) and then create the signatures. + // TODO(jblespiau): We should be able to sort the keys by interned-key + // pointers, but this requires the Python compilation to do the same. arguments.signature.keyword_args.resize(kwargs.size()); for (size_t i = 0; i < kwargs.size(); ++i) { // Intern the key if not already interned. @@ -409,7 +442,7 @@ void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs, std::sort(kwargs.begin(), kwargs.end(), [](const std::pair& a, const std::pair& b) { - return a.first.ptr() < b.first.ptr(); + return a.first < b.first; }); for (size_t i = 0; i < kwargs.size(); ++i) { arguments.signature.keyword_args[i].key = kwargs[i].first; @@ -478,7 +511,7 @@ StatusOr> ScalarToBuffer( "%s", absl::StrCat( "Not supported: The C++ jax jit execution path, only accepts " "DeviceArray, Numpy arrays, or Python scalars. Got type ", - py::cast(scalar.get_type().str()))); + py::cast(py::str(scalar.get_type())))); } const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) { @@ -491,28 +524,37 @@ const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) { static const auto* complex64_dt = new py::dtype("complex64"); static const auto* complex128_dt = new py::dtype("complex128"); - if (dtype == *int64_dt) { + if (dtype.equal(*int64_dt)) { return int32_dt; } - if (dtype == *float64_dt) { + if (dtype.equal(*float64_dt)) { return float32_dt; } - if (dtype == *uint64_dt) { + if (dtype.equal(*uint64_dt)) { return uint32_dt; } - if (dtype == *complex128_dt) { + if (dtype.equal(*complex128_dt)) { return complex64_dt; } return nullptr; } +bool IsFloat0(py::array arg) { + static const auto* dtypes_module = + new py::module(py::module::import("jax.dtypes")); + static const auto* float0_dtype = + new py::handle(dtypes_module->attr("float0")); + return float0_dtype->is(arg.attr("dtype")); +} + // Converts flattened arguments contained in ParsedArgumentsAsBuffers in // place. If arguments are `DeviceArray`, they must all be on the same `Device`. // -// Returns `OkStatus()` on success. +// Returns `OkStatus()` on success. Returning an error should lead to calling +// the Python fallback. Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, - xla::PjRtDevice* default_device, + xla::PjRtDevice* default_device, bool is_committed, ParsedArgumentsAsBuffers& arguments) { std::vector& arg_buffers = arguments.arg_buffers; auto& keep_alive = arguments.keep_alive; @@ -526,44 +568,49 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, const auto& device_array = xla_module->attr("DeviceArray"); static const auto* numpy_module = new py::module(py::module::import("numpy")); - const auto& array = numpy_module->attr("array"); + const auto& np_array = numpy_module->attr("array"); - // TODO(phawkins): consider device stickiness. - // We first check whether any `DeviceArray` is present and whether they are - // attached to any specific device. See also + // When the jitted function is not committed, we first check whether any + // sticky `DeviceArray` is present and on which device they live. See also: // https://github.com/google/jax/pull/1884 // https://github.com/google/jax/pull/1916 for the rationale why the // computation follows the data locality. // It's also similar to PyTorch's behavior. xla::PjRtDevice* data_device = nullptr; - for (py::handle arg : arguments.flat_dynamic_args) { - if (py::isinstance(arg, device_array)) { - xla::PyBuffer* buffer; - try { - // This can fail, e.g. when device_buffer is a `DeviceConstant`. - buffer = py::cast(arg.attr("device_buffer")); - } catch (const py::cast_error& e) { - return InvalidArgument( - "%s", - absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: " - "`device_buffer` field is of type ", - py::cast( - arg.attr("device_buffer").get_type().str()), - " while a `PyBuffer` was expected." + if (is_committed) { + data_device = default_device; + } else { + for (py::handle arg : arguments.flat_dynamic_args) { + // We specically only deal with DeviceArray (not ShardedDeviceArray). + // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored"). + if (arg.get_type().is(device_array)) { + xla::PyBuffer* buffer; + if (arg.attr("_device").is_none()) { // Skip non-sticky devices. + continue; + } + try { + // This can fail, e.g. when device_buffer is a `DeviceConstant`. + buffer = py::cast(arg.attr("device_buffer")); + } catch (const py::cast_error& e) { + return InvalidArgument( + "%s", + absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: " + "`device_buffer` field is of type ", + py::cast( + arg.attr("device_buffer").get_type().str()), + " while a `PyBuffer` was expected." - )); - } - xla::PjRtDevice* device = buffer->buffer()->device(); - if (data_device && (device != data_device)) { - return InvalidArgument( - "%s", - absl::StrCat( - "Arguments to a jit-compiled function must be colocated on the " - "same device. Arguments were found to be on the two following " - "different devices: ", - device->DebugString(), " and ", data_device->DebugString())); - } else { - data_device = device; + )); + } + xla::PjRtDevice* device = buffer->buffer()->device(); + if (data_device && (device != data_device)) { + throw std::invalid_argument(absl::StrCat( + "primitive arguments must be colocated on the same device (" + "C++ jax.jit). Arguments are on devices: ", + device->DebugString(), " and ", data_device->DebugString())); + } else { + data_device = device; + } } } } @@ -576,13 +623,26 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, xla::PjRtClient* pjrt_client = data_device->client(); for (py::handle arg : arguments.flat_dynamic_args) { - // We do not support here d2d transparent transfers. - // We assumes all the `DeviceArray` are already on the correct and shared - // device. - if (py::isinstance(arg, device_array)) { - xla::PyBuffer* buffer = - py::cast(arg.attr("device_buffer")); - arg_buffers.push_back(buffer->buffer()); + if (arg.get_type().is(device_array)) { + if (!HasTrivialLazyExpr(arg)) { + return InvalidArgument( + "Non-trivial lazy expression not supported in C++. " + "Falling back to Python."); + } + + PyBuffer* buffer = py::cast(arg.attr("device_buffer")); + if (buffer->device().contents == data_device) { + arg_buffers.push_back(buffer->buffer()); + } else { + // source and target platforms are the same, but different device. + // Perform a device-to-device copy. + // buffers from different XLA backends are passed through the host. + std::unique_ptr copied_buffer = + ValueOrThrow(buffer->buffer()->CopyToDevice(data_device)); + arg_buffers.push_back(copied_buffer.get()); + keep_alive.emplace_back(std::move(copied_buffer)); + } + ArgSignature sig; sig.dtype = buffer->shape().element_type(); sig.shape.assign(buffer->shape().dimensions().begin(), @@ -593,12 +653,17 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, // TODO(jblespiau): Can we improve this call? Do we need the underlying // GlobalPyRefManager() and co? py::array numpy_array = py::cast(arg); + if (IsFloat0(numpy_array)) { + return InvalidArgument( + "float0 numpy arrays not supported in C++. " + "It will fallback to Python."); + } // If jax_enable_x64 is not set, we need to coerce 32 bits types. // Note that this is calling back to Python! if (!jax_enable_x64) { const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype()); if (to_dtype) { - numpy_array = array(numpy_array, to_dtype); + numpy_array = np_array(numpy_array, *to_dtype); } } std::unique_ptr buffer = @@ -636,164 +701,145 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, } // namespace -CacheEntry& CompiledFunction::GetCacheEntry( - const py::args& args, const py::kwargs& kwargs, - const CallSignature& signature, - absl::optional cache_miss_return) { +CacheEntry* CompiledFunction::GetCacheEntryIfPresent( + const CallSignature& signature) { auto found_iterator = executables_.find(signature); if (found_iterator != executables_.end()) { // Cache hit! if (!found_iterator->second->compilation_complete.HasBeenNotified()) { py::gil_scoped_release gil_release; found_iterator->second->compilation_complete.WaitForNotification(); - if (found_iterator->second->compilation_error) { - throw std::invalid_argument( - found_iterator->second->compilation_error.value().error_message()); - } } - return *(found_iterator->second); + if (found_iterator->second->compilation_error) { + throw std::invalid_argument( + found_iterator->second->compilation_error.value().error_message()); + } + return found_iterator->second.get(); } - return SetAndReturnCacheEntry(args, kwargs, signature, cache_miss_return); + return nullptr; } -CacheEntry& CompiledFunction::SetAndReturnCacheEntry( - const py::args& args, const py::kwargs& kwargs, - const CallSignature& signature, - absl::optional cache_miss_return) { + +CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args, + const py::kwargs& kwargs, + const CallSignature& signature, + py::object out_and_fastpath_data) { // We need to insert the element. auto result = executables_.emplace(signature, std::make_unique()); auto it = result.first; - CacheEntry& cache_entry = *(it->second.get()); + CacheEntry* cache_entry = it->second.get(); // CallSignatures in the cache own their keyword argument reference. result.first->first.IncRef(); - // Cache miss? Call the Python cache miss function. - py::tuple executable_and_pytree; - if (cache_miss_return) { - executable_and_pytree = cache_miss_return.value(); - } else { - try { - executable_and_pytree = cache_miss_fun_(*args, **kwargs); - } catch (const py::error_already_set& e) { - cache_entry.compilation_error = InvalidArgument("%s", e.what()); - cache_entry.compilation_complete.Notify(); - throw; - } + py::tuple tuple = py::cast(out_and_fastpath_data); + CHECK_EQ(tuple.size(), 2); + if (tuple[1].is_none()) { + cache_entry->fall_back_to_python = true; + cache_entry->compilation_complete.Notify(); + return cache_entry; } - if (executable_and_pytree.size() != 4) { - throw std::runtime_error( - "AssertionError: The cache miss function should return 4 " - "arguments."); + + py::tuple executable_handlers_out_tree = py::cast(tuple[1]); + CHECK_EQ(executable_handlers_out_tree.size(), 3); + + auto executable = py::cast>( + executable_handlers_out_tree[0]); + std::vector handlers; + for (const auto& handler : + py::cast(executable_handlers_out_tree[1])) { + handlers.push_back(py::cast(handler)); } - cache_entry.executable = py::cast>( - std::move(executable_and_pytree[0])); + auto out_tree = py::cast(executable_handlers_out_tree[2]); + + cache_entry->executable = std::move(executable); int num_devices = - cache_entry.executable->pjrt_executable().local_devices().size(); - if (num_devices != 1) { - throw std::runtime_error(absl::StrCat( - "Running on more than a single device is not currently supported." - "The underlying PjRtExecutable has ", - num_devices)); - } - cache_entry.device = - cache_entry.executable->pjrt_executable().local_devices()[0]; - cache_entry.out_pytree_def = py::cast(executable_and_pytree[1]); + cache_entry->executable->pjrt_executable().local_devices().size(); + // The presence of jit(pmap) is detected from Python. + CHECK_EQ(num_devices, 1); - py::list shaped_arrays = - py::reinterpret_borrow(executable_and_pytree[2]); - py::list lazy_expressions = - py::reinterpret_borrow(executable_and_pytree[3]); + cache_entry->handlers = std::move(handlers); + cache_entry->out_pytree_def = std::move(out_tree); - cache_entry.out_avals.reserve(shaped_arrays.size()); - cache_entry.out_lazy_exprs.reserve(lazy_expressions.size()); - - int num_outputs = shaped_arrays.size(); - for (int i = 0; i < num_outputs; ++i) { - py::object shaped_array = - py::reinterpret_borrow(shaped_arrays[i]); - py::object lazy_expr = - py::reinterpret_borrow(lazy_expressions[i]); - - cache_entry.out_avals.push_back(shaped_array); - cache_entry.out_lazy_exprs.push_back(lazy_expr); - } - - cache_entry.compilation_complete.Notify(); + cache_entry->compilation_complete.Notify(); return cache_entry; } py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { + if (always_fallback_to_python_) { + return py::cast(cache_miss_(*args, **kwargs))[0]; + } + // Delayed values are retrieved on the first call to `Call`. + if (!default_device_) { + // As we are calling Python code, that may release the GIL, we first hold + // mu_ before holding the GIL. + py::gil_scoped_release gil_release; + { + absl::MutexLock lock1(&mu_); + py::gil_scoped_acquire gil_aquire; + + jax_enable_x64_ = py::cast(get_jax_enable_x64_()); + jax_disable_jit_ = py::cast(get_jax_disable_jit_()); + if (!default_device_) { + py::object device_and_is_committed = get_device_(); + try { + default_pydevice_ = py::cast>( + device_and_is_committed.attr("default_device")); + } catch (const py::cast_error& e) { + // Pathways and Cloud TPU 2VM runtime. + always_fallback_to_python_ = true; + return py::cast(cache_miss_(*args, **kwargs))[0]; + } + default_pyclient_ = default_pydevice_.client; + default_device_ = default_pydevice_.contents; + is_committed_ = + py::cast(device_and_is_committed.attr("committed_to_device")); + } + } + } + CHECK(default_device_); if (JitIsDisabled()) { return fun_(*args, **kwargs); } ParsedArgumentsAsBuffers arguments; FlattenArguments(args, kwargs, static_argnums_, arguments); - // TODO(jblespiau): It would be preferable to have a single location for - // locking code. - absl::optional cache_miss_result = absl::nullopt; - if (!default_device_) { - // TODO(jblespiau): This code will deadlock if a jitted function - // recursively calls itself. - if (first_compilation_started_) { - if (!first_compilation_complete_.HasBeenNotified()) { - py::gil_scoped_release gil_release; - first_compilation_complete_.WaitForNotification(); - } - if (first_compilation_error_) { - throw std::invalid_argument( - first_compilation_error_.value().error_message()); - } - } else { - first_compilation_started_ = true; - try { - cache_miss_result = cache_miss_fun_(*args, **kwargs); - } catch (const py::error_already_set& e) { - first_compilation_error_ = InvalidArgument("%s", e.what()); - first_compilation_complete_.Notify(); - throw; - } - auto executable = py::cast>( - cache_miss_result.value()[0]); - - pyclient_ = executable->client(); - default_device_ = executable->LocalDevices()[0].contents; - if (!default_device_) { - throw std::invalid_argument( - "executable->LocalDevices()[0] should not be null!"); - } - first_compilation_complete_.Notify(); - } - } - CHECK(default_device_); - - // The C++ jit do not support Tracers arguments yet. The Python-based jit - // function will be called if any of the dynamic arguments is unsupported. - if (!ConvertArgsToBuffers(jax_enable_x64_, *pyclient_, default_device_, - arguments) + // The C++ jit do not support Tracers arguments inputs yet. The Python-based + // jit function will be called if any of the dynamic arguments is unsupported. + if (!ConvertArgsToBuffers(jax_enable_x64_.value(), *default_pyclient_, + default_device_, is_committed_, arguments) .ok()) { - return python_f_jitted_(*args, **kwargs); + return py::cast(cache_miss_(*args, **kwargs))[0]; } - CacheEntry& cache_entry = - GetCacheEntry(args, kwargs, arguments.signature, cache_miss_result); + CacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature); + if (!cache_entry) { + py::object out_and_fastpath_data = cache_miss_(*args, **kwargs); + cache_entry = GetCacheEntryIfPresent(arguments.signature); + if (!cache_entry) { + cache_entry = AddCacheEntry(args, kwargs, arguments.signature, + out_and_fastpath_data); + } + CHECK(cache_entry); + if (cache_entry->fall_back_to_python) { + return py::cast(out_and_fastpath_data)[0]; + } + // As we have already computed the results, we can return it. + // It's even *required* e.g. if there are donated arguments, because + // otherwise the buffer which has been donated already will be invalid. + return py::cast(out_and_fastpath_data)[0]; + } + CHECK(cache_entry); + if (cache_entry->fall_back_to_python) { + return py::cast(cache_miss_(*args, **kwargs))[0]; + } std::vector> outputs = - ValueOrThrow(cache_entry.executable->PjRtExecute(arguments.arg_buffers)); - - static const auto* xla_module = - new py::module(py::module::import("jax.interpreters.xla")); - const auto& device_array = xla_module->attr("DeviceArray"); - - const std::vector& out_avals = cache_entry.out_avals; - const std::vector& out_lazy_exprs = cache_entry.out_lazy_exprs; + ValueOrThrow(cache_entry->executable->PjRtExecute(arguments.arg_buffers)); + const std::vector& handlers = cache_entry->handlers; py::list flat_device_arrays; for (int i = 0; i < outputs.size(); ++i) { - flat_device_arrays.append(device_array( - /*aval=*/out_avals[i], /*device=*/outputs[i]->device(), - /*lazy_expr=*/out_lazy_exprs[i], - /*device_buffer=*/std::move(outputs[i]))); + flat_device_arrays.append(handlers[i](std::move(outputs[i]))); } - return cache_entry.out_pytree_def.Unflatten(flat_device_arrays); + return cache_entry->out_pytree_def.Unflatten(flat_device_arrays); } } // namespace @@ -810,17 +856,27 @@ void BuildJaxjitSubmodule(pybind11::module& m) { jitlib.def("get_disable_jit", &GetDisableJit); jitlib.def( "jit", - [](py::function fun, py::function cache_miss_fun, - py::function fallback_on_unsupported_argument, bool jax_enable_x64, - bool jax_disable_jit, + [](py::function fun, py::function cache_miss, py::function get_device, + py::function get_jax_enable_x64, py::function get_jax_disable_jit, std::vector static_argnums) -> std::unique_ptr { return std::make_unique( - std::move(fun), std::move(cache_miss_fun), - std::move(fallback_on_unsupported_argument), jax_enable_x64, - jax_disable_jit, std::move(static_argnums)); + std::move(fun), std::move(cache_miss), std::move(get_device), + std::move(get_jax_enable_x64), std::move(get_jax_disable_jit), + std::move(static_argnums)); }); // Only for testing purposes + jitlib.def("_DtypeTo32BitDtype", [](const py::object obj) -> py::object { + py::dtype dtype = py::dtype::from_args(obj); + const py::dtype* res = DtypeTo32BitDtype(dtype); + if (res) { + return *res; + } else { + return py::none(); + } + }); + jitlib.def("_is_float0", &IsFloat0); + jitlib.def("_is_trivial", &HasTrivialLazyExpr); jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64, std::shared_ptr client) { xla::PjRtClient* pjrt_client = client->pjrt_client(); diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index 6df11322564..07b915c640c 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -30,6 +30,8 @@ namespace xla { namespace py = pybind11; namespace pprof = tensorflow::tfprof::pprof; +PyClient::PyClient(std::unique_ptr pjrt_client) + : pjrt_client_(std::move(pjrt_client)) {} PyClient::PyClient(std::shared_ptr pjrt_client) : pjrt_client_(std::move(pjrt_client)) {} diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index f12a4ae4f0a..08249722d6c 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -88,6 +88,7 @@ ClientAndPtr WrapWithClient(std::shared_ptr client, T* contents) { // We use a wrapper class to add Python-specific functionality. class PyClient : public std::enable_shared_from_this { public: + explicit PyClient(std::unique_ptr pjrt_client); explicit PyClient(std::shared_ptr pjrt_client); PjRtClient* pjrt_client() const { return pjrt_client_.get(); } diff --git a/tensorflow/compiler/xla/python/tpu_driver/BUILD b/tensorflow/compiler/xla/python/tpu_driver/BUILD index 0854c76dc34..bda1db6a466 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/BUILD @@ -5,7 +5,6 @@ load( "//tensorflow/compiler/xla/python/tpu_driver:platform/external/tools.bzl", "external_deps", "go_grpc_library", - "go_proto_library", ) licenses(["notice"]) # Apache 2.0 @@ -135,12 +134,6 @@ cc_library( alwayslink = 1, ) -go_proto_library( - name = "tpu_service_go_proto", - compatible_with = ["//buildenv/target:gce"], - deps = [":tpu_service_proto"], -) - go_grpc_library( name = "tpu_service_go_grpc", srcs = [":tpu_service_proto"], diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD index 2592c439e7b..9d98d0cf654 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/BUILD +++ b/tensorflow/compiler/xla/python/tpu_driver/client/BUILD @@ -36,7 +36,7 @@ cc_library( "//tensorflow/compiler/xla/python/tpu_driver:tpu_driver_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:allocator", + "//tensorflow/core/framework:allocator", "//tensorflow/core/platform:env", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl index 99b07b6c787..12a4390d317 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl +++ b/tensorflow/compiler/xla/python/tpu_driver/platform/external/tools.bzl @@ -16,10 +16,6 @@ Build dependencies and utilities for the TPU driver interface. """ -def go_proto_library(**kwargs): - # A dummy macro placeholder for compatibility reason. - pass - def go_grpc_library(**kwargs): # A dummy macro placeholder for compatibility reason. pass diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc index ac54df39895..ee2e2fa57b1 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -17,6 +17,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_split.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/pjrt/semaphore.h" #include "tensorflow/compiler/xla/pjrt/worker_thread.h" #include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h" @@ -29,6 +30,12 @@ namespace tpu_driver { namespace { +#define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id) \ + { \ + auto p = CheckHandleExists(container, target_op_id, operation_id); \ + if (p != nullptr) return p; \ + } + using xla::Status; using xla::WorkerThread; @@ -54,11 +61,35 @@ class PodEvent : public Event { const int64_t operation_id_; }; +class ErrorEvent : public PodEvent { + public: + explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status) + : PodEvent(driver, operation_id) { + status_ = status; + } + + xla::Status Await() override { return status_; } + absl::optional AwaitWithTimeout( + absl::Duration duration) override { + return status_; + } + void AddCallback(std::function callback) override { + callback(status_); + } + + private: + Status status_; +}; + class CombinedEvent : public PodEvent { public: explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id, std::vector> events) - : PodEvent(driver, operation_id), events_(events) {} + : PodEvent(driver, operation_id), events_(events) { + for (auto& event : events_) { + event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); }); + } + } xla::Status Await() override { for (auto& event : events_) { @@ -69,9 +100,10 @@ class CombinedEvent : public PodEvent { absl::optional AwaitWithTimeout( absl::Duration duration) override { - // TODO(frankchn): This might extend the timeout. for (auto& event : events_) { + auto start_time = absl::Now(); auto status = event->AwaitWithTimeout(duration); + duration -= absl::Now() - start_time; if (status == absl::nullopt) { return absl::nullopt; } else { @@ -81,13 +113,48 @@ class CombinedEvent : public PodEvent { return Status::OK(); } - void AddCallback(std::function callback) override { - // TODO(frankchn): This may return before every event is done. - events_[0]->AddCallback(std::move(callback)); + void AddCallback(std::function callback) + TF_LOCKS_EXCLUDED(mu_) override { + bool all_events_completed = false; + { + absl::MutexLock l(&mu_); + all_events_completed = events_completed_ == events_.size(); + } + if (all_events_completed) { + callback(event_status_); + } else { + absl::MutexLock l(&mu_); + callbacks_.push_back(std::move(callback)); + } } private: + void IncrementAndCheckComplete(Status s) TF_LOCKS_EXCLUDED(mu_) { + std::vector> callbacks; + { + absl::MutexLock l(&mu_); + + event_status_ = s; + events_completed_++; + if (events_completed_ == events_.size()) { + // Copy callbacks to a temporary to be invoked outside the mutex. + callbacks.assign(callbacks_.begin(), callbacks_.end()); + callbacks_.clear(); + } else { + return; + } + } + + for (const auto& callback : callbacks) { + callback(event_status_); + } + } + + absl::Mutex mu_; std::vector> events_; + std::vector> callbacks_ ABSL_GUARDED_BY(mu_); + int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0; + Status event_status_; }; class PodBufferHandle : public BufferHandle { @@ -160,6 +227,12 @@ class PodLoadedProgramHandle : public LoadedProgramHandle { }; struct EventInFlight { + EventInFlight() + : underlying_event(nullptr), + create_fn(nullptr), + incomplete_deps(), + callbacks() {} + std::shared_ptr underlying_event; std::function(void)> create_fn; @@ -176,14 +249,33 @@ class PodTpuDriver : public TpuDriver { event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") { std::vector workers = absl::StrSplit( absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ','); + + int worker_count = 0; + + // Flag for environments where local core # == all cores in TPU system #, + // which means that we are connecting to separate TPU systems or we are in + // a test environment. + bool in_local_core_environment = false; + for (const auto& worker : workers) { TpuDriverConfig worker_config(config_); *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker); - drivers_.push_back( - CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie()); + auto tpu_driver = + CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie(); + + SystemInfo driver_info; + tpu_driver->QuerySystemInfo(&driver_info); + + if (driver_info.core_count() == driver_info.local_core_size()) { + drivers_.insert({worker_count, std::move(tpu_driver)}); + in_local_core_environment = true; + } else { + drivers_.insert({driver_info.host_id(), std::move(tpu_driver)}); + } + + worker_count++; } - int cumulative_core_id = 0; absl::flat_hash_set> processed_chips; for (int driver_num = 0; driver_num < workers.size(); ++driver_num) { @@ -208,17 +300,24 @@ class PodTpuDriver : public TpuDriver { } // Process all the unique chips that we have seen. + int core_count = 0; for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) { for (auto& tpu_core : *tpu_chip.mutable_core()) { - int current_core = cumulative_core_id++; + int current_core = tpu_core.id(); + if (in_local_core_environment) { + current_core = core_count; + } - core_to_driver_.push_back(drivers_[tpu_chip.host_id()].get()); - core_to_driver_id_.push_back(tpu_chip.host_id()); - core_to_driver_core_.push_back(tpu_core.id()); + core_to_driver_.insert( + {current_core, drivers_[tpu_chip.host_id()].get()}); + core_to_driver_id_.insert({current_core, tpu_chip.host_id()}); + core_to_driver_core_.insert({current_core, tpu_core.id()}); tpu_core.set_id(current_core); tpu_core.set_core_on_host_index(current_core); *(pod_info_.add_local_core()) = tpu_core; + + core_count++; } // We are setting host_id to zero because we want this to look like one @@ -244,7 +343,7 @@ class PodTpuDriver : public TpuDriver { xla::Status Reset() override { for (auto& driver : drivers_) { - TF_RETURN_IF_ERROR(driver->Reset()); + TF_RETURN_IF_ERROR(driver.second->Reset()); } return xla::Status::OK(); } @@ -257,8 +356,8 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, core_id, region, num_bytes, operation_id]() { - absl::MutexLock l(&mu_); + [this, core_id, region, num_bytes, + operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { underlying_buffers_.insert( {operation_id, core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id], @@ -279,8 +378,8 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, core_id, region, shape, operation_id]() { - absl::MutexLock l(&mu_); + [this, core_id, region, shape, + operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { underlying_buffers_.insert( {operation_id, core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id], @@ -310,12 +409,14 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, core_id, region, children_ids, operation_id]() { - absl::MutexLock l(&mu_); - + [this, core_id, region, children_ids, + operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) + -> std::shared_ptr { std::vector child_buffers; child_buffers.reserve(children_ids.size()); for (int i = 0; i < children_ids.size(); ++i) { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i], + operation_id); child_buffers.push_back(underlying_buffers_[children_ids[i]].get()); } @@ -343,8 +444,10 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, op_id, core_id]() { - absl::MutexLock l(&mu_); + [this, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); + auto buf_iter = underlying_buffers_.find(op_id); auto underlying_hn = std::move(buf_iter->second); underlying_buffers_.erase(buf_iter); @@ -369,8 +472,10 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, src, op_id, core_id]() { - absl::MutexLock l(&mu_); + [this, src, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); + auto buf_iter = underlying_buffers_.find(op_id); return core_to_driver_[core_id]->TransferToDevice( src, buf_iter->second.get(), {}); @@ -392,8 +497,9 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, dst, op_id, core_id]() { - absl::MutexLock l(&mu_); + [this, dst, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id); auto buf_iter = underlying_buffers_.find(op_id); return core_to_driver_[core_id]->TransferFromDevice( buf_iter->second.get(), dst, {}); @@ -406,27 +512,49 @@ class PodTpuDriver : public TpuDriver { std::shared_ptr TransferFromDeviceToDevice( const BufferHandle* src, BufferHandle* dst, absl::Span wait_for) override { - int64_t operation_id = GetOperationId(); - auto deps = GetDependencyOperationIds(wait_for); - deps.insert(static_cast(src)->operation_id()); - deps.insert(static_cast(dst)->operation_id()); + auto src_core_id = static_cast(src)->core_id(); + auto dst_core_id = static_cast(dst)->core_id(); - auto src_op_id = static_cast(src)->operation_id(); - auto dst_op_id = static_cast(dst)->operation_id(); - auto core_id = static_cast(dst)->core_id(); + auto src_driver_id = core_to_driver_id_[src_core_id]; + auto dst_driver_id = core_to_driver_id_[dst_core_id]; - ScheduleRequest( - operation_id, - [this, src_op_id, dst_op_id, core_id]() { - absl::MutexLock l(&mu_); - auto src_iter = underlying_buffers_.find(src_op_id); - auto dst_iter = underlying_buffers_.find(dst_op_id); - return core_to_driver_[core_id]->TransferFromDeviceToDevice( - src_iter->second.get(), dst_iter->second.get(), {}); - }, - deps); + if (src_driver_id == dst_driver_id) { + // They are in the same host, we can schedule it normally + int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); + deps.insert(static_cast(src)->operation_id()); + deps.insert(static_cast(dst)->operation_id()); - return std::make_shared(this, operation_id); + auto src_op_id = static_cast(src)->operation_id(); + auto dst_op_id = static_cast(dst)->operation_id(); + + ScheduleRequest( + operation_id, + [this, operation_id, src_op_id, dst_op_id, dst_core_id]() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id, + operation_id); + CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id, + operation_id); + + auto src_iter = underlying_buffers_.find(src_op_id); + auto dst_iter = underlying_buffers_.find(dst_op_id); + return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice( + src_iter->second.get(), dst_iter->second.get(), {}); + }, + deps); + return std::make_shared(this, operation_id); + } else { + // src and dst are on different hosts, we have to bounce through us. + auto dst_size = dst->size_in_bytes(); + char* host_buf = new char[dst_size]; + + auto src_event = TransferFromDevice(src, host_buf, wait_for); + auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()}); + dst_event->AddCallback( + [src_event, host_buf](xla::Status status) { delete[] host_buf; }); + return dst_event; + } } std::unique_ptr CompileProgram( @@ -437,8 +565,8 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, operation_id, source, num_replicas]() { - absl::MutexLock l(&mu_); + [this, operation_id, source, + num_replicas]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto cph_iterator = underlying_cph_ .insert( @@ -473,8 +601,9 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, operation_id, cph_op_id, core_id]() { - absl::MutexLock l(&mu_); + [this, operation_id, cph_op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id); auto cph_iter = underlying_cph_.find(cph_op_id); underlying_lph_.insert( @@ -505,9 +634,9 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, op_id, core_id]() { - absl::MutexLock l(&mu_); - + [this, operation_id, op_id, + core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr { + CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); auto lph_iter = underlying_lph_.find(op_id); auto event = core_to_driver_[core_id]->UnloadProgram( std::move(lph_iter->second), {}); @@ -526,6 +655,7 @@ class PodTpuDriver : public TpuDriver { const xla::DeviceAssignmentProto& device_assignment, absl::Span wait_for) override { int64_t operation_id = GetOperationId(); + auto deps = GetDependencyOperationIds(wait_for); deps.insert(static_cast(program)->operation_id()); @@ -550,23 +680,27 @@ class PodTpuDriver : public TpuDriver { ScheduleRequest( operation_id, - [this, core_id, op_id, input_op_ids, output_op_ids, - device_assignment]() { - absl::MutexLock l(&mu_); - + [this, operation_id, core_id, op_id, input_op_ids, output_op_ids, + device_assignment]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) + -> std::shared_ptr { std::vector underlying_inputs; std::vector underlying_outputs; underlying_inputs.reserve(input_op_ids.size()); for (auto input_op_id : input_op_ids) { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id, + operation_id); underlying_inputs.push_back(underlying_buffers_[input_op_id].get()); } underlying_outputs.reserve(output_op_ids.size()); for (auto output_op_id : output_op_ids) { + CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id, + operation_id); underlying_outputs.push_back( underlying_buffers_[output_op_id].get()); } + CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id); LoadedProgramHandle* handle = underlying_lph_[op_id].get(); return core_to_driver_[core_id]->ExecuteProgram( handle, underlying_inputs, underlying_outputs, device_assignment, @@ -583,12 +717,12 @@ class PodTpuDriver : public TpuDriver { // Helper methods for Event scheduling - absl::optional WaitForEvent(int64_t event_id, - absl::Duration duration) { + absl::optional WaitForEvent(int64_t event_id, absl::Duration duration) + TF_LOCKS_EXCLUDED(mu_) { std::shared_ptr underlying_event; { - absl::MutexLock l(&event_mu_); + absl::MutexLock l(&mu_); auto event = events_.find(event_id); if (event == events_.end()) { @@ -601,16 +735,20 @@ class PodTpuDriver : public TpuDriver { } auto done = [this, event_id]() { - event_mu_.AssertHeld(); - return events_[event_id].underlying_event != nullptr; + mu_.AssertHeld(); + if (events_.count(event_id) == 0) { + LOG(ERROR) << "Cannot find event id " << event_id + << " in WaitForEvent."; + } + return events_[event_id]->underlying_event != nullptr && + events_[event_id]->underlying_event.use_count() != 0; }; - auto status = - event_mu_.AwaitWithTimeout(absl::Condition(&done), duration); + auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration); if (!status) { return absl::nullopt; } - underlying_event = events_[event_id].underlying_event; + underlying_event = events_[event_id]->underlying_event; } // Wait for the underlying event without holding on to the event_lock_, or @@ -618,8 +756,9 @@ class PodTpuDriver : public TpuDriver { return underlying_event->AwaitWithTimeout(duration); } - void AddCallbackForEvent(int64_t event_id, std::function fn) { - absl::MutexLock l(&event_mu_); + void AddCallbackForEvent(int64_t event_id, std::function fn) + TF_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); auto event = events_.find(event_id); if (event == events_.end()) { @@ -631,15 +770,17 @@ class PodTpuDriver : public TpuDriver { } } - if (event->second.underlying_event != nullptr) { - event->second.underlying_event->AddCallback(fn); + if (event->second->underlying_event != nullptr && + event->second->underlying_event.use_count() != 0) { + event->second->underlying_event->AddCallback(fn); } else { - event->second.callbacks.push_back(std::move(fn)); + event->second->callbacks.push_back(std::move(fn)); } } xla::Status GetCompiledProgramShape(int64_t op_id, - xla::ProgramShapeProto* program_shape) { + xla::ProgramShapeProto* program_shape) + TF_LOCKS_EXCLUDED(mu_) { absl::MutexLock l(&mu_); auto done = [this, op_id]() { @@ -655,14 +796,13 @@ class PodTpuDriver : public TpuDriver { const TpuDriverConfig& config_; std::shared_ptr<::grpc::ChannelCredentials> creds_; - std::vector> drivers_; - std::vector core_to_driver_id_; - std::vector core_to_driver_; - std::vector core_to_driver_core_; + absl::flat_hash_map> drivers_; + absl::flat_hash_map core_to_driver_id_; + absl::flat_hash_map core_to_driver_; + absl::flat_hash_map core_to_driver_core_; SystemInfo pod_info_; absl::Mutex mu_; - absl::Mutex event_mu_; absl::flat_hash_map> underlying_buffers_ ABSL_GUARDED_BY(mu_); @@ -672,9 +812,10 @@ class PodTpuDriver : public TpuDriver { absl::flat_hash_map> underlying_lph_ ABSL_GUARDED_BY(mu_); - absl::btree_map events_ ABSL_GUARDED_BY(event_mu_); + absl::btree_map> events_ + ABSL_GUARDED_BY(mu_); absl::flat_hash_map abnormal_event_status_ - ABSL_GUARDED_BY(event_mu_); + ABSL_GUARDED_BY(mu_); std::atomic operation_id_counter_{0}; @@ -694,18 +835,19 @@ class PodTpuDriver : public TpuDriver { // EventCompleted is executed on the event_thread_ worker thread. We want // to propagate the fact that the event is completed to any subsequent events // that might depend on this event. - void EventCompleted(int64_t event_id, Status status) { - absl::MutexLock l(&event_mu_); + void EventCompleted(int64_t event_id, Status status) TF_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); - absl::btree_map::iterator curr_event; + absl::btree_map>::iterator + curr_event; if (!status.ok()) abnormal_event_status_.insert({event_id, status}); curr_event = events_.find(event_id); - DCHECK(curr_event->second.callbacks.empty()); - DCHECK(curr_event->second.incomplete_deps.empty()); + DCHECK(curr_event->second->callbacks.empty()); + DCHECK(curr_event->second->incomplete_deps.empty()); for (auto& event : events_) { - event.second.incomplete_deps.erase(event_id); + event.second->incomplete_deps.erase(event_id); // The if statement conditions on both // - all previous events have completed (incomplete_deps.empty()) // - the op creating this event has not been called yet @@ -713,16 +855,16 @@ class PodTpuDriver : public TpuDriver { // We call the create_fn that creates the event and adds any relevant // callbacks to the actual event, before setting create_fn to nullptr // to indicate that it has already been called - if (event.second.incomplete_deps.empty() && - event.second.create_fn != nullptr) { + if (event.second->incomplete_deps.empty() && + event.second->create_fn != nullptr) { // We were the last unfilled dependency, all other dependencies are // filled. We can now fire the create function. - event.second.underlying_event = event.second.create_fn(); - for (auto& fn : event.second.callbacks) { - event.second.underlying_event->AddCallback(std::move(fn)); + event.second->underlying_event = event.second->create_fn(); + for (auto& fn : event.second->callbacks) { + event.second->underlying_event->AddCallback(std::move(fn)); } - event.second.callbacks.clear(); - event.second.create_fn = nullptr; + event.second->callbacks.clear(); + event.second->create_fn = nullptr; } } @@ -732,12 +874,14 @@ class PodTpuDriver : public TpuDriver { void ScheduleRequest(int64_t operation_id, std::function(void)> fn, - const absl::flat_hash_set& deps) { - absl::MutexLock l(&event_mu_); - absl::btree_map::iterator event; + const absl::flat_hash_set& deps) + TF_LOCKS_EXCLUDED(mu_) { + absl::MutexLock l(&mu_); + absl::btree_map>::iterator event; absl::flat_hash_set incomplete_deps; - event = events_.insert({operation_id, {}}).first; + event = events_.insert({operation_id, absl::make_unique()}) + .first; for (const auto& dep : deps) { if (events_.count(dep) > 0) incomplete_deps.insert(dep); } @@ -746,9 +890,9 @@ class PodTpuDriver : public TpuDriver { // All dependencies have been fulfilled, we execute the request // immediately and add a callback to inform our event fulfilled thread // when it is done. - event->second.create_fn = nullptr; - event->second.underlying_event = fn(); - event->second.underlying_event->AddCallback( + event->second->create_fn = nullptr; + event->second->underlying_event = fn(); + event->second->underlying_event->AddCallback( [this, operation_id](Status status) { event_thread_.Schedule([this, operation_id, status]() { EventCompleted(operation_id, status); @@ -758,15 +902,28 @@ class PodTpuDriver : public TpuDriver { // There are some dependencies that are not yet fulfilled. We attach // the request to the event, and will execute it in the EventFulfilled // worker thread when all its dependencies are fulfilled. - event->second.create_fn = std::move(fn); - event->second.incomplete_deps = std::move(incomplete_deps); - event->second.callbacks.push_back([this, operation_id](Status status) { + event->second->create_fn = std::move(fn); + event->second->incomplete_deps = std::move(incomplete_deps); + event->second->callbacks.push_back([this, operation_id](Status status) { event_thread_.Schedule([this, operation_id, status]() { EventCompleted(operation_id, status); }); }); } } + + template + std::shared_ptr CheckHandleExists( + absl::flat_hash_map& container, int64_t target_op_id, + int64_t operation_id) { + if (container.count(target_op_id) == 0) { + return std::make_shared( + this, operation_id, + tensorflow::errors::InvalidArgument("Handle ", target_op_id, + " does not exist.")); + } + return nullptr; + } }; xla::Status PodEvent::Await() { diff --git a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc index da51380c104..49a19cf9e7a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/recording_tpu_driver.cc @@ -127,8 +127,11 @@ class RecordingLoadedProgramHandle : public LoadedProgramHandle { class RecordingTpuDriver : public TpuDriver { public: explicit RecordingTpuDriver(std::unique_ptr driver, - const std::string recording_path) - : driver_(std::move(driver)), recording_path_(recording_path) { + const std::string recording_path, + const bool flush) + : driver_(std::move(driver)), + recording_path_(recording_path), + flush_(flush) { auto file_status = tensorflow::Env::Default()->NewAppendableFile( recording_path_, &log_file_); if (!file_status.ok()) { @@ -466,6 +469,7 @@ class RecordingTpuDriver : public TpuDriver { private: std::unique_ptr driver_; const std::string recording_path_; + const bool flush_; std::unique_ptr log_file_; @@ -499,6 +503,22 @@ class RecordingTpuDriver : public TpuDriver { "corrupt. Error: " << data_status.ToString(); } + + if (flush_) { + auto flush_status = log_file_->Flush(); + if (!flush_status.ok()) { + LOG(WARNING) << "Unable to flush data to log file. File possibly " + "corrupt. Error: " + << flush_status.ToString(); + } + + auto sync_status = log_file_->Sync(); + if (!sync_status.ok()) { + LOG(WARNING) << "Unable to sync log file. File possibly " + "corrupt. Error: " + << sync_status.ToString(); + } + } } } @@ -521,6 +541,7 @@ xla::StatusOr> RegisterRecordingTpuDriver( std::string file; std::string worker; + bool flush = false; for (const auto& config : configs) { std::vector kv = @@ -531,6 +552,11 @@ xla::StatusOr> RegisterRecordingTpuDriver( if (kv[0] == "worker") { worker = kv[1]; } + if (kv[0] == "flush") { + if (kv[1] == "true" || kv[1] == "1") { + flush = true; + } + } } TpuDriverConfig worker_config; @@ -541,7 +567,7 @@ xla::StatusOr> RegisterRecordingTpuDriver( auto driver = driver_status.ConsumeValueOrDie(); return std::unique_ptr( - new RecordingTpuDriver(std::move(driver), file)); + new RecordingTpuDriver(std::move(driver), file, flush)); } // To record a sequence of operations, set the worker configuration string to diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 06605660b63..0e92c85f6f6 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -75,7 +75,6 @@ namespace { namespace py = pybind11; - struct Uniquer { absl::Mutex mu; NameUniquer name_uniquer TF_GUARDED_BY(mu); @@ -430,8 +429,13 @@ PYBIND11_MODULE(xla_extension, m) { }) .def_property( "device_assignment", - [](const CompileOptions& options) { - return options.executable_build_options.device_assignment(); + [](const CompileOptions& options) + -> absl::optional { + return options.executable_build_options.has_device_assignment() + ? absl::optional( + options.executable_build_options + .device_assignment()) + : absl::nullopt; }, [](CompileOptions& options, const DeviceAssignment& device_assignment) { @@ -466,32 +470,31 @@ PYBIND11_MODULE(xla_extension, m) { return local_device->client()->TransferToInfeedLocal( literal, local_device->device_ordinal()); }) - .def( - "transfer_from_outfeed", - [](const PjRtDevice& device, - const Shape& shape) -> StatusOr { - GlobalPyRefManager()->CollectGarbage(); - std::shared_ptr literal_shared; - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - Shape shape_with_layout = shape; - ShapeUtil::ForEachMutableSubshape( - &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { - if (!subshape->has_layout()) { - LayoutUtil::SetToDefaultLayout(subshape); - } - }); - TF_ASSIGN_OR_RETURN( - Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape_with_layout, local_device->device_ordinal())); + .def("transfer_from_outfeed", + [](const PjRtDevice& device, + const Shape& shape) -> StatusOr { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal_shared; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device.GetLocalDeviceState()); + Shape shape_with_layout = shape; + ShapeUtil::ForEachMutableSubshape( + &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + TF_ASSIGN_OR_RETURN( + Literal literal, + local_device->client()->TransferFromOutfeedLocal( + shape_with_layout, local_device->device_ordinal())); - literal_shared = std::make_shared(std::move(literal)); - } - return LiteralToPython(std::move(literal_shared)); - }); + literal_shared = std::make_shared(std::move(literal)); + } + return LiteralToPython(std::move(literal_shared)); + }); py::class_>(m, "CpuDevice") .def("__repr__", [](const CpuDevice& device) { @@ -553,13 +556,13 @@ PYBIND11_MODULE(xla_extension, m) { m.def( "get_cpu_client", [](bool asynchronous) -> StatusOr> { - TF_ASSIGN_OR_RETURN(std::shared_ptr client, + TF_ASSIGN_OR_RETURN(std::unique_ptr client, GetCpuClient(asynchronous)); return std::make_shared(std::move(client)); }, py::arg("asynchronous") = true); m.def("get_interpreter_client", []() -> StatusOr> { - TF_ASSIGN_OR_RETURN(std::shared_ptr client, + TF_ASSIGN_OR_RETURN(std::unique_ptr client, GetInterpreterClient()); return std::make_shared(std::move(client)); }); @@ -569,7 +572,7 @@ PYBIND11_MODULE(xla_extension, m) { std::shared_ptr distributed_client, int node_id) -> StatusOr> { TF_ASSIGN_OR_RETURN( - std::shared_ptr client, + std::unique_ptr client, GetNvidiaGpuClient(asynchronous, allocator_config, std::move(distributed_client), node_id)); return std::make_shared(std::move(client)); @@ -820,6 +823,14 @@ PYBIND11_MODULE(xla_extension, m) { hlo_module.config().debug_options(), RenderedGraphFormat::kDot); }); + m.def( + "hlo_module_cost_analysis", + [](PyClient* client, + const HloModule& module) -> StatusOr> { + auto analysis = client->pjrt_client()->GetHloCostAnalysis(); + TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get())); + return analysis->properties(); + }); py::class_ xla_op_class(m, "XlaOp"); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index d65e0152f67..133483d2bb9 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -193,8 +193,8 @@ XLA_ELEMENT_TYPE_TO_DTYPE = { PrimitiveType.F64: np.dtype('float64'), PrimitiveType.C64: np.dtype('complex64'), PrimitiveType.C128: np.dtype('complex128'), - PrimitiveType.TUPLE: np.dtype(np.object), - PrimitiveType.TOKEN: np.dtype(np.object), + PrimitiveType.TUPLE: np.dtype(np.object_), + PrimitiveType.TOKEN: np.dtype(np.object_), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 49c57a27ac0..3863d8a1481 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -105,7 +105,7 @@ def TestFactory(xla_backend, cloud_tpu=False): c, arguments=(), expected=None, - rtol=1e-7, + rtol=1e-4, atol=0): self._ExecuteAndAssertWith( functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), @@ -173,6 +173,13 @@ def TestFactory(xla_backend, cloud_tpu=False): self.assertTrue(hlo_text.startswith("HloModule acomputation")) self.assertIn("fusion", hlo_text) + @unittest.skipIf(cloud_tpu, "not implemented") + def testFlopEstimate(self): + computation = self.ExampleComputation() + properties = xla_client._xla.hlo_module_cost_analysis( + self.backend, computation.as_hlo_module()) + self.assertEqual(properties["flops"], 8.0) + tests.append(ComputationPrinting) class ComputationHashTest(absltest.TestCase): diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index cb2c194858c..15022d1a879 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -2,7 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load( "//tensorflow:tensorflow.bzl", - "if_tpu", + "if_libtpu", "tf_cc_binary", "tf_cc_test", ) @@ -57,7 +57,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", tf_grpc_cc_dependency(), - ] + if_tpu( + ] + if_libtpu( if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"], if_true = [], ), diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 1003cd6c51a..491d1d67877 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -204,7 +204,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/strings", ], ) @@ -449,9 +449,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:human_readable_json", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -880,7 +880,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", ], ) @@ -900,7 +900,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "//third_party/eigen3", "@com_google_absl//absl/container:flat_hash_map", @@ -950,7 +950,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -986,7 +986,7 @@ cc_library( "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1014,7 +1014,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", ], ) @@ -1025,7 +1025,7 @@ cc_library( ":service", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1058,7 +1058,7 @@ cc_library( ":service", "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ] + if_cuda_is_configured([ "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", @@ -1073,7 +1073,7 @@ cc_library( deps = [ ":service", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ] + if_cuda_is_configured([ "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler_impl", ]) + internal_cuda_deps(), @@ -1086,7 +1086,7 @@ cc_library( "//tensorflow/compiler/xla/service/interpreter:compiler", "//tensorflow/compiler/xla/service/interpreter:interpreter_transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1104,7 +1104,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -1126,8 +1126,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", ], @@ -1164,7 +1164,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor", "//tensorflow/stream_executor:device_description", "//tensorflow/stream_executor:device_memory_allocator", @@ -1188,7 +1188,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/types:span", ], ) @@ -1221,7 +1221,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1262,7 +1262,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", ], ) @@ -1708,7 +1708,6 @@ cc_library( srcs = ["hlo_creation_utils.cc"], hdrs = [ "hlo_creation_utils.h", - "//tensorflow/compiler/xla:literal_util", ], deps = [ ":hlo", @@ -2374,6 +2373,7 @@ cc_library( ":hlo_pass_pipeline", ":hlo_verifier", ":tuple_simplifier", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2855,7 +2855,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -2890,7 +2890,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -2953,7 +2953,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", ], @@ -4224,7 +4224,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -4327,7 +4327,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", ], ) @@ -4339,7 +4339,7 @@ tf_cc_test( ":stream_pool", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -4396,7 +4396,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", @@ -4745,7 +4745,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4e7bd85e557..76b0236fcdd 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -4116,8 +4116,10 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { new_limits[i] -= low; } if (slice_in_padding) { - return ReplaceInstruction( - slice, MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape())); + HloInstruction* broadcast = + MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape()); + *(broadcast->mutable_shape()) = slice->shape(); + return ReplaceInstruction(slice, broadcast); } if (slice_undoes_pad && ReplaceInstructionIfSameShape(slice, pad_operand)) { return Status::OK(); @@ -4126,6 +4128,7 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { TF_ASSIGN_OR_RETURN(HloInstruction * new_slice, MakeSliceHlo(pad_operand, new_starts, new_limits, slice->slice_strides())); + *(new_slice->mutable_shape()) = slice->shape(); return ReplaceInstruction(slice, new_slice); } } @@ -4189,9 +4192,18 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { VLOG(3) << "Sink broadcast through slice"; VLOG(3) << "Original slice: " << slice->ToString(); VLOG(3) << "Original broadcast: " << broadcast->ToString(); - TF_ASSIGN_OR_RETURN(auto new_slice, - MakeSliceHlo(broadcast_operand, new_slice_starts, - new_slice_limits, new_slice_strides)); + auto new_slice_shape = broadcast_operand->shape(); + for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) { + int64 size_i = (new_slice_limits[i] - new_slice_starts[i] + + new_slice_strides[i] - 1) / + new_slice_strides[i]; + new_slice_shape.set_dimensions(i, size_i); + } + simplifier_->UpdateLayout(&new_slice_shape); + HloComputation* computation = broadcast_operand->parent(); + auto new_slice = computation->AddInstruction(HloInstruction::CreateSlice( + new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits, + new_slice_strides)); auto new_broadcast = HloInstruction::CreateBroadcast( slice->shape(), new_slice, broadcast->dimensions()); VLOG(3) << "New slice: " << slice->ToString(); @@ -4291,9 +4303,15 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( VLOG(3) << "Original broadcast: " << operand->ToString(); HloInstruction* new_dynamic_slice = broadcast_operand; if (!new_slice_sizes.empty()) { - TF_ASSIGN_OR_RETURN( - new_dynamic_slice, - MakeDynamicSliceHlo(broadcast_operand, new_indices, new_slice_sizes)); + auto new_ds_shape = broadcast_operand->shape(); + for (int64 i = 0; i < broadcast_operand->shape().rank(); ++i) { + new_ds_shape.set_dimensions(i, new_slice_sizes[i]); + } + simplifier_->UpdateLayout(&new_ds_shape); + HloComputation* computation = broadcast_operand->parent(); + new_dynamic_slice = + computation->AddInstruction(HloInstruction::CreateDynamicSlice( + new_ds_shape, broadcast_operand, new_indices, new_slice_sizes)); } auto new_broadcast = HloInstruction::CreateBroadcast( dynamic_slice->shape(), new_dynamic_slice, operand->dimensions()); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index cc501161ce9..19927ae1576 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -143,13 +143,10 @@ StatusOr> AllocationTracker::DeconstructTuple( // We only need to care about replica id 0 here, since the GlobalDataHandle is // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; - if (!shaped_buffer->on_host_shape().IsTuple()) { + if (!shaped_buffer->on_device_shape().IsTuple()) { return InvalidArgument("global data handle %d is not a tuple", data.handle()); } - // If the on-host representation is a tuple, then the on-device one should be - // as well. - TF_RET_CHECK(shaped_buffer->on_device_shape().IsTuple()); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("Deconstructing nested tuples is not implemented."); @@ -160,7 +157,6 @@ StatusOr> AllocationTracker::DeconstructTuple( i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); ++i) { auto element_buffer = ShapedBuffer( - ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), shaped_buffer->platform(), shaped_buffer->device_ordinal()); element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}), diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index eb96d537fa8..4eaa9101cc4 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -82,22 +82,6 @@ struct RendezvousKey { collective_op_kind(collective_op_kind), op_id(op_id) {} - static RendezvousKey FromInstruction( - const RunId& run_id, std::vector global_devices, - int num_local_participants, const HloInstruction* instr) { - CollectiveOpKind collective_op_kind; - int64 op_id; - - std::tie(collective_op_kind, op_id) = - instr->channel_id().has_value() - ? std::make_pair(kCrossModule, instr->channel_id().value()) - : std::make_pair( - kCrossReplica, - static_cast(instr->GetModule()->unique_id())); - return RendezvousKey(run_id, std::move(global_devices), - num_local_participants, collective_op_kind, op_id); - } - template friend H AbslHashValue(H h, const RendezvousKey& k) { return H::combine(std::move(h), k.run_id, k.global_devices, diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index f03b27cdcc7..653f4555a77 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,14 +28,6 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); -StatusOr< - std::tuple, std::unique_ptr>> -Compiler::RunHloPassesAndBufferAssignement( - std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("This compiler does not support this method"); -} - std::vector> Compiler::ComputeBackendConfigs(const HloInstruction& hlo, se::StreamExecutor* executor) const { diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 312a068ba65..253caac195c 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -188,7 +188,10 @@ class Compiler { std::tuple, std::unique_ptr>> RunHloPassesAndBufferAssignement(std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator); + se::DeviceMemoryAllocator* device_allocator, + bool optimize) { + return Unimplemented("This compiler does not support this method"); + } // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index d30186e0d89..855e75a76e0 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" @@ -350,125 +351,130 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, return false; } - HloInstruction* old_root = - conditional->branch_computation(0)->root_instruction(); - if (old_root->opcode() != HloOpcode::kTuple) { - return false; - } else { - VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); - // Identify the gte using `index'. - auto find_gte = [](const HloInstruction* conditional_result, - int64 index) -> HloInstruction* { - for (HloInstruction* instr : conditional_result->users()) { - if (instr->opcode() != HloOpcode::kGetTupleElement) { - return nullptr; - } - if (instr->tuple_index() == index) { - return instr; - } - } - return nullptr; - }; - - // Captures tuple indices refering to converts to be rematerialized/hoisted. - absl::flat_hash_set kspecial_convert = FindSpecialConverts( - old_root, branch_count, conditional, is_layout_sensitive); - - // Exit if we cannot find any converts to be hoisted. - if (kspecial_convert.empty()) { + // Determining whether all branch roots are tuples + for (int branch_num = 0; branch_num < branch_count; ++branch_num) { + HloInstruction* branch_root = + conditional->branch_computation(branch_num)->root_instruction(); + if (branch_root->opcode() != HloOpcode::kTuple) { return false; } + } - TF_RETURN_IF_ERROR( - RestructureConditionalInstruction(conditional->parent(), conditional)); - - for (int branch = 0; branch < branch_count; branch++) { - old_root = conditional->branch_computation(branch)->root_instruction(); - absl::flat_hash_map map_inst_to_tuple_index; - std::vector new_operands(old_root->operand_count()); - absl::flat_hash_set to_hoist_set; - - for (int64 operand_num = 0; operand_num < old_root->operand_count(); - ++operand_num) { - map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = - operand_num; + HloInstruction* old_root = + conditional->branch_computation(0)->root_instruction(); + VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); + // Identify the gte using `index'. + auto find_gte = [](const HloInstruction* conditional_result, + int64 index) -> HloInstruction* { + for (HloInstruction* instr : conditional_result->users()) { + if (instr->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; } - for (int64 operand_num = 0; operand_num < old_root->operand_count(); - ++operand_num) { - HloInstruction* hoist = old_root->mutable_operand(operand_num); - if (!kspecial_convert.contains(operand_num)) { - new_operands[operand_num] = old_root->mutable_operand(operand_num); - continue; - } - - to_hoist_set.insert(hoist); - int64 new_tuple_count = old_root->operand_count(); - - // Replace the hoisted instr in the tuple with the operand/operands. - // We will replace at least one of the operands of the hoist at the - // tuple place; the rest will be added at the end. - bool inplace = true; - CHECK(!hoist->operands().empty()); - for (HloInstruction* prod : hoist->operands()) { - if (inplace) { - map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; - new_operands[map_inst_to_tuple_index[hoist]] = prod; - inplace = false; - } else { - map_inst_to_tuple_index[prod] = new_tuple_count++; - new_operands.push_back(prod); - } - } + if (instr->tuple_index() == index) { + return instr; } + } + return nullptr; + }; - // Create the new root instruction. - HloComputation* cur_branch = conditional->branch_computation(branch); - HloInstruction* new_branch_root = - cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); - // The shape can vary since the operands to convert are now - // being returned through the branches' root. - cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); - TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); + // Captures tuple indices refering to converts to be rematerialized/hoisted. + absl::flat_hash_set kspecial_convert = FindSpecialConverts( + old_root, branch_count, conditional, is_layout_sensitive); - // Only one of the branches needs to change the conditional->parent(). - if (branch != 0) { + // Exit if we cannot find any converts to be hoisted. + if (kspecial_convert.empty()) { + return false; + } + + TF_RETURN_IF_ERROR( + RestructureConditionalInstruction(conditional->parent(), conditional)); + + for (int branch = 0; branch < branch_count; branch++) { + old_root = conditional->branch_computation(branch)->root_instruction(); + absl::flat_hash_map map_inst_to_tuple_index; + std::vector new_operands(old_root->operand_count()); + absl::flat_hash_set to_hoist_set; + + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = + operand_num; + } + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + HloInstruction* hoist = old_root->mutable_operand(operand_num); + if (!kspecial_convert.contains(operand_num)) { + new_operands[operand_num] = old_root->mutable_operand(operand_num); continue; } - HloComputation* conditional_parent = conditional->parent(); - HloInstruction* newconditional = - conditional_parent->AddInstruction(HloInstruction::CreateConditional( - cur_branch->root_instruction()->shape(), - conditional->mutable_operand(0), - absl::MakeSpan(conditional->branch_computations()), - absl::MakeSpan(conditional->operands()).subspan(1))); - // Ensure that all the users of conditional refer to the new one. - TF_RETURN_IF_ERROR( - conditional->ReplaceAllUsesWithDifferentShape(newconditional)); - TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); - conditional = newconditional; - // Add the hoisted instructions in the parent. - for (HloInstruction* hoist : to_hoist_set) { - VLOG(2) << "Hoisting instruction:" << hoist->ToString(); - int64 hoist_index = map_inst_to_tuple_index[hoist]; - // Find out the gte that captured the hoisted instr result. - HloInstruction* gte_hoist = find_gte(conditional, hoist_index); - CHECK(gte_hoist != nullptr); - std::vector new_operands; - for (HloInstruction* op : hoist->operands()) { - HloInstruction* gte = conditional_parent->AddInstruction( - HloInstruction::CreateGetTupleElement( - op->shape(), conditional, map_inst_to_tuple_index[op])); - new_operands.push_back(gte); + + to_hoist_set.insert(hoist); + int64 new_tuple_count = old_root->operand_count(); + + // Replace the hoisted instr in the tuple with the operand/operands. + // We will replace at least one of the operands of the hoist at the + // tuple place; the rest will be added at the end. + bool inplace = true; + CHECK(!hoist->operands().empty()); + for (HloInstruction* prod : hoist->operands()) { + if (inplace) { + map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; + new_operands[map_inst_to_tuple_index[hoist]] = prod; + inplace = false; + } else { + map_inst_to_tuple_index[prod] = new_tuple_count++; + new_operands.push_back(prod); } - HloInstruction* hoisted = conditional_parent->AddInstruction( - hoist->CloneWithNewOperands(hoist->shape(), new_operands)); - VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); - TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); - TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); } - // No need to explicitly delete a hoisted instruction since if its dead - // then the subsequent DCE will remove it. } + + // Create the new root instruction. + HloComputation* cur_branch = conditional->branch_computation(branch); + HloInstruction* new_branch_root = + cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); + // The shape can vary since the operands to convert are now + // being returned through the branches' root. + cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); + TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); + + // Only one of the branches needs to change the conditional->parent(). + if (branch != 0) { + continue; + } + HloComputation* conditional_parent = conditional->parent(); + HloInstruction* newconditional = + conditional_parent->AddInstruction(HloInstruction::CreateConditional( + cur_branch->root_instruction()->shape(), + conditional->mutable_operand(0), + absl::MakeSpan(conditional->branch_computations()), + absl::MakeSpan(conditional->operands()).subspan(1))); + // Ensure that all the users of conditional refer to the new one. + TF_RETURN_IF_ERROR( + conditional->ReplaceAllUsesWithDifferentShape(newconditional)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); + conditional = newconditional; + // Add the hoisted instructions in the parent. + for (HloInstruction* hoist : to_hoist_set) { + VLOG(2) << "Hoisting instruction:" << hoist->ToString(); + int64 hoist_index = map_inst_to_tuple_index[hoist]; + // Find out the gte that captured the hoisted instr result. + HloInstruction* gte_hoist = find_gte(conditional, hoist_index); + CHECK(gte_hoist != nullptr); + std::vector new_operands; + for (HloInstruction* op : hoist->operands()) { + HloInstruction* gte = conditional_parent->AddInstruction( + HloInstruction::CreateGetTupleElement(op->shape(), conditional, + map_inst_to_tuple_index[op])); + new_operands.push_back(gte); + } + HloInstruction* hoisted = conditional_parent->AddInstruction( + hoist->CloneWithNewOperands(hoist->shape(), new_operands)); + VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); + TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); + } + // No need to explicitly delete a hoisted instruction since if its dead + // then the subsequent DCE will remove it. } VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString(); return true; @@ -498,7 +504,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( << conditional_parent->ToString(HloPrintOptions::Fingerprint()) << "\n"; int64 op_index = 0; - for (Boundary b : new_boundaries) { + for (const Boundary& b : new_boundaries) { HloInstruction* op = b.operands()[0]; CHECK(op != nullptr); VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; @@ -545,7 +551,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( for (int i = 0; i < branch_count; i++) { auto computation = conditional->branch_computation(i); std::vector elements; - for (auto b1 : new_boundaries) { + for (const auto& b1 : new_boundaries) { HloInstruction* op = b1.operands()[i]; CHECK(op != nullptr); VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n"; @@ -556,7 +562,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( computation->set_root_instruction(tuple, true); VLOG(2) << "computation is :" << computation->ToString() << "\n"; // Remove hoisted instructions from the branches. - for (auto b2 : to_move_out) { + for (const auto& b2 : to_move_out) { auto instr_to_remove = b2.operands()[i]; // Double check to make sure it is safe to delete the instruction. // Complications may arise due to some operations in the alternative @@ -781,7 +787,7 @@ class GroupConnectedBoundaries { } } // Returns true if `instruction` is worth hoisting. - bool WorthHoisting(HloInstruction* instruction) { + bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) { // This is needed for the "moving-in" transformation, to prevent the root // of the parent computation (which contains the conditional) to be moved // inside the conditional. @@ -789,6 +795,8 @@ class GroupConnectedBoundaries { instruction == conditional_parent_->root_instruction()) { return false; } + // TOOD[b/169182921] The following cost model is rather incomplete. Will + // need to extend to cover most of element-wise ops. switch (instruction->opcode()) { case HloOpcode::kConvert: // If Convert is after AllReduce, it is worth moving out AllReduce @@ -815,6 +823,11 @@ class GroupConnectedBoundaries { return true; } case HloOpcode::kAllReduce: + // It is not safe to move collective ops from outside to inside + // conditional branches, as it may cause synchronization problems, + // when different layouts are assigned to different branches. + return is_inside_branch; + case HloOpcode::kAbs: case HloOpcode::kReduce: case HloOpcode::kAdd: case HloOpcode::kPower: @@ -999,7 +1012,8 @@ class GroupConnectedBoundaries { VLOG(2) << "visiting boundary " << b.ToString() << "\n"; if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical( b.operands(), is_layout_sensitive_)) && - IsSafeToMoveBoundary(b) && WorthHoisting(b.operands()[0])) { + IsSafeToMoveBoundary(b) && + WorthHoisting(b.operands()[0], b.IsInsideBranch())) { connected_boundaries_.push_back(b); VLOG(2) << "boundary can be moved\n"; int64 operand_count = (b.IsInsideBranch()) @@ -1087,6 +1101,13 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( StatusOr ConditionalCodeMotion::Run(HloModule* module) { VLOG(2) << "Begin a new pass of conditional code motion optimization.\n"; + // Use to support debugging of optimization, by disabling the opt after it has + // been applied a pre-determined times (to isolate impact of transformations). + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching 0.\n"; + })) { + return false; + } bool changed = false; bool cleanup_changed = false; { @@ -1177,7 +1198,8 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { benefit_move_out += d.GetBenefit(); if (benefit_move_out >= benefit_move_in) { final_d = Decision::Direction::kMoveOutOfBranch; - VLOG(2) << "Current Decision is move out of branch\n"; + VLOG(2) << "Current Decision is move out of branch (" + << to_move_out.size() << ")\n"; } else { VLOG(2) << "Current Decision remains move into branch\n"; } @@ -1191,7 +1213,8 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { VLOG(2) << "Current Decision remains move out of branch\n"; } else { final_d = Decision::Direction::kMoveIntoBranch; - VLOG(2) << "Current Decision is move into branch\n"; + VLOG(2) << "Current Decision is move into branch (" + << to_move_in.size() << ")\n"; } break; case Decision::Direction::kNoChange: @@ -1260,6 +1283,13 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { new_boundaries_for_moveout[i])); changed |= result; } + VLOG(2) << "Done moving out of branches " << to_move_out.size() + << " times. \n"; + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching 0.\n"; + })) { + break; + } } else if (final_d == Decision::Direction::kMoveIntoBranch) { CHECK(to_move_in.size() == new_boundaries_for_movein.size()); for (int i = 0; i < to_move_in.size(); ++i) { @@ -1268,6 +1298,13 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { new_boundaries_for_movein[i])); changed |= result; } + VLOG(2) << "Done moving into branches " << to_move_in.size() + << " times. \n"; + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching 0.\n"; + })) { + break; + } } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) { // Invoke special handling for convert rematerialization/hoisting // We need to make sure no sharing is present in the branches because no @@ -1276,8 +1313,16 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN( bool convert_result, ConvertSpecialMove(conditional, is_layout_sensitive_)); + if (convert_result) { + VLOG(2) << "Done special moving of convert\n"; + if (!ConsumeFuel("conditional_code_motion", [&] { + return "Skipping conditional opt after allowed limit reaching " + "0.\n"; + })) { + break; + } + } changed |= convert_result; - VLOG(2) << "Done special moving of convert\n"; } } if (changed) { diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index 55ab613aebf..3b40acf54e3 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -78,6 +78,52 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement()))); } +TEST_F(ConditionalCodeMotionTest, VerifyConditionalAnalysisWithWhileTuple) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + + body { + %p_body = (f32[2], bf16[2], s32[]) parameter(0) + %val = f32[2] get-tuple-element(p_body), index=0 + %val2 = bf16[2] get-tuple-element(p_body), index=1 + %const = s32[] constant(-1) + ROOT root = (f32[2], bf16[], s32[]) tuple(%val, %val2, %const) + } + + condition { + %p_cond = (f32[2], bf16[2], s32[]) parameter(0) + %gte = s32[] get-tuple-element(%p_cond), index=2 + %const = s32[] constant(42) + ROOT result = pred[] compare(%gte, %const), direction=EQ + } + + on_true { + %arg_tuple.1 = f32[2] parameter(0) + %const = s32[] constant(42) + %add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1) + %convert.2894 = bf16[2] convert(f32[2] %add.8493) + ROOT %tuple.1 = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const) + } + on_false { + %arg_tuple.1 = f32[2] parameter(0) + %const = s32[] constant(42) + %add.8493 = f32[2] add(f32[2] %arg_tuple.1, f32[2] %arg_tuple.1) + %convert.2894 = bf16[2] convert(f32[2] %add.8493) + %while_init = (f32[2], bf16[2], s32[]) tuple(%add.8493, %convert.2894, %const) + ROOT while = (f32[2], bf16[2], s32[]) while(%while_init), condition=condition, body=body + } + ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = f32[2] parameter(1) + ROOT conditional = (f32[2], bf16[2], s32[]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + } +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); +} + TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) { absl::string_view hlo_string = R"( @@ -576,6 +622,89 @@ ENTRY main { op::AllReduce(op::GetTupleElement(op::Conditional()))))))); } +TEST_F(ConditionalCodeMotionTest, DoNotMoveAllReduceIn) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +%add.64 (x.139: bf16[], y.139: bf16[]) -> bf16[] { + %x.139 = bf16[]{:T(512)} parameter(0) + %y.139 = bf16[]{:T(512)} parameter(1) + ROOT %add.44073 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.139, bf16[]{:T(512)} %y.139) +} + +%add.181 (x.256: bf16[], y.256: bf16[]) -> bf16[] { + %x.256 = bf16[]{:T(512)} parameter(0) + %y.256 = bf16[]{:T(512)} parameter(1) + ROOT %add.44842 = bf16[]{:T(512)} add(bf16[]{:T(512)} %x.256, bf16[]{:T(512)} %y.256) +} + +on_true { + arg_tuple.1 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(0) + get-tuple-element.11 = bf16[2,54,168,128] get-tuple-element(arg_tuple.1), index=0 + get-tuple-element.12 = bf16[2,52,168,128] get-tuple-element(arg_tuple.1), index=1 + convolution.1 = bf16[3,3,128,128] convolution(bf16[2,54,168,128] + get-tuple-element.11, bf16[2,52,168,128] + get-tuple-element.12), window={size=52x168 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + add.1 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.1, bf16[3,3,128,128] convolution.1) + ROOT tuple.1 = (bf16[3,3,128,128]) tuple(add.1) +} + +on_false { + arg_tuple.2 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(0) + get-tuple-element.21 = bf16[2,86,104,128] + get-tuple-element(arg_tuple.2), index=0 + get-tuple-element.22 = bf16[2,84,104,128] + get-tuple-element(arg_tuple.2), index=1 + convolution.2 = bf16[3,3,128,128] + convolution(bf16[2,86,104,128] get-tuple-element.21, bf16[2,84,104,128] + get-tuple-element.22), window={size=84x104 pad=0_0x1_1}, + dim_labels=f01b_i01o->01bf + add.2 = bf16[3,3,128,128] add(bf16[3,3,128,128] convolution.2, bf16[3,3,128,128] convolution.2) + ROOT tuple.2 = (bf16[3,3,128,128]) tuple(add.2) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.3 = (bf16[2,54,168,128], bf16[2,52,168,128]) parameter(1) + arg_tuple.4 = (bf16[2,86,104,128], bf16[2,84,104,128]) parameter(2) + arg_tuple.5 = f32[3,3,128,128] parameter(3) + conditional = (bf16[3,3,128,128]) + conditional(pred.1, arg_tuple.3, arg_tuple.4), true_computation=on_true, + false_computation=on_false + get-first-index = bf16[3,3,128,128] get-tuple-element(conditional), index=0 + all-reduce.2 = bf16[3,3,128,128] + all-reduce(bf16[3,3,128,128] %get-first-index), + channel_id=485, replica_groups={{0,1}}, use_global_device_ids=true, + to_apply=%add.181, metadata={op_type="Conv2DBackpropFilter" + op_name="gradients/resnet50/conv2d_22/Conv2D_grad/Conv2DBackpropFilter"} + convert.2 = f32[3,3,128,128] + convert(bf16[3,3,128,128] %all-reduce.2), + metadata={op_type="Cast" op_name="Cast_15"} + ROOT result = (f32[3,3,128,128]) tuple(convert.2) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + CHECK(conditional != nullptr); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 6); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 6); + + // Checks if conditional shape has changed. + ASSERT_TRUE(ShapeUtil::Compatible( + conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( + BF16, {3, 3, 128, 128})}))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce( + op::GetTupleElement(op::Conditional())))))); +} + TEST_F(ConditionalCodeMotionTest, MovePowOpIn) { absl::string_view hlo_string = R"( diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index ddd8e19716a..e313dbe2415 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -876,30 +876,13 @@ class CopyRemover { // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not // updated as copies are removed. bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { - VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; - bool is_live_range_before = [&] { - if (a.uses.empty()) { - VLOG(2) << "Empty uses for " << *a.value; - return ordering_.IsDefinedBefore(*a.value, *b.value); - } - for (const HloUse* use : a.uses) { - VLOG(3) << "Checking use " << *use << " against " << *b.value; - if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Use " << *use << " is NOT before " << *b.value; - return false; - } - VLOG(3) << "Use " << *use << " is before " << *b.value; - } - return true; - }(); - if (is_live_range_before) { - VLOG(2) << " Live range of " << a.value->ToShortString() << " is before " - << b.value->ToShortString(); - } else { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is not before " << b.value->ToShortString(); + if (a.uses.empty()) { + VLOG(2) << "Empty uses for " << *a.value; + return ordering_.IsDefinedBefore(*a.value, *b.value); } - return is_live_range_before; + return absl::c_all_of(a.uses, [&](const HloUse* use) { + return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_); + }); } // Returns whether 'node' is the last node in its list. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index bdddae43462..80c5513bb9d 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -91,7 +91,7 @@ cc_library( "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", @@ -204,7 +204,7 @@ cc_library( "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@llvm-project//llvm:Core", "@llvm-project//llvm:Object", "@llvm-project//llvm:Support", @@ -321,11 +321,11 @@ cc_library( "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/platform:types", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:device_memory_allocator", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 0260b3926c7..1ffafd37a27 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -562,9 +562,11 @@ StatusOr< std::tuple, std::unique_ptr>> CpuCompiler::RunHloPassesAndBufferAssignement( std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) { - TF_ASSIGN_OR_RETURN( - module, RunHloPasses(std::move(module), executor, device_allocator)); + se::DeviceMemoryAllocator* device_allocator, bool optimize) { + if (optimize) { + TF_ASSIGN_OR_RETURN( + module, RunHloPasses(std::move(module), executor, device_allocator)); + } // Select an order for emitting the HLO instructions for each computation. // Using this sequence enables tighter buffer liveness analysis and reduced diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d28ccd985a3..5c056fcacaa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -138,9 +138,10 @@ class CpuCompiler : public LLVMCompiler { StatusOr< std::tuple, std::unique_ptr>> - RunHloPassesAndBufferAssignement( - std::unique_ptr module, se::StreamExecutor* executor, - se::DeviceMemoryAllocator* device_allocator) override; + RunHloPassesAndBufferAssignement(std::unique_ptr module, + se::StreamExecutor* executor, + se::DeviceMemoryAllocator* device_allocator, + bool optimize) override; StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 7431e829b8e..02bc445ce9a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -210,8 +210,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( absl::Span buffers, absl::Span arguments) { se::Stream* stream = run_options->stream(); - ExecutionOutput result(/*on_host_shape=*/result_shape(), - /*on_device_shape=*/result_shape(), + ExecutionOutput result(/*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); const HloInputOutputAliasConfig& input_output_alias = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index e21ed7ad60e..bfd8e9e111a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -250,9 +250,9 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( size); } - if (size <= 0) { - return InvalidArgument("Outfeed shape must have positive size; got %d", - size); + if (size < 0) { + return InvalidArgument( + "Outfeed shape must have non-negative size; got %d", size); } int32 size_32 = static_cast(size); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 36566d6c25f..54822323137 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -449,7 +449,7 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address) { int64 length = ByteSizeOf(shape); - if (length <= 0 || length > std::numeric_limits::max()) { + if (length < 0 || length > std::numeric_limits::max()) { return InvalidArgument( "xfeed (infeed or outfeed) buffer length %d is outside the valid " "size range", diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index ffbd0d68ce9..23f5a5c434f 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -31,9 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + CHECK(!shape_.IsTuple()); CHECK(!ShapeUtil::IsScalar(shape_)); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index a604e1db222..a11fd44f1ce 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,8 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index b2ed9bd5f31..f6925ce5c80 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -56,6 +56,34 @@ CHECK: private unnamed_addr constant [48 x i8] /*match_optimized_ir=*/false); } +TEST_F(CpuOutfeedTest, OutfeedEmpty) { + const string hlo_text = R"( +HloModule Outfeed + +ENTRY main { + const_a = f32[2,0] constant({{}, {}}) + token0 = token[] after-all() + outfeed = token[] outfeed(f32[2,0] const_a, token0) + ROOT root = () tuple() +} +)"; + + string filecheck_pattern = R"( +CHECK: private unnamed_addr constant [0 x i8] +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + TEST_F(CpuOutfeedTest, OutfeedTokenInTuple) { const string hlo_text = R"( HloModule OutfeedTokenInTuple diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc index e36eff09009..e364c0f1b42 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -128,6 +128,22 @@ TEST_F(InfeedManagerTest, MultiThreaded) { ProcessNextBuffer(length); } +TEST_F(InfeedManagerTest, OutfeedBasic) { + TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/true); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); + xfeed->outfeed()->EnqueueBuffersAtomically({b}); + + ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {32})); +} + +TEST_F(InfeedManagerTest, OutfeedEmpty) { + TestInfeedBuffer* b = new TestInfeedBuffer(0, /*expect_shape_match=*/true); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); + xfeed->outfeed()->EnqueueBuffersAtomically({b}); + + ProcessNextOutfeedBuffer(0, ShapeUtil::MakeShape(U8, {0})); +} + TEST_F(InfeedManagerTest, OutfeedWrongShape) { TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index d5cf2ee9ac0..e9a3c6b3018 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -59,11 +59,11 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index, unowned_indices_.insert(index); } -xla::StatusOr ExecutionInput::ToShapedBuffer( +StatusOr ExecutionInput::ToShapedBuffer( se::DeviceMemoryAllocator* allocator, int device_ordinal) const { const Shape& input_shape = shape(); - xla::ShapedBuffer shaped_buffer(input_shape, input_shape, - allocator->platform(), device_ordinal); + ShapedBuffer shaped_buffer(input_shape, allocator->platform(), + device_ordinal); for (const auto& index_buffer : Buffers()) { const tensorflow::se::OwningDeviceMemory* mem = index_buffer.second.AsOwningDeviceMemory(); @@ -93,8 +93,7 @@ StatusOr Executable::ExecuteOnStream( static ExecutionInput MakeMaybeOwningDeviceMemoryTree( const ShapedBuffer& shaped_buffer) { - ExecutionInput result(shaped_buffer.on_device_shape(), - shaped_buffer.on_host_shape()); + ExecutionInput result(shaped_buffer.on_device_shape()); shaped_buffer.buffers().ForEachElement( [&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) { result.SetBuffer(index, MaybeOwningDeviceMemory(mem)); diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2e3ddedfb8c..1e1b3436a3c 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -60,15 +60,24 @@ namespace xla { // with their indices absent from unowned_indices_. class ExecutionInput { public: - explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape) + explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) { + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); + } + // TODO(b/170310047): remove this overload. + ExecutionInput(xla::Shape shape, xla::Shape host_shape) : buffers_(std::move(shape)) { - SetHostShape(std::move(host_shape)); + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); } - explicit ExecutionInput(ShapeTree buffers, - xla::Shape host_shape) + explicit ExecutionInput(ShapeTree buffers) : buffers_(std::move(buffers)) { - SetHostShape(std::move(host_shape)); + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); + } + // TODO(b/170310047): remove this overload. + ExecutionInput(ShapeTree buffers, + xla::Shape host_shape) + : buffers_(std::move(buffers)) { + SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); } ExecutionInput(ExecutionInput&&) = default; @@ -144,10 +153,13 @@ class ExecutionOutput { std::vector to_be_released) : result_(std::move(result)), to_be_released_(std::move(to_be_released)) {} + // TODO(b/170310047): remove this overload. ExecutionOutput(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal) - : result_(std::move(on_host_shape), std::move(on_device_shape), allocator, - device_ordinal) {} + : result_(std::move(on_device_shape), allocator, device_ordinal) {} + ExecutionOutput(Shape on_device_shape, se::DeviceMemoryAllocator* allocator, + int device_ordinal) + : result_(std::move(on_device_shape), allocator, device_ordinal) {} ExecutionOutput(ExecutionOutput&&) = default; ExecutionOutput& operator=(ExecutionOutput&&) = default; diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc index ab6a3d01d21..17d3fb2b3d6 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc @@ -25,24 +25,33 @@ limitations under the License. namespace xla { FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation( - const HloInstruction* fusion) + const HloInstruction* fusion, int64 root_usage_count) : fusion_(fusion) { HloInstruction* root = fusion->fused_expression_root(); indexing_users_[root].insert(fusion); - index_usage_count_[fusion] = 1; + index_usage_count_[fusion] = root_usage_count; RecomputeCache(); } +// This constant is arbitrarily chosen. Essentially we don't want to have too +// much code duplication, because it slows down the compilation time. There is +// a tradeoff between compilation time and runtime here. +const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15; + bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh( const HloInstruction* producer) const { - // This constant is arbitrarily chosen. Essentially we don't want to have too - // much code duplication, because it slows down the compilation time. There is - // a tradeoff between compilation time and runtime here. - const int64 kAllowedCodeDuplication = 15; - return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication; } +bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const { + for (const auto& entry : index_usage_count_) { + if (entry.second > kAllowedCodeDuplication) { + return true; + } + } + return false; +} + int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions( const HloInstruction* producer) const { int64 total = 0; diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h index b85bf9104c7..abe154a5149 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h @@ -24,13 +24,19 @@ limitations under the License. namespace xla { class FusionNodeIndexingEvaluation { public: - explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion); + explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion, + int64 root_usage_count = 1); // Evaluate the number of times 'producer' would be emitted if it is fused // into 'fusion_'. If the duplication is "too high" (some arbitrary chosen // constant), returns true. bool CodeDuplicationTooHigh(const HloInstruction* producer) const; + // Evaluate the maximum code duplication inside the fusion node. If the + // maximum code duplication is "too high" (some arbitrary chosen constant), + // returns true. + bool MaxCodeDuplicationTooHigh() const; + // Evaluate the number of times 'producer' would be emitted if it is fused // into 'fusion_'. int64 EvaluateEmittedInstructions(const HloInstruction* producer) const; @@ -53,6 +59,8 @@ class FusionNodeIndexingEvaluation { HloInstruction* fusion_operand); private: + static const int64 kAllowedCodeDuplication; + // Computes the 'indexing_users_' and 'index_usage_count_' maps based on the // current instructions inside the fusion node. Also updates // 'total_emitted_instructions_' accordingly. diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index c09757fe1af..d2febb5fb73 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -69,13 +69,8 @@ void GenericTransferManager::TransferLiteralFromDevice( TF_RET_CHECK(stream->parent()->device_ordinal() == device_buffer.device_ordinal()); - // The on-host and on-device shape should always be the same for the generic - // transfer manager. - TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), - device_buffer.on_host_shape())); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.on_host_shape(), + device_buffer.on_device_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { if (subshape.IsArray()) { stream->ThenMemcpy( @@ -103,20 +98,15 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( << ShapeUtil::HumanString(shape) << "; device buffer: " << device_buffer; - // The on-host and on-device shape should always be the same for the generic - // transfer manager. - TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), - device_buffer.on_host_shape())); - TF_RET_CHECK( - ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); + ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape())); TF_RET_CHECK(stream->parent()->device_ordinal() == device_buffer.device_ordinal()); TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.on_host_shape(), + device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); if (device_subshape.IsArray()) { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 57bd6defabb..b2ec656a2ba 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -216,7 +216,9 @@ cc_library( deps = [ ":backend_configs_cc", ":buffer_allocations", + ":cudnn_batchnorm_runner", ":gpu_constants", + ":gpu_conv_runner", ":gpu_executable", ":ir_emission_utils", ":nccl_all_reduce_thunk", @@ -263,6 +265,7 @@ cc_library( ":target_util", ":thunk", ":thunk_emitter", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/xla:hlo_utils", @@ -322,6 +325,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -373,7 +377,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -392,7 +396,7 @@ cc_library( "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", ], ) @@ -408,7 +412,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -458,7 +462,8 @@ tf_cuda_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", + "//tensorflow/compiler/xla:xla_data_proto_cc", ] + if_cuda([ "//tensorflow/stream_executor/cuda:cuda_activation", "//tensorflow/stream_executor/cuda:cuda_gpu_executor", @@ -597,7 +602,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:scoped_annotation", "//tensorflow/stream_executor", @@ -643,7 +648,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@llvm-project//llvm:Core", @@ -687,7 +692,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:device_memory", @@ -724,7 +729,7 @@ cc_library( "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor:device_memory_allocator", ] + if_cuda_is_configured([ @@ -749,7 +754,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", + "//tensorflow/stream_executor:dnn", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -771,7 +777,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -826,7 +832,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", @@ -849,7 +855,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:device_memory_allocator", ]), @@ -1151,7 +1157,7 @@ cc_library( "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@llvm-project//llvm:Core", ], @@ -1231,6 +1237,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:logistic_expander", "//tensorflow/compiler/xla/service:qr_expander", + "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:slice_sinker", @@ -1247,8 +1254,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/memory", @@ -1311,7 +1318,7 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core/platform:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/profiler/lib:traceme", @@ -1397,7 +1404,7 @@ cc_library( ":xfeed_queue", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], @@ -1435,7 +1442,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1522,6 +1529,7 @@ cc_library( hdrs = ["stream_executor_util.h"], copts = tf_copts(), deps = [ + ":ir_emission_utils", ":launch_dimensions", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -1530,11 +1538,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:cuda_libdevice_path", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:kernel_spec", "//tensorflow/stream_executor/gpu:gpu_asm_opts", @@ -1557,7 +1565,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/stream_executor/gpu:asm_compiler", ]), @@ -1616,7 +1624,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:pattern_matcher", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1710,7 +1718,7 @@ cc_library( deps = [ ":gpu_autotuning_proto_cc", "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/protobuf:autotuning_proto_cc", "@com_google_absl//absl/container:flat_hash_map", ], diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 10a565308de..21b4ef40d97 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -610,6 +610,11 @@ static StatusOr DeviceCompare(se::Stream* stream, executor->GetDeviceDescription().threads_per_block_limit(); gpu_device_info.threads_per_warp = executor->GetDeviceDescription().threads_per_warp(); + gpu_device_info.shared_memory_per_block = + executor->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + executor->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = executor->GetDeviceDescription().core_count(); LaunchDimensions dim = CalculateLaunchDimensions(buffer_shape, gpu_device_info); diff --git a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc index c34c299fea8..4ac5784e51a 100644 --- a/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cholesky_thunk.cc @@ -45,10 +45,7 @@ CholeskyThunk::CholeskyThunk(ThunkInfo thunk_info, info_buffer_(info_buffer), type_(type), batch_size_(batch_size), - a_batch_stride_( - n * n * - ShapeUtil::ByteSizeOfPrimitiveType( - thunk_info.hlo_instruction->operand(0)->shape().element_type())), + a_batch_stride_(n * n * ShapeUtil::ByteSizeOfPrimitiveType(type)), n_(n) {} Status CholeskyThunk::ExecuteOnStream(const ExecuteParams& params) { diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index b3b5cf7e048..88982d3c034 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -217,16 +218,23 @@ RefcountingHashMap& GlobalRendezvousMap() { } // anonymous namespace +CollectivePermuteConfig GetCollectivePermuteConfig( + const HloInstruction* instr) { + CollectivePermuteConfig config; + auto* collective_permute = Cast(instr); + config.source_target_pairs = collective_permute->source_target_pairs(); + return config; +} + CollectivePermuteThunk::CollectivePermuteThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& src, - const BufferAllocation::Slice& dest) + ThunkInfo thunk_info, CollectivePermuteConfig&& config, + const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest) : Thunk(kCollectivePermute, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), src_(src), dest_(dest) {} Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { - auto* instr = Cast(hlo_instruction_); auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); @@ -245,7 +253,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // Figure out which replicas our data is copied to. std::vector dest_replicas; - for (const auto& src_dest : instr->source_target_pairs()) { + for (const auto& src_dest : config_.source_target_pairs) { if (src_dest.first == replica_id) { dest_replicas.push_back(src_dest.second); } @@ -260,7 +268,7 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // If no replica writes into us (i.e. we aren't the target of any copies), our // contract is that we zero our output. - if (absl::c_none_of(instr->source_target_pairs(), + if (absl::c_none_of(config_.source_target_pairs, [&](std::pair src_dest) { return src_dest.second == replica_id; })) { diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h index 44cc6a1c64e..bef86eec9af 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h @@ -19,23 +19,30 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace gpu { +struct CollectivePermuteConfig { + std::vector> source_target_pairs; +}; + +CollectivePermuteConfig GetCollectivePermuteConfig(const HloInstruction* instr); + // Thunk that implements the collective-permute HLO. class CollectivePermuteThunk : public Thunk { public: - CollectivePermuteThunk(ThunkInfo thunk_info, + CollectivePermuteThunk(ThunkInfo thunk_info, CollectivePermuteConfig&& config, const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest); Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; - BufferAllocation::Slice src_; - BufferAllocation::Slice dest_; + const CollectivePermuteConfig config_; + const BufferAllocation::Slice src_; + const BufferAllocation::Slice dest_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 4cff48a89da..6560c1a819c 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -17,43 +17,51 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -ConditionalThunk::ConditionalThunk( - ThunkInfo thunk_info, - const BufferAllocation::Slice& branch_index_buffer_index, - absl::Span branch_operand_buffer_indexes, - std::vector branch_thunk_sequences) - : Thunk(Kind::kConditional, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), - branch_index_is_bool_( - thunk_info.hlo_instruction->operand(0)->shape().element_type() == - PRED), - branch_index_buffer_index_(branch_index_buffer_index), - branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), - branch_operand_buffer_indexes.end()) { - // Pass nullptr as the HloInstruction* to the branch_thunks_ +ConditionalThunkConfig GetConditionalThunkConfig( + const HloInstruction* instr, + std::vector&& branch_thunk_sequences, + std::vector>&& branch_profile_indices) { + ConditionalThunkConfig config; + config.branch_index_is_bool = + instr->operand(0)->shape().element_type() == PRED; + config.branch_count = instr->branch_count(); + // Pass nullptr as the HloInstruction* to the branch_thunks // constructors because these SequentialThunks are logically "part of" // this ConditionalThunk, and shouldn't be profiled separately from it. - branch_thunks_.reserve(branch_thunk_sequences.size()); + config.branch_thunks.reserve(branch_thunk_sequences.size()); for (auto& branch_thunk_sequence : branch_thunk_sequences) { - branch_thunks_.emplace_back( - new SequentialThunk(ThunkInfo(), std::move(branch_thunk_sequence))); + config.branch_thunks.emplace_back(new SequentialThunk( + Thunk::ThunkInfo(), std::move(branch_thunk_sequence))); } + config.branch_profile_indices = std::move(branch_profile_indices); + return config; } +ConditionalThunk::ConditionalThunk( + ThunkInfo thunk_info, ConditionalThunkConfig&& config, + const BufferAllocation::Slice& branch_index_buffer_index, + absl::Span branch_operand_buffer_indexes) + : Thunk(Kind::kConditional, thunk_info), + config_(std::move(config)), + branch_index_buffer_index_(branch_index_buffer_index), + branch_operand_buffer_indexes_(branch_operand_buffer_indexes.begin(), + branch_operand_buffer_indexes.end()) {} + Status ConditionalThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { - if (branch_index_is_bool_) { - TF_RET_CHECK(branch_thunks_.size() == 2); + if (config_.branch_index_is_bool) { + TF_RET_CHECK(config_.branch_thunks.size() == 2); } else { - TF_RET_CHECK(!branch_thunks_.empty()); + TF_RET_CHECK(!config_.branch_thunks.empty()); } - for (auto& branch_thunk : branch_thunks_) { + for (auto& branch_thunk : config_.branch_thunks) { TF_RETURN_IF_ERROR(branch_thunk->Initialize(executable, executor)); } return Status::OK(); @@ -69,7 +77,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { bool pred = false; se::DeviceMemoryBase branch_index_address = params.buffer_allocations->GetDeviceAddress(branch_index_buffer_index_); - if (branch_index_is_bool_) { + if (config_.branch_index_is_bool) { stream.ThenMemcpy(&pred, branch_index_address, sizeof(bool)); } else { stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32)); @@ -81,20 +89,20 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { "Failed to retrieve branch_index value on stream %p: %s.", &stream, block_status.error_message()); } - if (branch_index_is_bool_) { + if (config_.branch_index_is_bool) { branch_index = pred ? 0 : 1; } else { // Handle default scenario for branch_index not in [0, num_branches). - if (branch_index < 0 || branch_index >= hlo_instruction_->branch_count()) { - branch_index = hlo_instruction_->branch_count() - 1; + if (branch_index < 0 || branch_index >= config_.branch_count) { + branch_index = config_.branch_count - 1; } } // Execute the branch computation corresponding to the value of branch_index. profiler.StartHloComputation(); - TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(params)); - profiler.FinishHloComputation( - hlo_instruction_->branch_computation(branch_index)); + TF_RETURN_IF_ERROR( + config_.branch_thunks[branch_index]->ExecuteOnStream(params)); + profiler.FinishHloComputation(config_.branch_profile_indices[branch_index]); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index f91f1c52146..bf4280cdb12 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -30,6 +30,18 @@ limitations under the License. namespace xla { namespace gpu { +struct ConditionalThunkConfig { + bool branch_index_is_bool; + int64 branch_count; + std::vector> branch_thunks; + std::vector> branch_profile_indices; +}; + +ConditionalThunkConfig GetConditionalThunkConfig( + const HloInstruction* instr, + std::vector&& branch_thunk_sequences, + std::vector>&& branch_profile_indices); + // ConditionalThunk implements the conditional instruction on GPU by reading the // predicate of the conditional and executing the true or the false computation // depending on the value of the predicate. @@ -43,10 +55,9 @@ namespace gpu { class ConditionalThunk : public Thunk { public: ConditionalThunk( - ThunkInfo thunk_info, + ThunkInfo thunk_info, ConditionalThunkConfig&& config, const BufferAllocation::Slice& branch_index_buffer_index, - absl::Span branch_operand_buffer_indexes, - std::vector branch_thunk_sequences); + absl::Span branch_operand_buffer_indexes); ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; @@ -56,11 +67,9 @@ class ConditionalThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; - const bool branch_index_is_bool_; + const ConditionalThunkConfig config_; BufferAllocation::Slice branch_index_buffer_index_; std::vector branch_operand_buffer_indexes_; - std::vector> branch_thunks_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 3048db95c39..efa3a5802d6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -31,15 +31,16 @@ namespace xla { namespace gpu { ConvolutionThunk::ConvolutionThunk( - ThunkInfo thunk_info, std::vector operand_slices, + ThunkInfo thunk_info, GpuConvConfig&& config, + std::vector operand_slices, BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, BufferAllocation::Slice tuple_result_slice) : Thunk(Kind::kConvolution, thunk_info), - cudnn_call_(Cast(thunk_info.hlo_instruction)), operand_buffers_(std::move(operand_slices)), result_buffer_(result_slice), scratch_buffer_(scratch_slice), - tuple_result_buffer_(tuple_result_slice) {} + tuple_result_buffer_(tuple_result_slice), + config_(std::move(config)) {} Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; @@ -57,7 +58,7 @@ Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - TF_RETURN_IF_ERROR(RunGpuConv(cudnn_call_, absl::MakeSpan(operand_se_buffers), + TF_RETURN_IF_ERROR(RunGpuConv(config_, absl::MakeSpan(operand_se_buffers), result_buffer, scratch, params.stream)); // Write the output tuple. diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 03fae88c6dc..7f8377ebe4c 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -43,7 +43,7 @@ class ConvolutionThunk : public Thunk { // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // // operand_slices should be in the same order as cudnn_call->operands(). - ConvolutionThunk(ThunkInfo thunk_info, + ConvolutionThunk(ThunkInfo thunk_info, GpuConvConfig&& config, std::vector operand_slices, BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, @@ -55,11 +55,13 @@ class ConvolutionThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloCustomCallInstruction* cudnn_call_; std::vector operand_buffers_; BufferAllocation::Slice result_buffer_; BufferAllocation::Slice scratch_buffer_; BufferAllocation::Slice tuple_result_buffer_; + + // Convolution config + const GpuConvConfig config_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc index adf6b68096d..6b01151b48a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -109,26 +110,23 @@ DnnBatchDescriptors MakeBatchNormDescriptors(const Shape& shape, return batch_descs; } -void AssignCommonParams(const HloInstruction* batchnorm, +void AssignCommonParams(const CudnnBatchNormConfig& config, CudnnBatchNormParamsCommon* params, const se::DeviceMemoryBase& operand, - const se::DeviceMemory& scale, float epsilon, - int64 feature_index) { + const se::DeviceMemory& scale) { // The BatchNormTraining HLO outputs a tuple of three elements: output data, // batch mean, and batch variance. We want to make our descriptors based on // the shape of the output data. Batchnorm backward call outputs a tuple of // three elements: grad data, grad offset, and grad scale. We want to make // our descriptors based on the shape of the grad data. - const Shape& shape = batchnorm->shape().IsTuple() - ? batchnorm->shape().tuple_shapes(0) - : batchnorm->shape(); + const Shape& shape = config.output_shape; DnnBatchDescriptors batch_descs = - MakeBatchNormDescriptors(shape, feature_index); + MakeBatchNormDescriptors(shape, config.feature_index); params->operand_desc = batch_descs.input_desc; params->scale_offset_desc = batch_descs.scale_offset_desc; params->operand = operand; params->scale = scale; - params->epsilon = epsilon; + params->epsilon = config.epsilon; } template @@ -211,22 +209,33 @@ void RunCudnnBatchNormBackwardImpl(CudnnBatchNormBackwardParams* params, } // namespace +CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction* instr, + float epsilon, + int64 feature_index) { + CudnnBatchNormConfig config; + + config.output_shape = instr->shape().IsTuple() + ? instr->shape().tuple_shapes(0) + : instr->shape(); + config.output_type = config.output_shape.element_type(); + config.epsilon = epsilon; + config.feature_index = feature_index; + return config; +} + Status RunCudnnBatchNormForwardInference( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output, se::DeviceMemory scale, se::DeviceMemory offset, se::DeviceMemory mean, - se::DeviceMemory variance, float epsilon, int64 feature_index, - se::Stream* stream) { + se::DeviceMemory variance, se::Stream* stream) { CudnnBatchNormForwardInferenceParams inference_params; - AssignCommonParams(batchnorm, &inference_params.common, operand, scale, - epsilon, feature_index); + AssignCommonParams(config, &inference_params.common, operand, scale); inference_params.offset = offset; inference_params.mean = mean; inference_params.variance = variance; inference_params.output = output; - PrimitiveType output_primitive_type = batchnorm->shape().element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormForwardInferenceImpl(&inference_params, stream); @@ -235,29 +244,27 @@ Status RunCudnnBatchNormForwardInference( RunCudnnBatchNormForwardInferenceImpl(&inference_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm forward inference", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } Status RunCudnnBatchNormForwardTraining( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_data, se::DeviceMemory output_mean, se::DeviceMemory output_inv_stddev, se::DeviceMemory scale, - se::DeviceMemory offset, float epsilon, int64 feature_index, - se::Stream* stream) { + se::DeviceMemory offset, se::Stream* stream) { CudnnBatchNormForwardTrainingParams forward_params; - AssignCommonParams(batchnorm, &forward_params.common, operand, scale, epsilon, - feature_index); + AssignCommonParams(config, &forward_params.common, operand, scale); forward_params.offset = offset; forward_params.output_data = output_data; forward_params.output_mean = output_mean; forward_params.output_inv_stddev = output_inv_stddev; - PrimitiveType output_primitive_type = - batchnorm->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormForwardTrainingImpl(&forward_params, stream); @@ -266,22 +273,23 @@ Status RunCudnnBatchNormForwardTraining( RunCudnnBatchNormForwardTrainingImpl(&forward_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm forward training", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } Status RunCudnnBatchNormBackward( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig& config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output, se::DeviceMemory output_grad_scale, se::DeviceMemory output_grad_offset, se::DeviceMemory scale, se::DeviceMemory mean, se::DeviceMemory inv_stddev, - float epsilon, int64 feature_index, se::Stream* stream) { + se::Stream* stream) { CudnnBatchNormBackwardParams backward_params; - AssignCommonParams(batchnorm, &backward_params.common, operand, scale, - epsilon, feature_index); + AssignCommonParams(config, &backward_params.common, operand, scale); backward_params.output_grad_data = output_grad_data; backward_params.grad_output = grad_output; backward_params.output_grad_scale = output_grad_scale; @@ -289,9 +297,7 @@ Status RunCudnnBatchNormBackward( backward_params.mean = mean; backward_params.inv_stddev = inv_stddev; - PrimitiveType output_primitive_type = - batchnorm->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { + switch (config.output_type) { case F16: RunCudnnBatchNormBackwardImpl(&backward_params, stream); break; @@ -299,8 +305,10 @@ Status RunCudnnBatchNormBackward( RunCudnnBatchNormBackwardImpl(&backward_params, stream); break; default: - return Unimplemented("Primitive type not implemented for \"%s\" ", - batchnorm->ToString()); + return Unimplemented( + "Primitive type %s not implemented for batchnorm backward", + primitive_util::LowercasePrimitiveTypeName(config.output_type) + .c_str()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h index 9a630d013f7..b0791b01868 100755 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h @@ -28,27 +28,36 @@ limitations under the License. namespace xla { namespace gpu { +struct CudnnBatchNormConfig { + Shape output_shape; + PrimitiveType output_type; + float epsilon; + int64 feature_index; +}; + +CudnnBatchNormConfig GetCudnnBatchNormConfig(const HloInstruction *instr, + float epsilon, + int64 feature_index); + Status RunCudnnBatchNormForwardInference( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output, se::DeviceMemory scale, se::DeviceMemory offset, se::DeviceMemory mean, - se::DeviceMemory variance, float epsilon, int64 feature_index, - se::Stream* stream); + se::DeviceMemory variance, se::Stream *stream); Status RunCudnnBatchNormForwardTraining( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_data, se::DeviceMemory output_mean, se::DeviceMemory output_inv_stddev, se::DeviceMemory scale, - se::DeviceMemory offset, float epsilon, int64 feature_index, - se::Stream* stream); + se::DeviceMemory offset, se::Stream *stream); Status RunCudnnBatchNormBackward( - const HloInstruction* batchnorm, se::DeviceMemoryBase operand, + const CudnnBatchNormConfig &config, se::DeviceMemoryBase operand, se::DeviceMemoryBase output_grad_data, se::DeviceMemoryBase grad_output, se::DeviceMemory output_grad_scale, se::DeviceMemory output_grad_offset, se::DeviceMemory scale, se::DeviceMemory mean, se::DeviceMemory inv_stddev, - float epsilon, int64 feature_index, se::Stream* stream); + se::Stream *stream); } // namespace gpu } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index e91b2c4d0d2..dae490e0d18 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -31,90 +31,21 @@ namespace gpu { namespace dnn = se::dnn; -namespace { -void CheckInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { - // All input and output statistics variables must be F32. Also, the last - // operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining, - // and CudnnBatchNormBackward is the feature_index which must be S64. - // The allowed types for non-statistics variables are as follows: - // CudnnBatchNormForwardInference: - // operand[0]: {half, float} - // out[0]: {half, float} - // CudnnBatchNormForwardTraining: - // operand[0]: {half, float} - // out[0]: {half, float} - // CudnnBatchNormBackward: - // operand[0]: {half, float} - // operand[4]: {half, float} - // out[0]: {half, float} - // Note non-statistics inputs and outputs mentioned above should be of the - // same type. - - // Check Inputs. - int64 num_operands = hlo->operand_count(); - PrimitiveType operand_primitive_type = - hlo->operand(0)->shape().element_type(); - CHECK(operand_primitive_type == F16 || operand_primitive_type == F32) - << "Not yet implemented"; - - for (int i = 1; i < num_operands - 2; i++) { - if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget && - i == 4) { - // The first operand to batchnorm grad is the input and the 4th operand is - // the grad_output, both of which can be Eigen::half. - CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type) - << "Invalid datatype"; - continue; - } - CHECK_EQ(hlo->operand(i)->shape().element_type(), F32) - << "Not yet implemented"; - } - - // The last operand is the feature index which must be int64. - CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64) - << "Not yet implemented"; - - // Check Outputs. - if (hlo->shape().IsTuple()) { - CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(), - operand_primitive_type) - << "Invalid datatype"; - - for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) { - CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32) - << "Not yet implemented"; - } - } else { - CHECK_EQ(hlo->shape().element_type(), operand_primitive_type) - << "Invalid datatype"; - } -} -} // namespace - CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, - const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, + const BufferAllocation::Slice& variance, const BufferAllocation::Slice& output) : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), offset_(offset), mean_(mean), variance_(variance), - epsilon_(epsilon), - feature_index_(feature_index), - output_(output) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), - kCudnnBatchNormForwardInferenceCallTarget); - CHECK( - LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_(output) {} Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -131,8 +62,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(variance_)); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference( - hlo_instruction_, operand, output_base, scale, offset, mean, variance, - epsilon_, feature_index_, &stream)); + config_, operand, output_base, scale, offset, mean, variance, &stream)); if (!stream.ok()) { return InternalError("BatchNormalizationForward call failed."); @@ -141,32 +71,22 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( } CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), offset_(offset), - epsilon_(epsilon), - feature_index_(feature_index), output_data_(output_data), output_mean_(output_mean), output_inv_stddev_(output_inv_stddev), - output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); - CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(0)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_tuple_(output_tuple) {} Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -185,10 +105,10 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining( - hlo_instruction_, operand, output_data, output_mean, output_inv_stddev, + config_, operand, output_data, output_mean, output_inv_stddev, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), - epsilon_, feature_index_, &stream)); + &stream)); // Write the output tuple. const int kNumOutputs = 3; @@ -207,37 +127,26 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( } CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, - const BufferAllocation::Slice& grad_output, float epsilon, - int64 feature_index, const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& grad_output, + const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), operand_(operand), scale_(scale), mean_(mean), inv_stddev_(inv_stddev), grad_output_(grad_output), - epsilon_(epsilon), - feature_index_(feature_index), output_grad_data_(output_grad_data), output_grad_scale_(output_grad_scale), output_grad_offset_(output_grad_offset), - output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction_; - CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); - CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); - CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(0)->shape())); - CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), - hlo->operand(4)->shape())); - CheckInputOutputPrimitivetypeAreValid(hlo); -} + output_tuple_(output_tuple) {} Status CudnnBatchNormBackwardThunk::ExecuteOnStream( const ExecuteParams& params) { @@ -256,12 +165,12 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); se::Stream* stream = params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward( - hlo_instruction_, operand, output_grad_data, grad_output, - output_grad_scale, output_grad_offset, + config_, operand, output_grad_data, grad_output, output_grad_scale, + output_grad_offset, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), - epsilon_, feature_index_, stream)); + stream)); // Write the output tuple. const int kNumOutputs = 3; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index bb46017b8fb..d45e284ea2c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -47,12 +48,12 @@ namespace gpu { class CudnnBatchNormForwardInferenceThunk : public Thunk { public: CudnnBatchNormForwardInferenceThunk(ThunkInfo thunk_info, + CudnnBatchNormConfig&& config, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& variance, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output); CudnnBatchNormForwardInferenceThunk( @@ -63,23 +64,22 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; BufferAllocation::Slice mean_; BufferAllocation::Slice variance_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_; }; class CudnnBatchNormForwardTrainingThunk : public Thunk { public: CudnnBatchNormForwardTrainingThunk( - ThunkInfo thunk_info, const BufferAllocation::Slice& operand, + ThunkInfo thunk_info, CudnnBatchNormConfig&& config, + const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, - const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, + const BufferAllocation::Slice& offset, const BufferAllocation::Slice& output_data, const BufferAllocation::Slice& output_mean, const BufferAllocation::Slice& output_inv_stddev, @@ -93,12 +93,10 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_data_; BufferAllocation::Slice output_mean_; BufferAllocation::Slice output_inv_stddev_; @@ -108,12 +106,12 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { class CudnnBatchNormBackwardThunk : public Thunk { public: CudnnBatchNormBackwardThunk(ThunkInfo thunk_info, + CudnnBatchNormConfig&& config, const BufferAllocation::Slice& operand, const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, const BufferAllocation::Slice& inv_stddev, const BufferAllocation::Slice& grad_output, - float epsilon, int64 feature_index, const BufferAllocation::Slice& output_grad_data, const BufferAllocation::Slice& output_grad_scale, const BufferAllocation::Slice& output_grad_offset, @@ -126,14 +124,12 @@ class CudnnBatchNormBackwardThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + const CudnnBatchNormConfig config_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice mean_; BufferAllocation::Slice inv_stddev_; BufferAllocation::Slice grad_output_; - float epsilon_; - int64 feature_index_; BufferAllocation::Slice output_grad_data_; BufferAllocation::Slice output_grad_scale_; BufferAllocation::Slice output_grad_offset_; diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index dae15659402..c9b2318af79 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -24,29 +24,12 @@ namespace gpu { CustomCallThunk::CustomCallThunk( ThunkInfo thunk_info, void* call_target, std::vector> operand_slices, - ShapeTree result_slices, std::string opaque) + ShapeTree result_slices, const std::string& opaque) : Thunk(Thunk::kCustomCall, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), call_target_(call_target), operand_slices_(std::move(operand_slices)), result_slices_(std::move(result_slices)), - opaque_(std::move(opaque)) { - const HloInstruction* instr = hlo_instruction_; - CHECK_EQ(instr->operand_count(), operand_slices_.size()); - for (int64 i = 0; i < instr->operand_count(); ++i) { - const auto& s1 = operand_slices_[i].shape(); - const auto& s2 = instr->operand(i)->shape(); - CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat( - "Shape mismatch between instr->operand(%d) and " - "operand_slices[%d].shape(): %s vs %s", - i, i, s1.ToString(), s2.ToString()); - } - CHECK(ShapeUtil::Equal(instr->shape(), result_slices.shape())) - << absl::StreamFormat( - "Shape mismatch between instr->shape() and result_slices.shape(): " - "%s vs %s.", - instr->shape().ToString(), result_slices.shape().ToString()); -} + opaque_(opaque) {} // For each leaf in a preorder traversal of `slices`, appends its device address // to `buffers`. diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h index 31c03f5252f..f36eaa9cef2 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h @@ -41,16 +41,16 @@ class CustomCallThunk : public Thunk { CustomCallThunk( ThunkInfo thunk_info, void* call_target, std::vector> operand_slices, - ShapeTree result_slices, std::string opaque); + ShapeTree result_slices, + const std::string& opaque); Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; void* call_target_; std::vector> operand_slices_; ShapeTree result_slices_; - std::string opaque_; + const std::string opaque_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 318b8aff176..4cc19a23201 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -14,10 +14,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace gpu { +struct NcclAllReduceConfig::AuxData {}; + +NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default; +NcclAllReduceConfig::~NcclAllReduceConfig() = default; + +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, + int64 replica_count) { + NcclAllReduceConfig config = {}; + return config; +} + /* static */ bool NcclAllReduceThunk::NcclIsEnabled() { return false; // Skylark selects this source file if NCCL is disabled. } @@ -32,20 +44,16 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { "compiler, which is necessary to build the NCCL source library."); } -NcclAllReduceThunk::~NcclAllReduceThunk() = default; - /*static*/ absl::flat_hash_set NcclAllReduceThunk::DevicesWithOpenNcclChannels() { return {}; } -struct NcclAllReduceThunk::AuxData {}; - NcclAllReduceThunk::NcclAllReduceThunk( - ThunkInfo thunk_info, int64 replica_count, + ThunkInfo thunk_info, NcclAllReduceConfig &&config, std::vector buffers) : Thunk(Thunk::kNcclAllReduce, thunk_info), - replica_count_(replica_count), + config_(std::move(config)), buffers_(std::move(buffers)) {} } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index ccd661d8ade..a9e6cd05c31 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -24,15 +24,16 @@ namespace xla { namespace gpu { ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, - std::unique_ptr body_thunk_sequence) + std::unique_ptr body_thunk_sequence, + absl::optional body_profile_index) : Thunk(Kind::kWhile, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), loop_limit_(loop_limit), body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. - ThunkInfo(), std::move(*body_thunk_sequence))) {} + ThunkInfo(), std::move(*body_thunk_sequence))), + body_profile_index_(body_profile_index) {} Status ForThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -41,15 +42,14 @@ Status ForThunk::Initialize(const GpuExecutable& executable, } Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { - VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " - << (hlo_instruction_ ? hlo_instruction_->ToString() : ""); + VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters"; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); for (int64 i = 0; i < loop_limit_; ++i) { params.profiler->StartHloComputation(); // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - params.profiler->FinishHloComputation(hlo_instruction_->while_body()); + params.profiler->FinishHloComputation(body_profile_index_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index b6ee950737e..9a8bd069290 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -32,7 +32,8 @@ namespace gpu { class ForThunk : public Thunk { public: ForThunk(ThunkInfo thunk_info, const int64 loop_limit, - std::unique_ptr body_thunk_sequence); + std::unique_ptr body_thunk_sequence, + absl::optional body_profile_index_); ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; @@ -41,9 +42,9 @@ class ForThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; const int64 loop_limit_; std::unique_ptr body_thunk_sequence_; + const absl::optional body_profile_index_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index d08c732e611..7468114516d 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -469,6 +469,165 @@ TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) { EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); } +TEST_F(FusionMergerTest, NoMergeBecauseCodeDuplication) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule module + +and.reduce_sub_computation { + x = pred[] parameter(0) + y = pred[] parameter(1) + ROOT and = pred[] and(x, y) +} + +fused_computation.1 { + param_4.658 = f32[2,20,256]{2,0,1} parameter(4) + slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]} + constant.6847 = s32[] constant(0) + broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={} + param_9.415 = s32[3]{0} parameter(9) + compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE + constant.6846 = pred[] constant(true) + reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={} + param_5.528 = f32[2,512]{1,0} parameter(5) + slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]} + bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384) + constant.5418 = f32[] constant(0) + broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={} + select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227) + add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173) + param_0.299 = s32[] parameter(0) + constant.5157 = s32[] constant(11) + dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299) + slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]} + constant.6800 = s32[] constant(0) + broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={} + param_8.484 = s32[3]{0} parameter(8) + compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE + constant.6798 = pred[] constant(true) + reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={} + param_3.1169 = f32[2,512]{1,0} parameter(3) + slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]} + bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382) + select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227) + add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172) + constant.5154 = s32[] constant(10) + dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299) + slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]} + constant.6794 = s32[] constant(0) + broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={} + param_7.478 = s32[3]{0} parameter(7) + compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE + constant.6793 = pred[] constant(true) + reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={} + param_2.1685 = f32[2,512]{1,0} parameter(2) + slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]} + bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380) + select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227) + add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171) + constant.5153 = s32[] constant(9) + dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299) + slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]} + constant.6788 = s32[] constant(0) + broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={} + param_6.495 = s32[3]{0} parameter(6) + compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE + constant.6786 = pred[] constant(true) + reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={} + param_1.1408 = f32[2,512]{1,0} parameter(1) + slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]} + bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378) + select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227) + add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170) + constant.5152 = s32[] constant(8) + ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299) +} + +fused_computation.2 { + param_4.655 = f32[2,20,256]{2,0,1} parameter(4) + slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]} + param_6.483 = pred[] parameter(6) + broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={} + param_5.525 = f32[2,512]{1,0} parameter(5) + slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]} + bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368) + constant.5415 = f32[] constant(0) + broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={} + select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225) + add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161) + param_0.265 = s32[] parameter(0) + constant.5151 = s32[] constant(7) + dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265) + slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]} + constant.6782 = s32[] constant(0) + broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={} + param_9.391 = s32[3]{0} parameter(9) + compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE + constant.6781 = pred[] constant(true) + reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={} + param_3.1167 = f32[2,512]{1,0} parameter(3) + slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]} + bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366) + select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225) + add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160) + constant.5150 = s32[] constant(6) + dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265) + slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]} + constant.6776 = s32[] constant(0) + broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={} + param_8.464 = s32[3]{0} parameter(8) + compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE + constant.6775 = pred[] constant(true) + reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={} + param_2.1684 = f32[2,512]{1,0} parameter(2) + slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]} + bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364) + select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225) + add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159) + constant.5149 = s32[] constant(5) + dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265) + slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]} + constant.6770 = s32[] constant(0) + broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={} + param_7.458 = s32[3]{0} parameter(7) + compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE + constant.6769 = pred[] constant(true) + reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation + broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={} + param_1.1405 = f32[2,512]{1,0} parameter(1) + slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]} + bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362) + select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225) + add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158) + constant.5148 = s32[] constant(4) + ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265) +} + +ENTRY main { + param_0.0 = s32[] parameter(0) + param_1.0 = f32[2,512]{1,0} parameter(1) + param_2.0 = f32[2,512]{1,0} parameter(2) + param_3.0 = f32[2,512]{1,0} parameter(3) + param_4.0 = f32[2,20,256]{2,1,0} parameter(4) + param_5.0 = f32[2,512]{1,0} parameter(5) + param_6.0 = s32[3]{0} parameter(6) + param_7.0 = s32[3]{0} parameter(7) + param_8.0 = s32[3]{0} parameter(8) + param_9.0 = s32[3]{0} parameter(9) + fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1 + param_10 = pred[] parameter(10) + ROOT fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2 +} + )") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 0320496ea98..5a8265a53a6 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -115,6 +115,8 @@ static StatusOr> DoUncachedGemmAutotune( absl::optional first_algorithm; std::vector profile_results; + GpuGemmConfig config = GetGpuGemmConfig(gemm); + for (se::blas::AlgorithmType algorithm : algorithms) { // Make sure the output buffer always has the same value if we use // the bias parameter. @@ -129,8 +131,7 @@ static StatusOr> DoUncachedGemmAutotune( // for all algorithms if we're targeting < sm_50. But because we pass a // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. - CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer, - stream, + CHECK(RunGemm(config, lhs_buffer, rhs_buffer, output_buffer, stream, /*implements_whole_instruction=*/true, /*profile_index=*/-1, /*profiler=*/nullptr, diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index e55df0bb230..ea4f3951a3d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -33,32 +33,40 @@ limitations under the License. namespace xla { namespace gpu { -GemmThunk::GemmThunk(ThunkInfo thunk_info, +GpuGemmConfig GetGpuGemmConfig(const HloInstruction *gemm) { + GpuGemmConfig config; + config.output_shape = gemm->shape(); + config.lhs_shape = gemm->operand(0)->shape(); + config.rhs_shape = gemm->operand(1)->shape(); + auto backend_config_or = gemm->backend_config(); + config.backend_config = std::move(backend_config_or.ValueOrDie()); + return config; +} + +GemmThunk::GemmThunk(ThunkInfo thunk_info, GpuGemmConfig &&config, const BufferAllocation::Slice &lhs_buffer, const BufferAllocation::Slice &rhs_buffer, const BufferAllocation::Slice &output_buffer, - bool implements_whole_instruction, - const GemmBackendConfig &backend_config) + bool implements_whole_instruction) : Thunk(Kind::kGemm, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), output_buffer_(output_buffer), - implements_whole_instruction_(implements_whole_instruction), - backend_config_(backend_config) {} + implements_whole_instruction_(implements_whole_instruction) {} Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) { auto get_device_address = [&](const BufferAllocation::Slice &slice) { return params.buffer_allocations->GetDeviceAddress(slice); }; - VLOG(3) << "Running GEMM thunk on instruction: " << hlo_instruction_; + VLOG(3) << "Running GEMM thunk"; se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_); se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_); se::DeviceMemoryBase output_data = get_device_address(output_buffer_); - return RunGemm(hlo_instruction_, backend_config_, lhs_data, rhs_data, - output_data, params.stream, implements_whole_instruction_, - profile_index(), params.profiler); + return RunGemm(config_, lhs_data, rhs_data, output_data, params.stream, + implements_whole_instruction_, profile_index(), + params.profiler); } // This struct contains the metadata of a matrix, e.g., its base address and @@ -160,8 +168,7 @@ static bool DoGemmWithAlgorithm( .ok(); } -Status RunGemm(const HloInstruction *gemm, - const GemmBackendConfig &backend_config, +Status RunGemm(const GpuGemmConfig &gemm_config, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, se::Stream *stream, bool implements_whole_instruction, @@ -170,14 +177,11 @@ Status RunGemm(const HloInstruction *gemm, se::blas::ProfileResult *profile_result, absl::optional algorithm) { VLOG(2) << "Executing a GemmThunk"; - CHECK(IsCublasGemm(*gemm)); - const Shape &output_shape = gemm->shape(); - const HloInstruction *lhs = gemm->operand(0); - const HloInstruction *rhs = gemm->operand(1); - - const Shape &lhs_shape = lhs->shape(); - const Shape &rhs_shape = rhs->shape(); + const Shape &output_shape = gemm_config.output_shape; + const Shape &lhs_shape = gemm_config.lhs_shape; + const Shape &rhs_shape = gemm_config.rhs_shape; + const GemmBackendConfig &backend_config = gemm_config.backend_config; const DotDimensionNumbers &dim_nums = backend_config.dot_dimension_numbers(); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 1a51a7d4e0c..9d6613dbe77 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -33,17 +33,26 @@ namespace gpu { // This class stores everything that StreamExecutor needs to launch a BLAS gemm. // It is generated by IrEmitter. -// + +struct GpuGemmConfig { + Shape lhs_shape; + Shape rhs_shape; + Shape output_shape; + GemmBackendConfig backend_config; +}; + +GpuGemmConfig GetGpuGemmConfig(const HloInstruction* gemm); + // This is thread-compatible. class GemmThunk : public Thunk { public: // Constructs a thunk that computes "output = (lhs rhs) * alpha" using // BLAS gemm (alpha is stored in the instruction GemmBackendConfig). - GemmThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& lhs_buffer, + GemmThunk(ThunkInfo thunk_info, GpuGemmConfig&& config, + const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, - bool implements_whole_instruction, - const GemmBackendConfig& backend_config); + bool implements_whole_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -51,28 +60,27 @@ class GemmThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + const GpuGemmConfig config_; const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; const BufferAllocation::Slice output_buffer_; - bool implements_whole_instruction_; - GemmBackendConfig backend_config_; + const bool implements_whole_instruction_; }; // Run the given GEMM instruction `gemm` subject to the configuration -// in `backend_config` and the passed buffers. +// in `gemm_config` and the passed buffers. // // `implements_whole_instruction` is used for the default profiler creation // if the `profiler` is not supplied. False value indicates that the created // profiler will not specifically profile the `gemm` instruction. // // If `algorithm` is provided, it overrides the one specified in -// `backend_config`. +// `gemm_config.backend_config`. Status RunGemm( - const HloInstruction* gemm, const GemmBackendConfig& backend_config, - se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, se::Stream* stream, - bool implements_whole_instruction, absl::optional profile_index, + const GpuGemmConfig& gemm_config, se::DeviceMemoryBase lhs_buffer, + se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, + se::Stream* stream, bool implements_whole_instruction, + absl::optional profile_index, HloExecutionProfiler* profiler = nullptr, se::blas::ProfileResult* profile_result = nullptr, absl::optional algorithm = absl::nullopt); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b74dbd6100a..e9435f4fa92 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" #include "tensorflow/compiler/xla/service/qr_expander.h" +#include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" #include "tensorflow/compiler/xla/service/rng_expander.h" #include "tensorflow/compiler/xla/service/slice_sinker.h" @@ -229,6 +230,7 @@ Status GpuCompiler::OptimizeHloModule( // pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); } @@ -617,6 +619,9 @@ StatusOr> GpuCompiler::RunBackend( stream_exec->GetDeviceDescription().threads_per_warp(); gpu_device_info.shared_memory_per_block = stream_exec->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + stream_exec->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count(); absl::optional cuda_compute_capability = [&]() -> absl::optional { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 8fb741323f3..925caadbb97 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -122,24 +122,28 @@ std::vector GetAlgorithms(CudnnConvKind kind, } StatusOr> GetMIOpenAlgorithms( - const HloCustomCallInstruction* conv, + const HloCustomCallInstruction* instr, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec, ScratchAllocator* scratch_allocator, se::Stream* stream) { std::vector algorithms; - TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, - GetDnnConvolutionKind(conv)); + TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); - TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype, GetDnnDataType(conv)); + TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, + GetDNNConvKindFromCudnnConvKind(config.kind)); + + TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype, + GetDNNDataTypeFromPrimitiveType(config.output_type)); TF_ASSIGN_OR_RETURN(GpuConvParams params, - GetGpuConvParams(conv, operand_buffers, result_buffer)); + GetGpuConvParams(config, operand_buffers, result_buffer)); bool succ = stream_exec->GetMIOpenConvolveAlgorithms( - kind, dtype, stream, params.input_descriptor, params.input_buf, - params.filter_descriptor, params.filter_buf, params.output_descriptor, - params.output_buf, params.conv_desc, scratch_allocator, &algorithms); + kind, dtype, stream, params.config.input_descriptor, params.input_buf, + params.config.filter_descriptor, params.filter_buf, + params.config.output_descriptor, params.output_buf, + params.config.conv_desc, scratch_allocator, &algorithms); DCHECK(succ); return algorithms; @@ -442,6 +446,8 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_), blas_version, canonical_hlo); + TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { XLA_SCOPED_LOGGING_TIMER_LEVEL( absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ", @@ -465,7 +471,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( options.profile_result = &profile_result; options.algo_override = alg; Status launch_status = - RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer, &scratch_allocator, stream, options); if (!launch_status.ok()) { @@ -700,6 +706,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); } else { + TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); for (const auto& miopen_alg : algorithms) { const auto& alg = miopen_alg.algorithm(); XLA_SCOPED_LOGGING_TIMER_LEVEL( @@ -717,7 +724,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( options.algo_override = alg; options.scratch_size_override = miopen_alg.scratch_size(); Status launch_status = - RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer, &scratch_allocator, stream, options); if (!launch_status.ok()) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 5979ccbad58..e0ccbad3a01 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -79,15 +79,16 @@ Status RunGpuConvForward(GpuConvParams params, DeviceMemory filter_buf, DeviceMemory output_buf, AlgorithmConfig algorithm) { - if (params.conv_result_scale != 1) { + if (params.config.conv_result_scale != 1) { return InternalError( "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); + params.config.conv_result_scale); } return stream->ConvolveWithAlgorithm( - params.input_descriptor, input_buf, params.filter_descriptor, filter_buf, - params.conv_desc, params.output_descriptor, &output_buf, - scratch_allocator, algorithm, options.profile_result); + params.config.input_descriptor, input_buf, + params.config.filter_descriptor, filter_buf, params.config.conv_desc, + params.config.output_descriptor, &output_buf, scratch_allocator, + algorithm, options.profile_result); } template @@ -102,13 +103,14 @@ Status RunGpuConvForwardActivation(GpuConvParams params, bias_desc.set_count(1) .set_height(1) .set_width(1) - .set_feature_map_count(params.output_descriptor.feature_map_count()) - .set_layout(params.output_descriptor.layout()); + .set_feature_map_count( + params.config.output_descriptor.feature_map_count()) + .set_layout(params.config.output_descriptor.layout()); se::DeviceMemory side_input(params.fusion->side_input_buf); // If there is no side input, use output as the side input. if (side_input.is_null()) { - if (params.fusion->side_input_scale != 0) { + if (params.config.fusion->side_input_scale != 0) { return InternalError( "Side input scale is not 0, yet no side input buffer is " "provided"); @@ -123,12 +125,13 @@ Status RunGpuConvForwardActivation(GpuConvParams params, } return stream->FusedConvolveWithAlgorithm( - params.input_descriptor, input_buf, params.conv_result_scale, - params.filter_descriptor, filter_buf, params.conv_desc, side_input, - params.fusion->side_input_scale, bias_desc, - DeviceMemory(params.fusion->bias_buf), params.fusion->mode, - params.output_descriptor, &output_buf, scratch_allocator, algorithm, - options.profile_result); + params.config.input_descriptor, input_buf, + params.config.conv_result_scale, params.config.filter_descriptor, + filter_buf, params.config.conv_desc, side_input, + params.config.fusion->side_input_scale, bias_desc, + DeviceMemory(params.fusion->bias_buf), + params.config.fusion->mode, params.config.output_descriptor, &output_buf, + scratch_allocator, algorithm, options.profile_result); } // StreamExecutor supports various data types via overloading, and the support @@ -149,31 +152,33 @@ Status RunGpuConvInternalImpl(GpuConvParams params, DeviceMemory filter_buf, DeviceMemory output_buf, AlgorithmConfig algorithm) { - switch (params.kind) { + switch (params.config.kind) { case CudnnConvKind::kForward: return RunGpuConvForward(params, scratch_allocator, stream, options, input_buf, filter_buf, output_buf, algorithm); case CudnnConvKind::kBackwardInput: - if (params.conv_result_scale != 1) { + if (params.config.conv_result_scale != 1) { return InternalError( "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); + params.config.conv_result_scale); } return stream->ConvolveBackwardDataWithAlgorithm( - params.filter_descriptor, filter_buf, params.output_descriptor, - output_buf, params.conv_desc, params.input_descriptor, &input_buf, - scratch_allocator, algorithm, options.profile_result); + params.config.filter_descriptor, filter_buf, + params.config.output_descriptor, output_buf, params.config.conv_desc, + params.config.input_descriptor, &input_buf, scratch_allocator, + algorithm, options.profile_result); break; case CudnnConvKind::kBackwardFilter: - if (params.conv_result_scale != 1) { + if (params.config.conv_result_scale != 1) { return InternalError( "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); + params.config.conv_result_scale); } return stream->ConvolveBackwardFilterWithAlgorithm( - params.input_descriptor, input_buf, params.output_descriptor, - output_buf, params.conv_desc, params.filter_descriptor, &filter_buf, - scratch_allocator, algorithm, options.profile_result); + params.config.input_descriptor, input_buf, + params.config.output_descriptor, output_buf, params.config.conv_desc, + params.config.filter_descriptor, &filter_buf, scratch_allocator, + algorithm, options.profile_result); break; case CudnnConvKind::kForwardActivation: { return RunGpuConvForwardActivation( @@ -195,7 +200,7 @@ Status RunGpuConvInternalImpl(GpuConvParams params, DeviceMemory filter_buf, DeviceMemory output_buf, AlgorithmConfig algorithm) { - switch (params.kind) { + switch (params.config.kind) { case CudnnConvKind::kForward: return RunGpuConvForward(params, scratch_allocator, stream, options, input_buf, filter_buf, output_buf, algorithm); @@ -218,7 +223,7 @@ Status RunGpuConvImpl(const GpuConvParams& params, auto input_buf = se::DeviceMemory(params.input_buf); auto filter_buf = se::DeviceMemory(params.filter_buf); auto output_buf = se::DeviceMemory(params.output_buf); - AlgorithmConfig algorithm = params.algorithm; + AlgorithmConfig algorithm = params.config.algorithm; if (options.algo_override.has_value()) { algorithm = AlgorithmConfig(*options.algo_override); @@ -238,7 +243,8 @@ Status RunGpuConvImpl(const GpuConvParams& params, if (!stream->ok()) { return InternalError( "Unable to launch convolution with type %s and algorithm (%d, %s)", - CudnnConvKindToString(params.kind), algorithm.algorithm()->algo_id(), + CudnnConvKindToString(params.config.kind), + algorithm.algorithm()->algo_id(), algorithm.algorithm_no_scratch().has_value() ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id()) : "none"); @@ -248,95 +254,83 @@ Status RunGpuConvImpl(const GpuConvParams& params, } // anonymous namespace -StatusOr GetGpuConvParams( - const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer) { - GpuConvParams params; +StatusOr GetGpuConvConfig( + const HloCustomCallInstruction* cudnn_call) { + GpuConvConfig config; + + config.input_type = cudnn_call->operand(0)->shape().element_type(); + config.output_type = cudnn_call->shape().tuple_shapes(0).element_type(); TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - conv->backend_config()); - TF_ASSIGN_OR_RETURN(params.kind, GetCudnnConvKind(conv)); - const Shape* input_shape; - const Shape* filter_shape; - const Shape* output_shape; + cudnn_call->backend_config()); + TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call)); // The third field is scratch size stored from conv_algorithm_picker // The operand is added to the shape field of the conv instruction // in GpuConvAlgorithmPicker::RunOnInstruction() call. - params.algorithm = se::dnn::AlgorithmConfig( + config.algorithm = se::dnn::AlgorithmConfig( se::dnn::AlgorithmDesc(backend_config.algorithm(), backend_config.tensor_ops_enabled()), - conv->shape().tuple_shapes(1).dimensions(0)); - params.conv_result_scale = backend_config.conv_result_scale(); + cudnn_call->shape().tuple_shapes(1).dimensions(0)); + config.conv_result_scale = backend_config.conv_result_scale(); - switch (params.kind) { + Shape operand0_shape = cudnn_call->operand(0)->shape(); + Shape operand1_shape = cudnn_call->operand(1)->shape(); + Shape result_shape = cudnn_call->shape().tuple_shapes(0); + + switch (config.kind) { case CudnnConvKind::kForward: - input_shape = &conv->operand(0)->shape(); - filter_shape = &conv->operand(1)->shape(); - output_shape = &conv->shape().tuple_shapes(0); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; + case CudnnConvKind::kForwardActivation: + config.input_shape = operand0_shape; + config.filter_shape = operand1_shape; + config.output_shape = result_shape; break; case CudnnConvKind::kBackwardInput: - input_shape = &conv->shape().tuple_shapes(0); - filter_shape = &conv->operand(1)->shape(); - output_shape = &conv->operand(0)->shape(); - params.input_buf = result_buffer; - params.filter_buf = operand_buffers[1]; - params.output_buf = operand_buffers[0]; + config.input_shape = result_shape; + config.filter_shape = operand1_shape; + config.output_shape = operand0_shape; break; case CudnnConvKind::kBackwardFilter: - input_shape = &conv->operand(0)->shape(); - filter_shape = &conv->shape().tuple_shapes(0); - output_shape = &conv->operand(1)->shape(); - params.input_buf = operand_buffers[0]; - params.filter_buf = result_buffer; - params.output_buf = operand_buffers[1]; + config.input_shape = operand0_shape; + config.filter_shape = result_shape; + config.output_shape = operand1_shape; break; - case CudnnConvKind::kForwardActivation: { - input_shape = &conv->operand(0)->shape(); - filter_shape = &conv->operand(1)->shape(); - output_shape = &conv->shape().tuple_shapes(0); - params.fusion.emplace(); - GpuConvParams::FusionParams& fusion = *params.fusion; - if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { - return InternalError("Bad activation mode: %s", - backend_config.ShortDebugString()); - } - fusion.mode = static_cast( - backend_config.activation_mode()); - fusion.side_input_scale = backend_config.side_input_scale(); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - params.fusion->bias_buf = operand_buffers[2]; - if (operand_buffers.size() >= 4) { - params.fusion->side_input_buf = operand_buffers[3]; - } - } + default: + return InternalError("Unknown convolution kind"); } - const Window& window = conv->window(); + if (config.kind == CudnnConvKind::kForwardActivation) { + config.fusion.emplace(); + GpuConvConfig::FusionConfig& fusion = *config.fusion; + if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.mode = + static_cast(backend_config.activation_mode()); + fusion.side_input_scale = backend_config.side_input_scale(); + } + + const Window& window = cudnn_call->window(); const ConvolutionDimensionNumbers& dnums = - conv->convolution_dimension_numbers(); + cudnn_call->convolution_dimension_numbers(); VLOG(3) << "Convolution Algorithm: " - << params.algorithm.algorithm()->algo_id(); + << config.algorithm.algorithm()->algo_id(); VLOG(3) << "tensor_ops_enabled: " - << params.algorithm.algorithm()->tensor_ops_enabled(); - VLOG(3) << "Convolution kind: " << CudnnConvKindToString(params.kind); - VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(*input_shape); + << config.algorithm.algorithm()->tensor_ops_enabled(); + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind); + VLOG(3) << "input shape: " + << ShapeUtil::HumanStringWithLayout(config.input_shape); VLOG(3) << "filter shape: " - << ShapeUtil::HumanStringWithLayout(*filter_shape); + << ShapeUtil::HumanStringWithLayout(config.filter_shape); VLOG(3) << "Output shape: " - << ShapeUtil::HumanStringWithLayout(*output_shape); + << ShapeUtil::HumanStringWithLayout(config.output_shape); VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; const int num_dimensions = window.dimensions_size(); - CHECK_LE(num_dimensions, 3) << conv->ToString(); + CHECK_LE(num_dimensions, 3) << cudnn_call->ToString(); // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. @@ -350,18 +344,18 @@ StatusOr GetGpuConvParams( window.dimensions_size() > 0 && window.dimensions()[0].window_reversal(); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()) - << conv->ToString(); + << cudnn_call->ToString(); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()) - << conv->ToString(); + << cudnn_call->ToString(); CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()) - << conv->ToString(); + << cudnn_call->ToString(); for (const WindowDimension& dim : window.dimensions()) { - CHECK_EQ(dims_reversed, dim.window_reversal()) << conv->ToString(); - CHECK_EQ(dim.padding_low(), dim.padding_high()) << conv->ToString(); + CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString(); + CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString(); CHECK_EQ(dim.base_dilation(), 1) << "cudnn does not support base dilation; it " "must be made explicit with a kPad: " - << conv->ToString(); + << cudnn_call->ToString(); } // cuDNN's convolution APIs support the BDYX layout for activations/output and @@ -370,12 +364,16 @@ StatusOr GetGpuConvParams( FilterLayout filter_dl; DataLayout output_dl; + const Shape* input_shape = &config.input_shape; + const Shape* filter_shape = &config.filter_shape; + const Shape* output_shape = &config.output_shape; + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), XlaConvLayoutsToStreamExecutorLayouts( dnums, input_shape->layout(), filter_shape->layout(), output_shape->layout())); - BatchDescriptor& input_descriptor = params.input_descriptor; + BatchDescriptor& input_descriptor = config.input_descriptor; input_descriptor = BatchDescriptor(effective_num_dimensions); input_descriptor.set_layout(input_dl) .set_feature_map_count( @@ -388,7 +386,7 @@ StatusOr GetGpuConvParams( input_shape->dimensions(dnums.input_spatial_dimensions(dim))); } - FilterDescriptor& filter_descriptor = params.filter_descriptor; + FilterDescriptor& filter_descriptor = config.filter_descriptor; filter_descriptor = FilterDescriptor(effective_num_dimensions); filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( @@ -401,11 +399,11 @@ StatusOr GetGpuConvParams( filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim))); } - params.conv_desc = ConvolutionDescriptor(effective_num_dimensions); - params.conv_desc.set_group_count(conv->feature_group_count()); - params.conv_desc.set_convolution_not_crosscorr(dims_reversed); + config.conv_desc = ConvolutionDescriptor(effective_num_dimensions); + config.conv_desc.set_group_count(cudnn_call->feature_group_count()); + config.conv_desc.set_convolution_not_crosscorr(dims_reversed); for (int dim = 0; dim < num_dimensions; ++dim) { - params.conv_desc + config.conv_desc .set_zero_padding( static_cast(effective_num_dimensions - dim - 1), window.dimensions(dim).padding_low()) @@ -417,7 +415,7 @@ StatusOr GetGpuConvParams( window.dimensions(dim).window_dilation()); } - BatchDescriptor& output_descriptor = params.output_descriptor; + BatchDescriptor& output_descriptor = config.output_descriptor; output_descriptor = BatchDescriptor(effective_num_dimensions); output_descriptor.set_layout(output_dl) .set_feature_map_count( @@ -434,32 +432,70 @@ StatusOr GetGpuConvParams( input_descriptor.set_spatial_dim(static_cast(dim), 1); output_descriptor.set_spatial_dim(static_cast(dim), 1); filter_descriptor.set_spatial_dim(static_cast(dim), 1); - params.conv_desc.set_zero_padding(static_cast(dim), 0) + config.conv_desc.set_zero_padding(static_cast(dim), 0) .set_filter_stride(static_cast(dim), 1); } + return config; +} + +StatusOr GetGpuConvParams( + const GpuConvConfig& config, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer) { + GpuConvParams params; + params.config = config; + + switch (config.kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; + case CudnnConvKind::kBackwardInput: + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + } + + if (config.kind == CudnnConvKind::kForwardActivation) { + params.fusion.emplace(); + GpuConvParams::FusionParams& fusion = *params.fusion; + fusion.bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + fusion.side_input_buf = operand_buffers[3]; + } + } + return params; } -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const gpu::GpuConvConfig& config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, RunConvOptions options) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunGpuConv(conv, operand_buffers, result_buffer, &scratch_allocator, + return RunGpuConv(config, operand_buffers, result_buffer, &scratch_allocator, stream, options); } -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const gpu::GpuConvConfig& config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, RunConvOptions options) { TF_ASSIGN_OR_RETURN(GpuConvParams params, - GetGpuConvParams(conv, operand_buffers, result_buffer)); + GetGpuConvParams(config, operand_buffers, result_buffer)); - PrimitiveType input_primitive_type = conv->operand(0)->shape().element_type(); + PrimitiveType input_primitive_type = config.input_type; switch (input_primitive_type) { case F16: return RunGpuConvImpl( @@ -471,8 +507,7 @@ Status RunGpuConv(const HloCustomCallInstruction* conv, return RunGpuConvImpl(params, scratch_allocator, stream, options); case S8: { - PrimitiveType output_primitive_type = - conv->shape().tuple_shapes(0).element_type(); + PrimitiveType output_primitive_type = config.output_type; switch (output_primitive_type) { case F32: return RunGpuConvImpl(params, scratch_allocator, @@ -481,12 +516,11 @@ Status RunGpuConv(const HloCustomCallInstruction* conv, return RunGpuConvImpl(params, scratch_allocator, stream, options); default: - return Unimplemented("Unimplemented convolution %s", - conv->ToString()); + return Unimplemented("Unimplemented convolution"); } } default: - return Unimplemented("Unimplemented convolution %s", conv->ToString()); + return Unimplemented("Unimplemented convolution"); } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h index 3b8ce0f0f1c..5d27e6d6da7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_RUNNER_H_ #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/dnn.h" namespace xla { namespace gpu { @@ -40,10 +42,10 @@ struct RunConvOptions { absl::optional scratch_size_override; }; -// Implementation struct exposed for debugging and log analysis. -struct GpuConvParams { - // Here are the fields related to cuDNN's fused convolution. The result thus - // is defined as: +// Structure to describe static properties of a GPU convolution. +struct GpuConvConfig { + // Field related to cuDNN's fused convolution are in FusionConfig & + // FusionParams structures. The result thus is defined as: // activation(conv_result_scale * conv(x, w) + // side_input_scale * side_input + broadcast(bias)) // @@ -54,23 +56,39 @@ struct GpuConvParams { // added to the final results. // // side_input_buf, if valid, must have the same shape as the output buffer. - struct FusionParams { + struct FusionConfig { se::dnn::ActivationMode mode; double side_input_scale; + }; + + PrimitiveType input_type; + PrimitiveType output_type; + CudnnConvKind kind; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + se::dnn::BatchDescriptor input_descriptor; + se::dnn::FilterDescriptor filter_descriptor; + se::dnn::BatchDescriptor output_descriptor; + se::dnn::ConvolutionDescriptor conv_desc; + + Shape input_shape; + Shape filter_shape; + Shape output_shape; + absl::optional fusion; +}; + +// Implementation struct exposed for debugging and log analysis. +struct GpuConvParams { + GpuConvConfig config; + struct FusionParams { se::DeviceMemoryBase bias_buf; se::DeviceMemoryBase side_input_buf; // nullable }; - CudnnConvKind kind; - se::dnn::BatchDescriptor input_descriptor; - se::dnn::FilterDescriptor filter_descriptor; - se::dnn::BatchDescriptor output_descriptor; se::DeviceMemoryBase input_buf; se::DeviceMemoryBase filter_buf; se::DeviceMemoryBase output_buf; - se::dnn::ConvolutionDescriptor conv_desc; - se::dnn::AlgorithmConfig algorithm; - double conv_result_scale; absl::optional fusion; }; @@ -89,21 +107,24 @@ struct GpuConvParams { // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const GpuConvConfig& conv_config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, RunConvOptions = {}); -Status RunGpuConv(const HloCustomCallInstruction* conv, +Status RunGpuConv(const GpuConvConfig& conv_config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, RunConvOptions = {}); +StatusOr GetGpuConvConfig( + const HloCustomCallInstruction* cudnn_call); + // Implementation details exposed for debugging and log analysis. StatusOr GetGpuConvParams( - const HloCustomCallInstruction* conv, + const GpuConvConfig& conv_config, absl::Span operand_buffers, se::DeviceMemoryBase result_buffer); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h index 7352bad1a66..afb773c4527 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h @@ -32,6 +32,8 @@ struct GpuDeviceInfo { int threads_per_block_limit; int threads_per_warp; int shared_memory_per_block; + int threads_per_core_limit; + int core_count; }; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index c963dfb2b2a..1a0d1e0beb6 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -450,8 +450,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); const Shape& root_shape = root->shape(); auto device_ordinal = executor->device_ordinal(); - ExecutionOutput result(/*on_host_shape=*/root->shape(), - /*on_device_shape=*/root->shape(), memory_allocator, + ExecutionOutput result(/*on_device_shape=*/root->shape(), memory_allocator, device_ordinal); TF_ASSIGN_OR_RETURN(BufferAllocations buffer_allocations, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 1f83ec71984..e73c4885e9e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -90,6 +90,15 @@ void HloExecutionProfiler::FinishHloComputation( } } +void HloExecutionProfiler::FinishHloComputation( + absl::optional profile_index) { + if (do_profile_) { + profile_->SetCyclesTakenBy( + profile_index.value(), + GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_)); + } +} + void HloExecutionProfiler::StartHloInstruction() { if (do_profile_) { InitAndStartTimer(&timers_, stream_); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h index 1189143e3f9..860fa167790 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -53,6 +53,11 @@ class HloExecutionProfiler { // the time that the computation took to execute in the profile. void FinishHloComputation(const HloComputation* computation); + // If profiling is enabled stops the timer for a (sub)computation with the + // given profile index and records the time that the computation took to + // execute in the profile. + void FinishHloComputation(absl::optional profile_index); + // If profiling is enabled, starts a per-operation timer. void StartHloInstruction(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 5fe459a70bc..dc3a0c788ac 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -26,14 +26,13 @@ InfeedThunk::InfeedThunk( ThunkInfo thunk_info, const ShapeTree& infeed_slices) : Thunk(Kind::kInfeed, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; auto& buffer_allocations = *params.buffer_allocations; - VLOG(2) << "Infeeding to GPU: " << hlo_instruction_->ToString(); + VLOG(2) << "Infeeding to GPU"; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index ab410661ba1..ec33235c466 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -43,7 +43,6 @@ class InfeedThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; const ShapeTree infeed_slices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc index 154612824ef..4f4409ab896 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc @@ -35,8 +35,9 @@ XLAThunksDialect::XLAThunksDialect(MLIRContext *context) >(); } +} // namespace xla_thunks + #define GET_OP_CLASSES #include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc.inc" -} // namespace xla_thunks } // namespace mlir diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h index ede9adb9ab1..bc0da6a8fc8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h @@ -33,10 +33,11 @@ class XLAThunksDialect : public Dialect { static StringRef getDialectNamespace() { return "xla_thunks"; } }; +} // namespace xla_thunks + #define GET_OP_CLASSES #include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h.inc" -} // namespace xla_thunks } // namespace mlir #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 9d4ec358bd3..7743d19497d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -472,39 +472,6 @@ StatusOr GetCudnnConvKind( return InternalError("Unexpected call target: %s", target); } -StatusOr GetDnnConvolutionKind( - const HloCustomCallInstruction* instr) { - absl::string_view target = instr->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - return se::dnn::ConvolutionKind::FORWARD; - } - if (target == kCudnnConvBackwardInputCallTarget) { - return se::dnn::ConvolutionKind::BACKWARD_DATA; - } - if (target == kCudnnConvBackwardFilterCallTarget) { - return se::dnn::ConvolutionKind::BACKWARD_FILTER; - } - return InternalError("Unexpected call target: %s", target); -} - -StatusOr GetDnnDataType( - const HloCustomCallInstruction* conv) { - PrimitiveType output_primitive_type = - conv->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { - case F16: - return se::dnn::ToDataType::value; - case F32: - return se::dnn::ToDataType::value; - case F64: - return se::dnn::ToDataType::value; - default: - break; - } - return InternalError("Unsupported convolution datatype : %s", - conv->ToString()); -} - string CudnnConvKindToString(CudnnConvKind kind) { switch (kind) { case CudnnConvKind::kForward: diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 6f731b2936f..a782eb3f507 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -55,12 +55,6 @@ enum class CudnnConvKind { StatusOr GetCudnnConvKind(const HloCustomCallInstruction* instr); -StatusOr GetDnnConvolutionKind( - const HloCustomCallInstruction* instr); - -StatusOr GetDnnDataType( - const HloCustomCallInstruction* conv); - // Converts a CudnnConvKind value to a string. string CudnnConvKindToString(CudnnConvKind kind); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f7627c348b6..d33474f83c2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" @@ -201,12 +202,29 @@ StatusOr GetAllocationSliceForMlir( "StaticMemRefCastOp(ViewOp(arg))"); } -absl::string_view GetHloName(mlir::Operation* op) { - if (auto attr = op->getAttrOfType("name")) { - auto ref = attr.getValue(); - return absl::string_view(ref.data(), ref.size()); +StatusOr> GetMlirBufferSlices( + mlir::Operation* op, mlir::OperandRange operands, + absl::Span allocations) { + const auto buffer_is_written = [op](mlir::Value operand) { + llvm::SmallVector effects; + mlir::cast(op).getEffectsOnValue(operand, + effects); + return absl::c_any_of( + effects, [](const mlir::MemoryEffects::EffectInstance& instance) { + return mlir::isa(instance.getEffect()); + }); + }; + + std::vector slices; + for (mlir::Value operand : operands) { + slices.emplace_back(); + auto& slice = slices.back(); + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + GetAllocationSliceForMlir(operand, allocations)); + slice.written = buffer_is_written(operand); + slice.shape = TypeToShape(operand.getType()); } - return ""; + return slices; } } // namespace @@ -1366,13 +1384,6 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitSortOp(sort)); result.op = sort_op; - result.name = GetHloName(sort_op); - // The name in sort op has no semantics, and it's for debug only. If the name - // doesn't exist, we should use a namer (e.g. count-based). - // TODO(timshen): use a namer instead of relying on the HloInstruction names. - if (result.name.empty()) { - result.name = sort->name(); - } const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); auto& slice = result.extra_slice; TF_ASSIGN_OR_RETURN(slice.buffer_slice, @@ -1385,48 +1396,30 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { return EmitSortFromMlir(result); } -Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { +Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) { absl::Span allocations( ir_emitter_context_->buffer_assignment().Allocations()); - auto sort_op = mlir::cast(input.op); - - int operand_count = sort_op.operands().size(); - std::vector operand_shapes(operand_count); - std::vector slices; - std::vector output_shapes(sort_op.output().size()); - - for (int i = 0; i < operand_count; i++) { - operand_shapes[i] = - TypeToShape(sort_op.operands()[i].getType().cast()); - } - - // Craft n + 1 slices, where the first n are output parameters, and the last - // is the on-device tuple storage. We don't need n operands because sorting - // kernels are always in-place. - for (int i = 0; i < operand_count; i++) { - output_shapes[i] = - TypeToShape(sort_op.output()[i].getType().cast()); - MlirBufferSlice slice; - TF_ASSIGN_OR_RETURN( - slice.buffer_slice, - GetAllocationSliceForMlir(sort_op.output()[i], allocations)); - slice.written = true; - slice.shape = operand_shapes[i]; - slices.push_back(slice); - } - slices.push_back(input.extra_slice); + auto sort_op = mlir::cast(mlir_input.op); + std::string name = mlir::GetNameFromLoc(sort_op.getLoc()); + TF_ASSIGN_OR_RETURN( + std::vector operands, + GetMlirBufferSlices(sort_op, sort_op.operands(), allocations)); + TF_ASSIGN_OR_RETURN( + std::vector outputs, + GetMlirBufferSlices(sort_op, sort_op.output(), allocations)); + outputs.push_back(mlir_input.extra_slice); std::vector> thunks; - Shape keys_shape = operand_shapes[0]; + Shape keys_shape = operands[0].shape; int64 dimension_to_sort = sort_op.dimension(); - for (int64 i = 0; i < operand_count; ++i) { + for (int64 i = 0; i < operands.size(); ++i) { // We assume that the layout of all involved operands and outputs is the // same. TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i])); + LayoutUtil::LayoutsInShapesEqual(keys_shape, operands[i].shape)); TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i])); + LayoutUtil::LayoutsInShapesEqual(keys_shape, outputs[i].shape)); // If possible, we share buffers. If that is not possible, we need to copy // the values, because the emitter does the sorting in-place. @@ -1439,18 +1432,18 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - VLOG(2) << input.name << " requires initial D2D copy for operand " << i; + VLOG(2) << name << " requires initial D2D copy for operand " << i; thunks.push_back(absl::make_unique( Thunk::ThunkInfo(), /*source_address=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i]))); + /*mem_size=*/ShapeUtil::ByteSizeOf(operands[i].shape))); } } uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - VLOG(2) << input.name << " requires " << num_stages << " stages."; + VLOG(2) << name << " requires " << num_stages << " stages."; CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); @@ -1514,10 +1507,10 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { // we have not enough threads, or not enough shared memory. Also it does not // give a speedup if the tile size is < 128. int64 total_shared_memory_needed = 0; - for (int64 i = 0; i < operand_count; ++i) { + for (int64 i = 0; i < operands.size(); ++i) { total_shared_memory_needed += kTileSize * - ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type()); + ShapeUtil::ByteSizeOfPrimitiveType(operands[i].shape.element_type()); } bool no_tiling = kTileSize < 128 || @@ -1530,7 +1523,7 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { "kTileSize=%d < 128, " "kThreadsPerBlock=%d > threads_per_block_limit=%d, " "total_shared_memory_needed=%d > shared_memory_per_block=%d", - input.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, + name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, ir_emitter_context_->gpu_device_info().threads_per_block_limit, total_shared_memory_needed, ir_emitter_context_->gpu_device_info().shared_memory_per_block); @@ -1538,32 +1531,32 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block", - input.name, num_blocks, kThreadsPerBlock); + name, num_blocks, kThreadsPerBlock); std::vector ir_arrays; auto emit_kernel = [&](absl::Span xor_masks) { VLOG(2) << absl::StreamFormat( - "%s uses kernel for xor masks [%s]", input.name, + "%s uses kernel for xor masks [%s]", name, absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) { absl::StrAppendFormat(out, "0x%x", xor_mask); })); - thunks.push_back(BuildKernelThunkForMlir(input.name, Thunk::ThunkInfo(), - slices, &ir_arrays)); + thunks.push_back( + BuildKernelThunkForMlir(name, Thunk::ThunkInfo(), outputs, &ir_arrays)); LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); std::vector values_arrays; - values_arrays.reserve(operand_count); - for (int64 i = 0; i < operand_count; ++i) { + values_arrays.reserve(operands.size()); + for (int64 i = 0; i < operands.size(); ++i) { values_arrays.push_back(ir_arrays[i]); } TF_ASSIGN_OR_RETURN( const HloComputation* comparator, GetOrCreateSubComputationFromRegion(&sort_op.comparator())); return llvm_ir::EmitSortInPlace( - dimension_to_sort, values_arrays, IrName(input.name), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(name), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, @@ -1596,17 +1589,16 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput input) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } VLOG(2) << absl::StreamFormat( - "%s requires %d thunks (including any D2D copies)", input.name, - thunks.size()); + "%s requires %d thunks (including any D2D copies)", name, thunks.size()); - AddThunkToThunkSequence( - absl::make_unique(input.thunk_info, std::move(thunks))); - if (operand_count > 1) { + AddThunkToThunkSequence(absl::make_unique( + mlir_input.thunk_info, std::move(thunks))); + if (operands.size() > 1) { // Emit the tuple as part of the last stage of sorting. // We are currently in the block sorted.in_bounds.after. b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); llvm_ir::EmitTuple( - ir_arrays[operand_count], + ir_arrays.back(), absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_); } return Status::OK(); @@ -1625,9 +1617,10 @@ Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { + CollectivePermuteConfig config = GetCollectivePermuteConfig(hlo); AddThunkToThunkSequence(absl::make_unique( - GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), - GetAllocationSlice(*hlo))); + GetThunkInfo(hlo), std::move(config), + GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo))); return Status::OK(); } @@ -1659,9 +1652,10 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { *crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({})); tuple_element_buffers.push_back(buffers[i].destination_buffer); } + NcclAllReduceConfig config = + GetNcclAllReduceConfig(crs, hlo_module_config_.replica_count()); auto all_reduce_thunk = absl::make_unique( - GetThunkInfo(crs), - /*replica_count=*/hlo_module_config_.replica_count(), + GetThunkInfo(crs), std::move(config), /*buffers=*/std::move(buffers)); if (crs->shape().IsTuple()) { std::vector> thunks; @@ -2253,11 +2247,19 @@ StatusOr> IrEmitterUnnested::BuildWhileThunk( IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + const auto* index_map = ir_emitter_context_->profile_index_map(); + absl::optional condition_profile_index, body_profile_index; + if (index_map) { + condition_profile_index = index_map->GetProfileIndexFor(*condition); + body_profile_index = index_map->GetProfileIndexFor(*body); + } + return std::unique_ptr(new WhileThunk( GetThunkInfo(hlo), GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition->ConsumeThunkSequence(), - ir_emitter_body->ConsumeThunkSequence())); + ir_emitter_body->ConsumeThunkSequence(), condition_profile_index, + body_profile_index)); } StatusOr> IrEmitterUnnested::BuildForThunk( @@ -2273,8 +2275,15 @@ StatusOr> IrEmitterUnnested::BuildForThunk( IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + const auto* index_map = ir_emitter_context_->profile_index_map(); + absl::optional body_profile_index; + if (index_map) { + body_profile_index = index_map->GetProfileIndexFor(*body); + } + return std::unique_ptr(new ForThunk( - GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence())); + GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(), + body_profile_index)); } StatusOr> IrEmitterUnnested::BuildConditionalThunk( @@ -2286,7 +2295,15 @@ StatusOr> IrEmitterUnnested::BuildConditionalThunk( std::vector branch_operands; std::vector branch_thunks; - for (int j = 0; j < hlo->branch_count(); ++j) { + std::vector> branch_profile_indices; + + int branch_count = hlo->branch_count(); + branch_thunks.reserve(branch_count); + branch_profile_indices.reserve(branch_count); + + const auto* index_map = ir_emitter_context_->profile_index_map(); + + for (int j = 0; j < branch_count; ++j) { branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1))); HloComputation* branch_computation = hlo->branch_computation(j); TF_ASSIGN_OR_RETURN( @@ -2295,17 +2312,25 @@ StatusOr> IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_)); TF_CHECK_OK(branch_computation->Accept(ir_emitter.get())); branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); + + absl::optional profile_index; + if (index_map) { + profile_index = index_map->GetProfileIndexFor(*branch_computation); + } + branch_profile_indices.push_back(profile_index); } + ConditionalThunkConfig config = GetConditionalThunkConfig( + hlo, std::move(branch_thunks), std::move(branch_profile_indices)); return std::unique_ptr(new ConditionalThunk( - GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands, - std::move(branch_thunks))); + GetThunkInfo(hlo), std::move(config), + GetAllocationSlice(*hlo->operand(0)), branch_operands)); } Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk, - int unroll_factor) { + int unroll_factor, bool few_waves) { VLOG(3) << bindings_.ToString(); bool multi_output = hlo.shape().IsTuple(); @@ -2316,7 +2341,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( << ShapeUtil::HumanStringWithLayout(hlo.shape()) << " for unroll_factor " << unroll_factor; LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); + element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor, + few_waves); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!multi_output) { @@ -2402,8 +2428,27 @@ Status IrEmitterUnnested::EmitTargetElementLoop( std::unique_ptr kernel_thunk = BuildKernelThunk(&hlo, /*implements_whole_instruction=*/true); + + // Check if we want to schedule grid size that has fewer SM waves. + // This speed up computations in some cases. + bool few_waves = false; + auto few_waves_allow_instr = [](const HloInstruction* instr) { + return instr->IsElementwise() || instr->opcode() == HloOpcode::kParameter || + // We need to make the codegen broadcast aware before enabling + // more broadcast pattern. + (instr->opcode() == HloOpcode::kBroadcast && + instr->dimensions().empty()); + }; + if (hlo.opcode() == HloOpcode::kFusion) { + few_waves = + absl::c_all_of(hlo.fused_instructions_computation()->instructions(), + few_waves_allow_instr); + } else { + few_waves = few_waves_allow_instr(&hlo); + } + Status emit_status = EmitTargetElementLoopInThunk( - hlo, body_emitter, kernel_thunk.get(), unroll_factor); + hlo, body_emitter, kernel_thunk.get(), unroll_factor, few_waves); thunk_sequence_.emplace_back(std::move(kernel_thunk)); return emit_status; @@ -2887,7 +2932,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( current_output); llvm::Value* warp_id = b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize)); - ksl.If(is_zero(thread_id_info.lane_id), [&] { + ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { llvm::Value* shmem_output_addr = shared_to_global(b_.CreateInBoundsGEP( shared_cache, {b_.getInt32(0), constant(j), warp_id})); @@ -2895,7 +2940,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( }); EmitSyncThreads(); - ksl.If(is_zero(warp_id), [&] { + ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { llvm::Value* block_accum_addr = shared_to_global(b_.CreateInBoundsGEP( shared_cache, {b_.getInt32(0), constant(j), thread_id_info.lane_id})); @@ -2915,10 +2960,11 @@ void IrEmitterUnnested::EmitEpilogueForReduction( EmitFullWarpShuffleDownLoopForReduce( reducers[i], element_type, /*block_accum_addr*/ selected_value); - ksl.If(is_zero(thread_id_info.thread_id_x), [&] { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, block_accum_addr)); - }); + ksl.If("reduction_atomic_update", is_zero(thread_id_info.thread_id_x), + [&] { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, block_accum_addr)); + }); }); } else { @@ -2953,10 +2999,11 @@ void IrEmitterUnnested::EmitEpilogueForReduction( b_.CreateICmpULT(thread_id_info.thread_id_x, tiling_kernel_info.output_tile_bounds[kDimY])); - ksl.If(b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, shmem_transposed_addr)); - }); + ksl.If("reduction_atomic_update", + b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { + TF_CHECK_OK(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, shmem_transposed_addr)); + }); } } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index c36f0b7840d..5cc5e206167 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -58,7 +58,6 @@ struct MlirBufferSlice : public BufferSlice { struct MlirEmitterInput { mlir::Operation* op; - absl::string_view name; Thunk::ThunkInfo thunk_info; MlirBufferSlice extra_slice; }; @@ -161,7 +160,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; - Status EmitSortFromMlir(MlirEmitterInput input); + Status EmitSortFromMlir(MlirEmitterInput mlir_input); Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override; @@ -178,7 +177,7 @@ class IrEmitterUnnested : public IrEmitter, // `unroll_factor` is greater than one. Status EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, - KernelThunk* thunk, int unroll_factor); + KernelThunk* thunk, int unroll_factor, bool few_waves = false); Status Postprocess(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index c23e8112cb0..5dbbb2d65da 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -56,7 +56,7 @@ static int64 ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions(const Shape& shape, GpuDeviceInfo gpu_device_info, - int unroll_factor) { + int unroll_factor, bool few_waves) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); @@ -90,6 +90,11 @@ LaunchDimensions CalculateLaunchDimensions(const Shape& shape, } int64 block_count = CeilOfRatio(num_elements, threads_per_block); + if (few_waves) { + threads_per_block = std::min(threads_per_block, int64{128}); + block_count = gpu_device_info.core_count * + (gpu_device_info.threads_per_core_limit / threads_per_block); + } VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " "block) = ceil(%d/%d) = %d", diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index dbe5a037e43..1472141a80e 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -67,7 +67,8 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions(const Shape& shape, GpuDeviceInfo gpu_device_info, - int unroll_factor = 1); + int unroll_factor = 1, + bool few_waves = false); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index 8a1890a0769..fb18b7041b7 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -55,7 +55,7 @@ class Memset32BitValueThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - uint32 value_; + const uint32 value_; const BufferAllocation::Slice dest_; }; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index fcbd9e760c6..4ee78b0874c 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -104,6 +104,13 @@ HloInstruction* SelectPreferredFusionCandidate( std::vector GetProducerConsumerMultiOutputFusionCandidates( const HloInstruction* producer, const HloReachabilityMap& reachability) { std::vector fusion_candidates; + // If there is only one user, and it is not a multi-output fusion node, this + // fusion possibility was already considered and rejected by the FusionMerger + // pass. No need to try again! + if (producer->user_count() == 1 && + !producer->users()[0]->IsMultiOutputFusion()) { + return fusion_candidates; + } for (HloInstruction* consumer : producer->users()) { VLOG(3) << "Looking at producer " << producer->name() << " and its consumer " << consumer->name(); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 4d269665b42..283eb30d15f 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -798,6 +798,86 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { // Check that we don't fuse too many reductions together. TEST_F(MultiOutputFusionTest, SharedMemoryBudget) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + fused_computation0 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation1 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation2 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation3 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation4 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation5 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation6 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation7 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation8 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } + fused_computation9 { + p0 = f32[64,64] parameter(0) + p1 = f32[64,64] parameter(1) + p2 = f32[] parameter(2) + add = f32[64,64] add(p0, p1) + ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0}, + to_apply=scalar_add_computation + } ENTRY computation { zero = f32[] constant(0) param0 = f32[64,64] parameter(0) @@ -810,36 +890,16 @@ TEST_F(MultiOutputFusionTest, SharedMemoryBudget) { param7 = f32[64,64] parameter(7) param8 = f32[64,64] parameter(8) param9 = f32[64,64] parameter(9) - add0 = f32[64,64] add(param0, param1) - add1 = f32[64,64] add(param1, param2) - add2 = f32[64,64] add(param2, param3) - add3 = f32[64,64] add(param3, param4) - add4 = f32[64,64] add(param4, param5) - add5 = f32[64,64] add(param5, param6) - add6 = f32[64,64] add(param6, param7) - add7 = f32[64,64] add(param7, param8) - add8 = f32[64,64] add(param8, param9) - add9 = f32[64,64] add(param9, param0) - out0 = f32[64] reduce(f32[64,64] add0, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out1 = f32[64] reduce(f32[64,64] add1, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out2 = f32[64] reduce(f32[64,64] add2, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out3 = f32[64] reduce(f32[64,64] add3, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out4 = f32[64] reduce(f32[64,64] add4, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out5 = f32[64] reduce(f32[64,64] add5, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out6 = f32[64] reduce(f32[64,64] add6, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out7 = f32[64] reduce(f32[64,64] add7, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out8 = f32[64] reduce(f32[64,64] add8, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation - out9 = f32[64] reduce(f32[64,64] add9, f32[] zero), dimensions={0}, - to_apply=scalar_add_computation + out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0 + out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1 + out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2 + out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3 + out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4 + out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5 + out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6 + out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7 + out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8 + out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9 ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9) } )")) diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 25ab9a7ce6e..b13f71c5a13 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -514,11 +514,40 @@ void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr handle, // header. In particular, this stores the thunk's cache of all NcclCliques it's // ever used. This causes those cliques to stay alive as long as the thunk // lives, which is how we avoid expensive reinitialization of NCCL cliques. -struct NcclAllReduceThunk::AuxData { +struct NcclAllReduceConfig::AuxData { tensorflow::mutex mu; absl::flat_hash_set> cliques TF_GUARDED_BY(mu); }; +NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default; +NcclAllReduceConfig::~NcclAllReduceConfig() = default; + +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr, + int64 replica_count) { + NcclAllReduceConfig config; + config.operand_count = instr->operands().size(); + config.operand_element_type.reserve(config.operand_count); + for (int i = 0; i < config.operand_count; i++) { + config.operand_element_type.push_back( + instr->operand(i)->shape().element_type()); + } + config.replica_count = replica_count; + config.replica_groups = instr->replica_groups(); + auto reduction_kind = MatchReductionComputation(instr->to_apply()); + CHECK(reduction_kind.has_value()); + config.reduction_kind = reduction_kind.value(); + + if (instr->channel_id().has_value()) { + config.collective_op_kind = RendezvousKey::kCrossModule; + config.op_id = instr->channel_id().value(); + } else { + config.collective_op_kind = RendezvousKey::kCrossReplica; + config.op_id = static_cast(instr->GetModule()->unique_id()); + } + config.aux_data = std::make_unique(); + return config; +} + /*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) { auto operands_are_supported = [crs]() { return absl::c_all_of(crs->operands(), [](HloInstruction* operand) { @@ -541,14 +570,12 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() { } NcclAllReduceThunk::NcclAllReduceThunk( - ThunkInfo thunk_info, int64 replica_count, + ThunkInfo thunk_info, NcclAllReduceConfig&& config, std::vector buffers) : Thunk(Thunk::kNcclAllReduce, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), - replica_count_(replica_count), - buffers_(std::move(buffers)), - aux_data_(absl::make_unique()) { - CHECK_EQ(hlo_instruction_->operand_count(), buffers_.size()); + config_(std::move(config)), + buffers_(std::move(buffers)) { + CHECK_EQ(config_.operand_count, buffers_.size()); } // Figures out which devices (named by their replica-ids) are participating in @@ -558,7 +585,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - auto* instr = Cast(hlo_instruction_); int64 local_device_ordinal = params.stream->parent()->device_ordinal(); GlobalDeviceId global_device_id; if (params.gpu_global_device_ids) { @@ -574,10 +600,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { // the same collective group as the caller. TF_ASSIGN_OR_RETURN( std::vector global_participating_replicas, - GetParticipatingReplicas(global_device_id, instr->replica_groups(), - replica_count_, *params.device_assn)); + GetParticipatingReplicas(global_device_id, config_.replica_groups, + config_.replica_count, *params.device_assn)); if (IsGlobalNcclConfig() && - global_participating_replicas.size() != replica_count_) { + global_participating_replicas.size() != config_.replica_count) { return InvalidArgument( "Partial replica groups are not allowed when using NCCL_COMM_ID " "environment configuration."); @@ -605,10 +631,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { } absl::c_sort(global_devices); - // Find or create the rendezvous for this collective operation. - RendezvousKey rendezvous_key = RendezvousKey::FromInstruction( - params.run_id, global_devices, local_devices.size(), hlo_instruction_); - + // Create the rendezvous for this collective operation. + RendezvousKey rendezvous_key(params.run_id, global_devices, + local_devices.size(), config_.collective_op_kind, + config_.op_id); if (VLOG_IS_ON(2)) { std::vector local_participants; for (const auto& entry : local_devices) { @@ -633,15 +659,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { params.buffer_allocations->GetDeviceAddress(buffer.source_buffer); pbuffer.destination_data = params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer); - pbuffer.primitive_type = - hlo_instruction_->operand(i)->shape().element_type(); + pbuffer.primitive_type = config_.operand_element_type[i]; participant.buffers.push_back(pbuffer); } participant.local_devices = std::move(local_devices); participant.nccl_unique_id_callback = params.nccl_unique_id_callback; - auto reduction_kind = MatchReductionComputation(hlo_instruction_->to_apply()); - CHECK(reduction_kind.has_value()); - participant.reduction_kind = *reduction_kind; + participant.reduction_kind = config_.reduction_kind; auto rendezvous_factory = [](const RendezvousKey& k) { return absl::make_unique(k); @@ -658,13 +681,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { // Keep the clique we used alive for as long as this Thunk lives. Creating // new NCCL cliques is expensive, and this is how we avoid thrashing them. { - tensorflow::mutex_lock lock(aux_data_->mu); - aux_data_->cliques.insert(std::move(clique)); + tensorflow::mutex_lock lock(config_.aux_data->mu); + config_.aux_data->cliques.insert(std::move(clique)); } return Status::OK(); } -NcclAllReduceThunk::~NcclAllReduceThunk() {} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index cbd4fd3aa51..20e4adef7b1 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -18,11 +18,13 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/collective_ops_utils.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -30,6 +32,30 @@ limitations under the License. namespace xla { namespace gpu { +struct NcclAllReduceConfig { + int64 operand_count; + std::vector operand_element_type; + int64 replica_count; + std::vector replica_groups; + ReductionKind reduction_kind; + RendezvousKey::CollectiveOpKind collective_op_kind; + int64 op_id; + + NcclAllReduceConfig() = default; + NcclAllReduceConfig(NcclAllReduceConfig &&); + ~NcclAllReduceConfig(); + + // Extra data stored in NcclAllReduceThunk whose types we don't want exposed + // in the header file. (This is mainly because the implementation of + // NcclAllReduceThunk is different depending on whether CUDA is enabled in the + // build, and we don't want to expose *that* mess in the header.) + struct AuxData; + std::unique_ptr aux_data; +}; + +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, + int64 replica_count); + // Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. class NcclAllReduceThunk : public Thunk { public: @@ -56,9 +82,8 @@ class NcclAllReduceThunk : public Thunk { BufferAllocation::Slice source_buffer; BufferAllocation::Slice destination_buffer; }; - NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count, + NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig &&config, std::vector buffers); - ~NcclAllReduceThunk() override; Status ExecuteOnStream(const ExecuteParams& params) override; @@ -67,16 +92,8 @@ class NcclAllReduceThunk : public Thunk { static bool CanImplement(const HloInstruction* crs); private: - // Extra data stored in NcclAllReduceThunk whose types we don't want exposed - // in the header file. (This is mainly because the implementation of - // NcclAllReduceThunk is different depending on whether CUDA is enabled in the - // build, and we don't want to expose *that* mess in the header.) - struct AuxData; - - const HloInstruction* hlo_instruction_; - const int64 replica_count_; + const NcclAllReduceConfig config_; const std::vector buffers_; - std::unique_ptr aux_data_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 83066a4addf..6eef1b9f0b9 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -14,26 +14,34 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" + #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, +OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) { + OutfeedConfig config; + config.input_shape = instr->operand(0)->shape(); + return config; +} + +OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, ShapeTree outfeed_slices) : Thunk(Kind::kOutfeed, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), + config_(std::move(config)), outfeed_slices_(std::move(outfeed_slices)) {} Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; auto& buffer_allocations = *params.buffer_allocations; - VLOG(2) << "Outfeeding from GPU: " << hlo_instruction_->ToString(); + VLOG(2) << "Outfeeding from GPU"; auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); @@ -42,13 +50,12 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { outfeed_manager->BlockingGetNextDestination(); // Nothing to be done for empty tuples. - if (ShapeUtil::IsEmptyTuple(hlo_instruction_->operand(0)->shape())) { + if (ShapeUtil::IsEmptyTuple(config_.input_shape)) { return Status::OK(); } - CHECK(ShapeUtil::Compatible(hlo_instruction_->operand(0)->shape(), - outfeed_buffers->shape())) + CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape())) << "XLA program outfeed request of shape " - << hlo_instruction_->operand(0)->shape().ToString() + << config_.input_shape.ToString() << " did not match the runtime's outfeed buffer of shape " << outfeed_buffers->shape().ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index 9174e605783..60c64858ee7 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -25,6 +25,12 @@ limitations under the License. namespace xla { namespace gpu { +struct OutfeedConfig { + Shape input_shape; +}; + +OutfeedConfig GetOutfeedConfig(const HloInstruction* instr); + // A thunk that outfeeds data. Data must be already resident on the host. This // thunk performs a host to device copy from the buffer allocated for the // outfeed op to the host location. @@ -32,7 +38,7 @@ class OutfeedThunk : public Thunk { public: // Constructs a OutfeedThunk that copies data to the host-side // outfeed queue from the buffers in the given shape tree. - OutfeedThunk(ThunkInfo thunk_info, + OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config, ShapeTree outfeed_slices); OutfeedThunk(const OutfeedThunk&) = delete; @@ -41,7 +47,7 @@ class OutfeedThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; + const OutfeedConfig config_; const ShapeTree outfeed_slices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 6b7b31e8288..45c4f25d8e8 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +59,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -122,6 +124,12 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } + if (base_index != nullptr) { + linear_index_base = + b_->CreateAdd(linear_index_base, base_index, "linear_index_plus_base", + /*HasNUW=*/true, /*HasNSW=*/true); + } + array_indices.emplace_back(linear_index_base, shape_, b_); for (int i = 1; i < unroll_factor_; ++i) { llvm::Value* linear_index = @@ -147,5 +155,43 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, return array_indices; } +Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = b_->getInt64Ty(); + } + int64 total_threads = launch_dimensions_.launch_bound(); + int64 num_elements = ShapeUtil::ElementsIn(shape_); + // If all the elements are handled by the current threads, no need + // to add a loop inside the kernel. + if (total_threads * unroll_factor_ >= num_elements) { + VLOG(1) << "ParallelLoopEmitter::EmitLoop fallback"; + return LoopEmitter::EmitLoop(loop_name, index_type); + } + + KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll); + auto constant = [&](int64 val) { + return llvm::ConstantInt::get(index_type, val); + }; + + TF_RETURN_IF_ERROR(ksl.ForWithStatus( + "loop", constant(0), constant(num_elements), + constant(total_threads * unroll_factor_), [&](llvm::Value* base_indvar) { + for (const llvm_ir::IrArray::Index& array_index : + EmitIndexAndSetExitBasicBlock(loop_name, index_type, + base_indvar)) { + TF_RETURN_IF_ERROR(body_emitter_(array_index)); + } + return Status::OK(); + })); + + // Set the insertion point of b_ to the loop exit, so that + // code emitted for later instructions will be correctly placed. + if (exit_bb_ != nullptr) { + b_->SetInsertPoint(exit_bb_); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 0a6b5430b23..5e142ec3832 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -57,7 +57,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) override; + + Status EmitLoop(absl::string_view loop_name = "", + llvm::Type* index_type = nullptr); private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 8ea7c57c978..7293b1485fc 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -319,5 +319,35 @@ void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, } } +StatusOr GetDNNConvKindFromCudnnConvKind( + CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kBackwardFilter: + return se::dnn::BACKWARD_FILTER; + case CudnnConvKind::kBackwardInput: + return se::dnn::BACKWARD_DATA; + case CudnnConvKind::kForward: + return se::dnn::FORWARD; + default: + break; + } + return InternalError("Unexpected convolution kind"); +} + +StatusOr GetDNNDataTypeFromPrimitiveType( + PrimitiveType type) { + switch (type) { + case F16: + return se::dnn::ToDataType::value; + case F32: + return se::dnn::ToDataType::value; + case F64: + return se::dnn::ToDataType::value; + default: + break; + } + return InternalError("Unsupported convolution datatype"); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 6696d1957b3..2b58496e05c 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" @@ -86,6 +87,10 @@ se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config); void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, int64* rng_state, se::DeviceMemoryBase buffer); +StatusOr GetDNNConvKindFromCudnnConvKind( + CudnnConvKind kind); +StatusOr GetDNNDataTypeFromPrimitiveType(PrimitiveType type); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 2f139563b4a..200829efddb 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -148,8 +148,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedSine) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} sine(p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} sine(p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -182,8 +182,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedCosine) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} cosine(p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} cosine(p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -216,8 +216,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedPower) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} power(p0, p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} power(p0, p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -241,8 +241,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedAtan2) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} atan2(p0, p0) + p0 = f32[16000000]{0} parameter(0) + ROOT s = f32[16000000]{0} atan2(p0, p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 7a9fedec629..64b685db379 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -69,10 +69,6 @@ class Thunk { }; struct ThunkInfo { - // Optional. It's only used by subclasses which haven't been migrated away - // from HloInstructions. Once the migration is done, Thunks should be fully - // serializable. - const HloInstruction* hlo_instruction = nullptr; absl::optional profile_index; std::string profile_annotation; }; diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 690d0c9de56..4c6c5bb846d 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -19,9 +19,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_runner.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" @@ -37,6 +39,65 @@ limitations under the License. namespace xla { namespace gpu { +namespace { +void CheckBatchNormInputOutputPrimitivetypeAreValid(const HloInstruction* hlo) { + // All input and output statistics variables must be F32. Also, the last + // operand for CudnnBatchNormForwardInference, CudnnBatchNormForwardTraining, + // and CudnnBatchNormBackward is the feature_index which must be S64. + // The allowed types for non-statistics variables are as follows: + // CudnnBatchNormForwardInference: + // operand[0]: {half, float} + // out[0]: {half, float} + // CudnnBatchNormForwardTraining: + // operand[0]: {half, float} + // out[0]: {half, float} + // CudnnBatchNormBackward: + // operand[0]: {half, float} + // operand[4]: {half, float} + // out[0]: {half, float} + // Note non-statistics inputs and outputs mentioned above should be of the + // same type. + + // Check Inputs. + int64 num_operands = hlo->operand_count(); + PrimitiveType operand_primitive_type = + hlo->operand(0)->shape().element_type(); + CHECK(operand_primitive_type == F16 || operand_primitive_type == F32) + << "Not yet implemented"; + + for (int i = 1; i < num_operands - 2; i++) { + if (hlo->custom_call_target() == kCudnnBatchNormBackwardCallTarget && + i == 4) { + // The first operand to batchnorm grad is the input and the 4th operand is + // the grad_output, both of which can be Eigen::half. + CHECK_EQ(hlo->operand(i)->shape().element_type(), operand_primitive_type) + << "Invalid datatype"; + continue; + } + CHECK_EQ(hlo->operand(i)->shape().element_type(), F32) + << "Not yet implemented"; + } + + // The last operand is the feature index which must be int64. + CHECK_EQ(hlo->operand(num_operands - 1)->shape().element_type(), S64) + << "Not yet implemented"; + + // Check Outputs. + if (hlo->shape().IsTuple()) { + CHECK_EQ(hlo->shape().tuple_shapes(0).element_type(), + operand_primitive_type) + << "Invalid datatype"; + + for (int j = 1; j < hlo->shape().tuple_shapes_size(); j++) { + CHECK_EQ(hlo->shape().tuple_shapes(j).element_type(), F32) + << "Not yet implemented"; + } + } else { + CHECK_EQ(hlo->shape().element_type(), operand_primitive_type) + << "Invalid datatype"; + } +} +} // namespace std::unique_ptr ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); return absl::make_unique( @@ -72,15 +133,14 @@ std::unique_ptr ThunkEmitter::BuildTriangularSolveThunk( std::unique_ptr ThunkEmitter::BuildGemmThunk( const HloInstruction* inst) { - auto config_or = inst->backend_config(); - GemmBackendConfig gemm_config = std::move(config_or.ValueOrDie()); + GpuGemmConfig config = GetGpuGemmConfig(inst); const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); // The bias is passed inside the output buffer. If those buffers are shared // we can just use it, otherwise copy the bias values into the output buffer // first. - if (gemm_config.beta() != 0.0) { + if (config.backend_config.beta() != 0.0) { const HloInstruction* bias = inst->operand(2); CHECK_EQ(bias->shape(), inst->shape()); if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { @@ -91,22 +151,22 @@ std::unique_ptr ThunkEmitter::BuildGemmThunk( /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()))); thunks.push_back(absl::make_unique( - context_->GetThunkInfo(inst), + context_->GetThunkInfo(inst), std::move(config), GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/false, std::move(gemm_config))); + /*implements_whole_instruction=*/false)); return absl::make_unique(context_->GetThunkInfo(inst), std::move(thunks)); } } return absl::make_unique( - context_->GetThunkInfo(inst), + context_->GetThunkInfo(inst), std::move(config), GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/true, std::move(gemm_config)); + /*implements_whole_instruction=*/true); } std::unique_ptr ThunkEmitter::BuildInfeedThunk( @@ -133,8 +193,9 @@ std::unique_ptr ThunkEmitter::BuildOutfeedThunk( *slice = status_or_slice.ValueOrDie(); } }); + OutfeedConfig config = GetOutfeedConfig(inst); return absl::make_unique(context_->GetThunkInfo(inst), - std::move(slices)); + std::move(config), std::move(slices)); } Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { @@ -154,16 +215,20 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { CHECK(feature_index->IsConstant()); int64 feature_index_value = feature_index->literal().Get({}); + CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(0)->shape())); + CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call); + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence( absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), /*mean=*/GetAllocationSlice(*custom_call->operand(3)), /*variance=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output=*/GetAllocationSlice(*custom_call))); return Status::OK(); } @@ -183,14 +248,14 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_data = GetAllocationSlice(*custom_call, {0}); auto output_mean = GetAllocationSlice(*custom_call, {1}); auto output_inv_stddev = GetAllocationSlice(*custom_call, {2}); + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence( absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output_data=*/output_data, /*output_mean=*/output_mean, /*output_inv_stddev=*/output_inv_stddev, @@ -212,15 +277,22 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_data = GetAllocationSlice(*custom_call, {0}); auto output_grad_scale = GetAllocationSlice(*custom_call, {1}); auto output_grad_offset = GetAllocationSlice(*custom_call, {2}); + CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(0)->shape())); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + custom_call->operand(4)->shape())); + CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call); + + CudnnBatchNormConfig config = GetCudnnBatchNormConfig( + custom_call, epsilon_value, feature_index_value); AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(custom_call), + context_->GetThunkInfo(custom_call), std::move(config), /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*mean=*/GetAllocationSlice(*custom_call->operand(2)), /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, /*output_grad_data=*/output_grad_data, /*output_grad_scale=*/output_grad_scale, /*output_grad_offset=*/output_grad_offset, @@ -238,9 +310,13 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = GetAllocationSlice(*custom_call, {0}); auto scratch_slice = GetAllocationSlice(*custom_call, {1}); + TF_ASSIGN_OR_RETURN( + GpuConvConfig config, + GetGpuConvConfig(Cast(custom_call))); AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(custom_call), std::move(operand_slices), - conv_result_slice, scratch_slice, tuple_result_slice)); + context_->GetThunkInfo(custom_call), std::move(config), + std::move(operand_slices), conv_result_slice, scratch_slice, + tuple_result_slice)); return Status::OK(); } @@ -310,11 +386,26 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { return slices; }; std::vector> operand_slices; - for (const auto* operand : custom_call->operands()) { + for (int64 i = 0; i < custom_call->operand_count(); i++) { + const auto* operand = custom_call->operand(i); operand_slices.push_back(get_slices_for_instr(operand)); + const auto& s1 = operand_slices.back().shape(); + const auto& s2 = operand->shape(); + CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat( + "Shape mismatch between operand shape and " + "slice shape for operand %d: %s vs %s", + i, s1.ToString(), s2.ToString()); } ShapeTree result_slices = get_slices_for_instr(custom_call); + CHECK(ShapeUtil::Equal(custom_call->shape(), result_slices.shape())) + << absl::StreamFormat( + "Shape mismatch between instr->shape() and " + "result_slices.shape(): " + "%s vs %s.", + custom_call->shape().ToString(), + result_slices.shape().ToString()); + AddThunkToThunkSequence(absl::make_unique( context_->GetThunkInfo(custom_call), call_target, std::move(operand_slices), std::move(result_slices), @@ -385,7 +476,6 @@ Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo( const HloInstruction* hlo) const { CHECK(hlo); Thunk::ThunkInfo info; - info.hlo_instruction = hlo; info.profile_annotation = absl::StrFormat( "Thunk:#hlo_op=%s,hlo_module=%s#", hlo->name(), hlo->GetModule()->name()); return info; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 792479df4ac..6397ad3bee0 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -27,9 +27,10 @@ WhileThunk::WhileThunk( ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence) + std::unique_ptr body_thunk_sequence, + absl::optional condition_profile_index, + absl::optional body_profile_index) : Thunk(Kind::kWhile, thunk_info), - hlo_instruction_(thunk_info.hlo_instruction), condition_result_buffer_index_(condition_result_buffer_index), // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_ // and body_thunk_sequence_ constructors because these SequentialThunks @@ -38,7 +39,9 @@ WhileThunk::WhileThunk( condition_thunk_sequence_(absl::make_unique( ThunkInfo(), std::move(*condition_thunk_sequence))), body_thunk_sequence_(absl::make_unique( - ThunkInfo(), std::move(*body_thunk_sequence))) {} + ThunkInfo(), std::move(*body_thunk_sequence))), + condition_profile_index_(condition_profile_index), + body_profile_index_(body_profile_index) {} Status WhileThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { @@ -62,7 +65,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { profiler.StartHloComputation(); VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction_->while_condition()); + profiler.FinishHloComputation(condition_profile_index_); // Copy the result of condition computation and break the loop if 'false'. bool condition_result; @@ -86,7 +89,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction_->while_body()); + profiler.FinishHloComputation(body_profile_index_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 707bac15bb2..707edbdc192 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -42,7 +42,9 @@ class WhileThunk : public Thunk { WhileThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& condition_result_buffer_index, std::unique_ptr condition_thunk_sequence, - std::unique_ptr body_thunk_sequence); + std::unique_ptr body_thunk_sequence, + absl::optional condition_profile_index, + absl::optional body_profile_index); WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; @@ -51,10 +53,11 @@ class WhileThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: - const HloInstruction* hlo_instruction_; const BufferAllocation::Slice condition_result_buffer_index_; std::unique_ptr condition_thunk_sequence_; std::unique_ptr body_thunk_sequence_; + const absl::optional condition_profile_index_; + const absl::optional body_profile_index_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 38c4da1f182..6323d0903a4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -93,10 +93,13 @@ HloComputation::HloComputation( } HloInstruction* HloComputation::AddInstruction( - std::unique_ptr instruction) { + std::unique_ptr instruction, const std::string& new_name) { CHECK(instruction->opcode() != HloOpcode::kParameter) << "Parameter instructions cannot be added to a computation after " << "it has been built"; + if (!new_name.empty()) { + instruction->SetAndSanitizeName(new_name); + } return AddInstructionInternal(std::move(instruction)); } @@ -315,6 +318,8 @@ Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction, (*inst_it->second)->set_parent(nullptr); to_be_deleted_.emplace_back(inst_it->second->release()); to_be_deleted_.back()->DetachFromOperandsAndUsers(); + // Clear all operands to avoid Null operands. + to_be_deleted_.back()->RemoveAllOperands(); to_be_deleted_.back()->MarkAsDead(); instructions_.erase(inst_it->second); instruction_iterators_.erase(inst_it); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 1dcf1d9d7d3..d618a527070 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -122,7 +122,8 @@ class HloComputation { // Add an instruction to the computation. The computation takes ownership of // the instruction. - HloInstruction* AddInstruction(std::unique_ptr instruction); + HloInstruction* AddInstruction(std::unique_ptr instruction, + const std::string& new_name = ""); // Remove the param_no'th parameter from the computation. // Note this is only applicatable to the computation for the fusion diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b91ec9d86ee..4fb7edd0104 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2300,8 +2300,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector input_index(operand_shape.dimensions_size()); std::vector update_index(updates_shape.dimensions_size()); - std::vector input_scatter_index_clamped( - operand_shape.dimensions_size()); UpdateScatterIndexToInputIndex update_scatter_index_to_input_index( &scatter->scatter_dimension_numbers(), /*input_shape=*/operand_shape, @@ -2790,7 +2788,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // bound, call `f` with the base index. static void IterateThroughWindow( const Shape& window_shape, const Window& window, const Shape& base_shape, - const absl::Span& window_count_index, + const absl::Span window_count_index, const std::function&)>& f) { const int64 rank = base_shape.rank(); DimensionVector window_index(rank); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e21ae719e4d..9675a2f0f0d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1911,6 +1911,8 @@ class HloInstruction { // by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); + void RemoveAllOperands() { operands_.clear(); } + void RemoveOperandAt(int index) { operands_.erase(operands_.begin() + index); } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 749193a83ef..3c44b390969 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -387,6 +387,12 @@ TokKind HloLexer::LexNumberOrPattern() { return TokKind::kNegInf; } + static LazyRE2 neg_nan = {"-nan"}; + if (RE2::Consume(&consumable, *neg_nan)) { + current_ptr_ = consumable.begin(); + return TokKind::kNegNan; + } + return TokKind::kError; } @@ -502,6 +508,8 @@ string TokKindToString(TokKind kind) { return "kw_nan"; case TokKind::kw_inf: return "kw_inf"; + case TokKind::kNegNan: + return "kNegNan"; case TokKind::kNegInf: return "kNegInf"; case TokKind::kPrimitiveType: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index b8c7debaab4..4068ad76581 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -65,6 +65,7 @@ enum class TokKind { kw_nan, kw_inf, + kNegNan, // -nan kNegInf, // -inf // Typed tokens. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 37bdeaa1073..d04a7695f3c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -2717,6 +2717,7 @@ bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { case TokKind::kInt: case TokKind::kDecimal: case TokKind::kw_nan: + case TokKind::kNegNan: case TokKind::kw_inf: case TokKind::kNegInf: { add_one_elem_seen(); @@ -4374,6 +4375,9 @@ bool HloParserImpl::ParseDouble(double* result) { case TokKind::kw_nan: *result = std::numeric_limits::quiet_NaN(); break; + case TokKind::kNegNan: + *result = -std::numeric_limits::quiet_NaN(); + break; case TokKind::kw_inf: *result = std::numeric_limits::infinity(); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d220d735622..3cb9a1c564b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -2120,6 +2120,19 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] { EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); } +TEST_F(HloParserTest, NegativeNan) { + const string original = R"(HloModule NegativeNan_module + +ENTRY %NegativeNan () -> bf16[2] { + ROOT %constant = bf16[2]{0} constant({-nan, -nan}) +} + +)"; + auto result = ParseAndReturnUnverifiedModule(original); + EXPECT_EQ(Status::OK(), result.status()); + EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original); +} + TEST_F(HloParserTest, AttributesAnyOrder) { const string original = R"(HloModule any_order_module diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 1de231a9a86..1533c53ba45 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -48,7 +48,8 @@ class HloPassFix : public Pass { ++iteration_count; if (iteration_count == kLimit) { VLOG(1) << "Unexpectedly high number of iterations in HLO passes '" - << Pass::name() << "' exiting fixed point loop."; + << Pass::name() << "' for module '" << module->name() + << "'. Exiting fixed point loop."; // Return false in case this is fixed point is nested. return false; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 3b7b0b61f0a..74c385f16bd 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -58,6 +58,7 @@ StatusOr HloPassPipeline::RunPassesInternal( TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); bool changed = false; for (HloPassInterface* pass : passes) { + XLA_SCOPED_LOGGING_TIMER(absl::StrCat("HLO pass: ", pass->name())); absl::string_view pass_name = pass->name(); VLOG(1) << " HLO pass " << pass_name; VLOG(2) << " Module hash " << hlo->Hash(); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index e9d6191f945..59b1ac31e9b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1521,12 +1521,13 @@ StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, << ") to" << compact_shape.ToString(true); HloComputation* computation = best->parent(); - HloInstruction* compressed = computation->AddInstruction( - HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best)); + HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best), + /*new_name=*/best->name() + ".remat_compressed"); HloInstruction* uncompressed = computation->AddInstruction( - HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed)); + HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed), + /*new_name=*/best->name() + ".remat_uncompressed"); Item* compressed_item = instruction_list->CreateItem(compressed); compressed_item->placed = true; diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc index dec119d8aba..2d1edbefd97 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis.cc @@ -129,6 +129,13 @@ bool DetermineHloInstructionIsReplicated( return true; } + if (hlo->opcode() == HloOpcode::kCustomCall && + (hlo->custom_call_target() == "X64SplitLow" || + hlo->custom_call_target() == "X64SplitHigh" || + hlo->custom_call_target() == "X64Combine")) { + return all_operands_replicated(hlo); + } + if (hlo->IsElementwise() || // hlo->opcode() == HloOpcode::kConcatenate || // hlo->opcode() == HloOpcode::kConvolution || // diff --git a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc index c2d86e808c2..cc0f4c86f4d 100644 --- a/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_replication_analysis_test.cc @@ -501,6 +501,36 @@ ENTRY entry { FindInstruction(module.get(), "conditional"), {1})); } +TEST_F(HloReplicationAnalysisTest, X64SplitCombine) { + const string module_str = R"( +HloModule SimpleTupleSelect + +ENTRY entry { + param = (f64[]) parameter(0) + gte = f64[] get-tuple-element(param), index=0 + param-low = f32[] custom-call(gte), custom_call_target="X64SplitLow" + param-high = f32[] custom-call(gte), custom_call_target="X64SplitHigh" + ROOT result-combine = f64[] custom-call(param-low, param-high), custom_call_target="X64Combine" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(module_str)); + auto param = module->entry_computation()->parameter_instruction(0); + param->set_parameter_replicated_at_leaf_buffers(absl::Span{true}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloReplicationAnalysis::Run( + module.get(), /*cross_partition_spmd=*/false)); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "gte"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "param-low"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "param-high"), {})); + EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( + FindInstruction(module.get(), "result-combine"), {})); +} + TEST_F(HloReplicationAnalysisTest, SimpleTupleSelect) { const string module_str = R"( HloModule SimpleTupleSelect diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 3a5e7ca6f40..0d71c6d49ed 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -211,8 +211,7 @@ static std::vector ExecutionInputsFromScopedShapedBuffers( *buffer_tree.mutable_element(index) = execution_input_buffer; } }); - execution_inputs.emplace_back(std::move(buffer_tree), - input_buffer.on_host_shape()); + execution_inputs.emplace_back(std::move(buffer_tree)); } return execution_inputs; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 4244cdaceea..977f6ee8ea6 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -56,6 +56,13 @@ HloSharding HloSharding::PartialTile( HloSharding HloSharding::PartialTile( const Array& tile_assignment_last_dim_replicate) { + if (tile_assignment_last_dim_replicate.dimensions().back() == 1) { + auto new_tile_dims = tile_assignment_last_dim_replicate.dimensions(); + new_tile_dims.pop_back(); + auto fully_tiled = tile_assignment_last_dim_replicate; + fully_tiled.Reshape(new_tile_dims); + return HloSharding(fully_tiled); + } std::vector> sorted_groups( tile_assignment_last_dim_replicate.num_elements() / tile_assignment_last_dim_replicate.dimensions().back()); diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 4b58a0cb04a..c134b7ba6a6 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -121,9 +121,9 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/platform:types", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.cc b/tensorflow/compiler/xla/service/interpreter/executable_base.cc index 745750bffe1..00998994c0a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable_base.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.cc @@ -56,7 +56,7 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( } for (auto& argument : arguments) { const ShapeTree& buffers = argument.Buffers(); - argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), + argument_buffers.push_back(ShapedBuffer(buffers.shape(), /*platform=*/nullptr, /*device_ordinal=*/device_ordinal)); auto in_it = buffers.begin(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 59f4466980f..9940b032558 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -158,10 +158,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:fusion_node_indexing_evaluation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 6aa33a10d64..f8514a6cba3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include +#include -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -44,37 +44,32 @@ using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) { indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { - if (llvm::Value* generated_value = FindOrDefault( - generated_value_cache_[hlo], index.multidim(), nullptr)) { - llvm::BasicBlock* generated_value_bb = nullptr; - if (auto* generated_instruction = - llvm::dyn_cast(generated_value)) { - generated_value_bb = generated_instruction->getParent(); - } - // Ideally, we should be able to reuse the cached generated value if it - // dominates the current insertion block. However, the check for dominance - // can be expensive and unreliable when the function is being constructed. - // - // It's also worth experimenting what if we don't do caching at all. - // LLVM's CSE or GVN should be able to easily merge common subexpressions - // that would be regenerated without caching. But this might increase the - // JIT compilation time. - if (generated_value_bb == nullptr || - generated_value_bb == b_->GetInsertBlock()) { + auto cache = generated_value_cache_.find(hlo); + if (cache != generated_value_cache_.end()) { + auto key = std::make_pair(b_->GetInsertBlock(), index.multidim()); + if (llvm::Value* generated_value = + FindOrDefault(cache->second, key, nullptr)) { + VLOG(3) << "The cached generated value is reused."; + return generated_value; + } + auto null_key = std::make_pair(nullptr, index.multidim()); + if (llvm::Value* generated_value = + FindOrDefault(cache->second, null_key, nullptr)) { VLOG(3) << "The cached generated value is reused."; return generated_value; } - VLOG(3) << "The cached generated value can't be reused, because it is in " - "a different BB (" - << generated_value_bb->getName().str() - << ") from the current insertion block (" - << b_->GetInsertBlock()->getName().str() << ")."; } TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value, elemental_emitter_->MakeElementGenerator( hlo, indexed_generators_)(index)); - generated_value_cache_[hlo][index.multidim()] = generated_value; + llvm::BasicBlock* generated_value_bb = nullptr; + if (auto* generated_instruction = + llvm::dyn_cast(generated_value)) { + generated_value_bb = generated_instruction->getParent(); + } + generated_value_cache_[hlo][std::make_pair( + generated_value_bb, index.multidim())] = generated_value; return generated_value; }; return Status::OK(); @@ -214,84 +209,15 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient( if (consumer->opcode() != HloOpcode::kFusion) { return false; } - // Collects for each instruction in the fusion node from which (indirect) - // users newly created index values are passed. Roughly speaking, we reuse - // index values if the shapes are equal when ignoring the element type (we may - // reuse also if the shape change is a bitcast, but we don't consider that - // here). By ignoring potential reuses our estimate whether the fusion emitter - // is inefficient is a bit more conservative than necessary. - absl::flat_hash_map> - indexing_users; - // Stores the number of different index accesses for each instruction in the - // fusion node. The fusion emitter caches access with the same index, so this - // value indicates how many times a specific instruction will be emitted. - absl::flat_hash_map index_usage_count; - index_usage_count[consumer] = 1; - - auto evaluate_fusion_computation = [&indexing_users, &index_usage_count]( - const HloInstruction* fusion) { - auto postorder = - fusion->fused_instructions_computation()->MakeInstructionPostOrder(); - std::reverse(postorder.begin(), postorder.end()); - for (const auto* instruction : postorder) { - if (instruction->opcode() == HloOpcode::kParameter) { - continue; - } - int64& total = index_usage_count[instruction]; - if (indexing_users[instruction].empty()) { - total = index_usage_count[fusion]; - } else { - total = 0; - for (const auto* user : indexing_users[instruction]) { - total += index_usage_count[user]; - } - } - for (const auto* operand : instruction->operands()) { - // For simplicity we assume that all shape and layout changing - // operations except Transposes invalidate index reuse. Transposes are - // special: although they are shape changing, we can reuse the - // multi-dimensional index for the operand by permuting it. - if (instruction->opcode() == HloOpcode::kTranspose || - Shape::Equal().IgnoreElementType()(operand->shape(), - instruction->shape())) { - // If the index is reused, it means the operand gets index values - // from the same set of (indirect) users as 'instruction' itself. - indexing_users[operand].insert(indexing_users[instruction].begin(), - indexing_users[instruction].end()); - } else { - // If the index is not reused, it means 'instruction' computes a - // new index derived from the index it gets. - indexing_users[operand].insert(instruction); - } - } - } - }; - evaluate_fusion_computation(consumer); - - // Also account for the 'producer' if it would be fused. Find the operand it - // corresponds to. - for (int64 operand_num = 0; operand_num < consumer->operand_count(); - ++operand_num) { - if (consumer->operand(operand_num) == producer) { - auto instruction = consumer->fused_parameter(operand_num); - int64& total = index_usage_count[producer]; - total = 0; - for (const auto* user : indexing_users[instruction]) { - total += index_usage_count[user]; - } - break; - } + FusionNodeIndexingEvaluation eval_consumer(consumer); + if (producer->opcode() != HloOpcode::kFusion) { + return eval_consumer.CodeDuplicationTooHigh(producer); } - - // If 'producer' is a fusion node as well, also evaluate it. - if (producer->opcode() == HloOpcode::kFusion) { - evaluate_fusion_computation(producer); - } - - // Check that the code duplication has at most a factor of 15 (where 15 is an - // arbitrary constant that seems to work). - return index_usage_count[producer] > 15; + // If 'producer' is a fusion node as well, also evaluate it. Pass the + // evaluated duplication of the fusion node if it is merged into consumer. + FusionNodeIndexingEvaluation eval_producer( + producer, eval_consumer.EvaluateEmittedInstructions(producer)); + return eval_producer.MaxCodeDuplicationTooHigh(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index d13b0262180..e19e970cb24 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" @@ -153,9 +154,10 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault { // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges - absl::flat_hash_map< - const HloInstruction*, - absl::flat_hash_map, llvm::Value*>> + absl::flat_hash_map>, + llvm::Value*>> generated_value_cache_; }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index b6b3b2dd8b3..9d7f06f4f68 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -130,8 +130,14 @@ IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; @@ -164,7 +170,8 @@ Status LoopEmitter::EmitLoop(absl::string_view loop_name, } for (const IrArray::Index& array_index : - EmitIndexAndSetExitBasicBlock(loop_name, index_type)) { + EmitIndexAndSetExitBasicBlock(loop_name, index_type, + /*base_index*/ nullptr)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 008205a642a..a356741f74b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -71,11 +71,13 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty()); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty(), + /*base_index*/ nullptr); } virtual std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index); // Emits a complete loop nest for every element in the given shape. Status EmitLoop(absl::string_view loop_name = "", diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 6d4b0e65010..8fb01fa45f2 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -29,6 +29,105 @@ const HeapSimulator::Chunk kDummyChunk{-1, -1}; // pow(kWhileExecutionCount, nesting_level) times. const int kWhileExecutionCount = 5; +bool LooksLikeAnActivation(const HloInstruction* inst) { + for (HloInstruction* user : inst->users()) { + switch (user->opcode()) { + case HloOpcode::kConvolution: + case HloOpcode::kDot: + if (user->operand(0) == inst) { + return true; + } + break; + case HloOpcode::kGather: + if (user->operand(1) == inst) { + return true; + } + break; + case HloOpcode::kFusion: + for (int i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) == inst && + LooksLikeAnActivation(user->fused_parameter(i))) { + return true; + } + } + break; + case HloOpcode::kBitcast: + return LooksLikeAnActivation(user); + default: + return true; + } + } + return false; +} + +bool IsCrossProgramPrefetchCandidate( + const HloValue& value, const MemorySpaceAssignment::Options& options) { + return value.instruction()->parent() == + value.instruction()->GetModule()->entry_computation() && + value.instruction()->opcode() == HloOpcode::kParameter && + (!value.shape().has_layout() || + value.shape().layout().memory_space() != + options.alternate_memory_space) && + value.index().size() == 1 && value.shape().IsArray() && + !value.uses().empty() && + options.size_fn(value) <= options.max_size_in_bytes && + absl::c_all_of(value.uses(), [&](const HloUse& use) { + const HloInstruction* inst = + use.instruction->operand(use.operand_number); + + // Skip the LooksLikeAnActivation test since we're testing the + // parent GTE and its children below. + if (inst->opcode() == HloOpcode::kBitcast && + inst->operand(0)->opcode() == HloOpcode::kGetTupleElement && + inst->operand(0)->operand(0)->opcode() == + HloOpcode::kParameter) { + return true; + } + + return inst->opcode() == HloOpcode::kGetTupleElement && + !LooksLikeAnActivation(inst); + }); +} + +absl::optional +FindCrossProgramPrefetchCandidate( + const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, + const MemorySpaceAssignment::Options& options) { + std::vector candidates; + for (const HloBuffer& buffer : alias_analysis.buffers()) { + CHECK_GE(buffer.values().size(), 1); + const HloValue* value = buffer.values().at(0); + if (IsCrossProgramPrefetchCandidate(*value, options)) { + MemorySpaceAssignment::BufferInterval interval; + interval.buffer = value; + interval.size = options.size_fn(*value); + interval.start = 0; + interval.end = hlo_live_range.schedule_end_time(); + interval.need_allocation = true; + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; + candidates.emplace_back(interval); + } + } + + // The buffer_interval_compare ought to do a good job picking the most + // appropriate buffer to cross program prefetch, but empirically, it makes + // worse choices than just picking the largest buffer. + // TODO(b/152421603): Investigate. + auto size_compare = [](const auto& x, const auto& y) { + return x.size < y.size; + }; + auto& compare = options.default_cross_program_prefetch_heuristic && + options.buffer_interval_compare + ? *options.buffer_interval_compare + : size_compare; + + auto best_candidate = absl::c_max_element(candidates, compare); + if (best_candidate == candidates.end()) { + return absl::nullopt; + } + return *best_candidate; +} + } // namespace /*static*/ StatusOr> @@ -64,12 +163,16 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( while_nest_multiplier = it->second; } else { while_nest_multiplier = tensorflow::MathUtil::IPow( - kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); + kWhileExecutionCount, + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); cache->while_nest_multiplier[&instruction] = while_nest_multiplier; } } else { while_nest_multiplier = tensorflow::MathUtil::IPow( - kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); + kWhileExecutionCount, + CalculateComputationNestLevel(&instruction, + /*while_only=*/true)); } return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * while_nest_multiplier; @@ -125,8 +228,8 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( return alternate_mem_benefit / std::sqrt(interval.size); } -int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( - const HloInstruction* instruction) const { +int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel( + const HloInstruction* instruction, bool while_only) const { int nest_level = 0; const HloComputation* computation = instruction->parent(); while (!computation->IsEntryComputation()) { @@ -134,7 +237,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( auto callsites = node.caller_callsites(); CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; auto callsite = callsites[0]; - if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) { ++nest_level; } computation = callsite.instruction()->parent(); @@ -280,6 +383,8 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( float preferred_async_copy_to_overlap_ratio) : while_nest_level_( cost_analysis.hlo_live_range().instruction_schedule().size(), 0), + computation_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0), cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio), @@ -303,9 +408,12 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( instructions_elapsed_time.resize(logical_time + 1, 0.0); while_nest_level_.resize(logical_time + 1, 0); } - int nest_level = cost_analysis_.CalculateWhileLoopNestLevel( - instruction_and_logical_time.first); - while_nest_level_[logical_time] = nest_level; + int while_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/true); + while_nest_level_[logical_time] = while_nest_level; + int computation_nest_level = cost_analysis_.CalculateComputationNestLevel( + instruction_and_logical_time.first, /*while_only=*/false); + computation_nest_level_[logical_time] = computation_nest_level; if (instruction->opcode() == HloOpcode::kWhile || instruction->opcode() == HloOpcode::kConditional) { continue; @@ -313,8 +421,8 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( float elapsed_time = cost_analysis_.GetInstructionElapsed( *instruction_and_logical_time.first); instructions_elapsed_time[logical_time] = - elapsed_time * - tensorflow::MathUtil::IPow(kWhileExecutionCount, nest_level); + elapsed_time * tensorflow::MathUtil::IPow(kWhileExecutionCount, + while_nest_level); } // As an optimization, create a cumulative sum vector of elapsed time. float cumsum = 0.0; @@ -384,14 +492,14 @@ int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( /*output_in_alternate_mem=*/false); inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; } - int end_nest_level = while_nest_level_[end_time]; + int end_nest_level = computation_nest_level_[end_time]; // Find the latest time we're allowed to start prefetching. float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed; int latest_prefetch_time; for (latest_prefetch_time = end_time - 1; latest_prefetch_time >= start_time && - (while_nest_level_[latest_prefetch_time] != end_nest_level || + (computation_nest_level_[latest_prefetch_time] != end_nest_level || min_interval > GetLogicalIntervalElapsed(latest_prefetch_time, end_time) + inst_elapsed_reduction); @@ -412,13 +520,13 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed; float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, prefetch_end_time); - int end_nest_level = while_nest_level_[prefetch_end_time]; + int end_nest_level = computation_nest_level_[prefetch_end_time]; for (int64 prefetch_start_time = earliest_prefetch_start_time + 1; prefetch_start_time <= latest_prefetch_start_time; ++prefetch_start_time) { float interval = GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); - if (while_nest_level_[prefetch_start_time] == end_nest_level && + if (computation_nest_level_[prefetch_start_time] == end_nest_level && std::abs(preferred_interval - interval) < std::abs(preferred_interval - best_interval)) { best_interval = interval; @@ -432,10 +540,11 @@ int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const { // Iterate towards the beginning until we find a suitable end time that is the // same while nest level as the original prefetch end time. - int64 original_nest_level = while_nest_level_[original_prefetch_end_time]; + int64 original_nest_level = + computation_nest_level_[original_prefetch_end_time]; int64 new_prefetch_end_time; for (new_prefetch_end_time = proposed_prefetch_end_time; - while_nest_level_[new_prefetch_end_time] != original_nest_level; + computation_nest_level_[new_prefetch_end_time] != original_nest_level; --new_prefetch_end_time) { } return new_prefetch_end_time; @@ -456,7 +565,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, /*output_in_alternate_mem=*/false); inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; end_logical_time_ = end_time; - int end_nest_level = while_nest_level_[end_logical_time_]; + int end_nest_level = computation_nest_level_[end_logical_time_]; // Find the latest time we're allowed to start prefetching. float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_; @@ -468,7 +577,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, max_overlap_multiplier_ * async_copy_elapsed_; for (earliest_prefetch_time_ = start_time; earliest_prefetch_time_ <= end_logical_time_ && - (while_nest_level_[earliest_prefetch_time_] != end_nest_level || + (computation_nest_level_[earliest_prefetch_time_] != end_nest_level || max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_)); ++earliest_prefetch_time_) { @@ -506,8 +615,8 @@ int64 CostAnalysisPrefetchIntervalPicker::Next() { if (using_increasing_prefetch_time_iterator_) { int64 prefetch_time = increasing_prefetch_time_iterator_++; while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && - while_nest_level_[increasing_prefetch_time_iterator_] != - while_nest_level_[end_logical_time_]) { + computation_nest_level_[increasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_]) { ++increasing_prefetch_time_iterator_; } if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { @@ -517,8 +626,8 @@ int64 CostAnalysisPrefetchIntervalPicker::Next() { } else { int64 prefetch_time = decreasing_prefetch_time_iterator_--; while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && - while_nest_level_[decreasing_prefetch_time_iterator_] != - while_nest_level_[end_logical_time_]) { + computation_nest_level_[decreasing_prefetch_time_iterator_] != + computation_nest_level_[end_logical_time_]) { --decreasing_prefetch_time_iterator_; } if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { @@ -562,11 +671,11 @@ float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( // Since elapsed_time_cumsum_ is already weighed by the while loop nesting // level, normalize the elapsed time by dividing with the nesting factor of // the interval (start and end times). - int interval_nest_level = GetMinWhileNestLevel(start_time, end_time); + int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time); return (elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]) / tensorflow::MathUtil::IPow(kWhileExecutionCount, - interval_nest_level); + interval_while_nest_level); } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { @@ -709,12 +818,13 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( } void AlternateMemoryBestFitHeap::FindAliases( - std::vector* allocation_values) const { + std::vector* allocation_values, + bool skip_values_with_no_uses) const { absl::flat_hash_map values_by_defining_inst; for (AllocationValue& value : *allocation_values) { // Skip the value if it doesn't have any uses. - if (value.uses().empty()) { + if (value.uses().empty() && skip_values_with_no_uses) { continue; } CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0); @@ -981,6 +1091,17 @@ void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const { } HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { + if (options_.enable_cross_program_prefetch) { + absl::optional + prefetch_candidate = FindCrossProgramPrefetchCandidate( + alias_analysis_, hlo_live_range_, options_); + if (prefetch_candidate) { + HloModule* module = + prefetch_candidate->buffer->instruction()->GetModule(); + AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate); + } + } + std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -1157,7 +1278,7 @@ void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( for (const auto& colocated_interval : colocated_intervals) { CreateAllocationValues(*colocated_interval, allocation_values); } - FindAliases(&allocation_values); + FindAliases(&allocation_values, /*skip_values_with_no_uses=*/true); } AlternateMemoryBestFitHeap::Result @@ -2507,107 +2628,6 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( }; } -namespace { - -bool LooksLikeAnActivation(const HloInstruction* inst) { - for (HloInstruction* user : inst->users()) { - switch (user->opcode()) { - case HloOpcode::kConvolution: - case HloOpcode::kDot: - if (user->operand(0) == inst) { - return true; - } - break; - case HloOpcode::kGather: - if (user->operand(1) == inst) { - return true; - } - break; - case HloOpcode::kFusion: - for (int i = 0; i < user->operand_count(); ++i) { - if (user->operand(i) == inst && - LooksLikeAnActivation(user->fused_parameter(i))) { - return true; - } - } - break; - case HloOpcode::kBitcast: - return LooksLikeAnActivation(user); - default: - return true; - } - } - return false; -} - -bool IsCrossProgramPrefetchCandidate( - const HloValue& value, const MemorySpaceAssignment::Options& options) { - return value.instruction()->parent() == - value.instruction()->GetModule()->entry_computation() && - value.instruction()->opcode() == HloOpcode::kParameter && - (!value.shape().has_layout() || - value.shape().layout().memory_space() != - options.alternate_memory_space) && - value.index().size() == 1 && value.shape().IsArray() && - !value.uses().empty() && - options.size_fn(value) <= options.max_size_in_bytes && - absl::c_all_of(value.uses(), [&](const HloUse& use) { - const HloInstruction* inst = - use.instruction->operand(use.operand_number); - - // Skip the LooksLikeAnActivation test since we're testing the - // parent GTE and its children below. - if (inst->opcode() == HloOpcode::kBitcast && - inst->operand(0)->opcode() == HloOpcode::kGetTupleElement && - inst->operand(0)->operand(0)->opcode() == - HloOpcode::kParameter) { - return true; - } - - return inst->opcode() == HloOpcode::kGetTupleElement && - !LooksLikeAnActivation(inst); - }); -} - -absl::optional -FindCrossProgramPrefetchCandidate( - const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, - const MemorySpaceAssignment::Options& options) { - std::vector candidates; - for (const HloBuffer& buffer : alias_analysis.buffers()) { - CHECK_GE(buffer.values().size(), 1); - const HloValue* value = buffer.values().at(0); - if (IsCrossProgramPrefetchCandidate(*value, options)) { - MemorySpaceAssignment::BufferInterval interval; - interval.buffer = value; - interval.size = options.size_fn(*value); - interval.start = 0; - interval.end = hlo_live_range.schedule_end_time(); - interval.need_allocation = true; - interval.colocations = {++buffer.values().begin(), buffer.values().end()}; - candidates.emplace_back(interval); - } - } - - // The buffer_interval_compare ought to do a good job picking the most - // appropriate buffer to cross program prefetch, but empirically, it makes - // worse choices than just picking the largest buffer. - // TODO(b/152421603): Investigate. - auto size_compare = [](const auto& x, const auto& y) { - return x.size < y.size; - }; - auto& compare = options.default_cross_program_prefetch_heuristic && - options.buffer_interval_compare - ? *options.buffer_interval_compare - : size_compare; - - auto best_candidate = absl::c_max_element(candidates, compare); - if (best_candidate == candidates.end()) { - return absl::nullopt; - } - return *best_candidate; -} -} // namespace /*static*/ StatusOr> MemorySpaceAssignment::Run(HloModule* module, @@ -2658,13 +2678,6 @@ Status MemorySpaceAssignment::FindAllocationSequence( auto algorithm = absl::make_unique( &allocations_, options_, alias_analysis, hlo_live_range); - if (options_.enable_cross_program_prefetch) { - absl::optional - prefetch_candiate = FindCrossProgramPrefetchCandidate( - alias_analysis, hlo_live_range, options_); - algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate); - } - HeapSimulator::Options heap_simulator_options; heap_simulator_options.may_reuse_operand_buffers = false; TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_, diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 409a44d319d..b50015f82b3 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -149,9 +149,11 @@ class MemorySpaceAssignmentCostAnalysis { int64 GetScheduleEndTime() const; - // Returns the number of nested while loop levels this instruction resides in. - // 0 means it is not in a while loop. - int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; + // Returns the number of nested computation levels this instruction resides + // in. If while_only is true, it returns the while loop nest level and 0 + // means the instruction is not in a while loop. + int CalculateComputationNestLevel(const HloInstruction* instruction, + bool while_only) const; const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } @@ -360,6 +362,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { // (in cumulative sum) and while nesting level. std::vector elapsed_time_cumsum_; std::vector while_nest_level_; + std::vector computation_nest_level_; // Maintain the index of the most recent (before this instruction) nest level // change in order to efficiently determine the minimum nest level in an // interval. @@ -728,6 +731,16 @@ class MemorySpaceAssignment { // All the positions where this use aliases with. The aliased positions // must get the same allocation. std::vector aliases; + + bool operator==(const Use& other) const { + return hlo_use == other.hlo_use && time == other.time && + aliases == other.aliases; + } + + template + friend H AbslHashValue(H h, const Use& s) { + return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); + } }; AllocationValue(const HloValue* value, const HloPosition& position, @@ -823,6 +836,8 @@ class MemorySpaceAssignment { AllocationSequence allocations_; + HloModule* module() { return module_; } + private: // Process calls Process methods of the allocations after the allocations have // been finalized. @@ -949,6 +964,38 @@ class AlternateMemoryBestFitHeap HeapSimulator::Result Finish() override; + protected: + // Given a buffer interval, returns the colocated intervals. Unlike the + // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it + // returns the colocated intervals sorted by scheduled time. + std::vector GetSortedColocatedIntervals( + const BufferInterval& interval) const; + + // Given a BufferInterval, creates AllocationValue objects and corresponding + // AllocationSequences and appends them into allocation_sequence_list_. + void CreateAllocationValues( + const BufferInterval& buffer_interval, + std::vector& allocation_values) const; + + // Given colocated intervals, populates allocation_values with the + // corresponding AllocationValue objects. + void CreateAllocationValuesFromColocatedIntervals( + absl::Span + colocated_intervals, + std::vector& allocation_values); + + // Go through all the uses in the AllocationValues and find the aliasing + // positions. + void FindAliases(std::vector* allocation_values, + bool skip_values_with_no_uses) const; + + MemorySpaceAssignment::AllocationSequence* allocations() { + return allocations_; + } + const MemorySpaceAssignment::Options& options() { return options_; } + const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } + const HloLiveRange& hlo_live_range() { return hlo_live_range_; } + private: // We inherit AllocationBlock struct to attach the Allocation information to // make importing repacked offsets easier. @@ -1096,18 +1143,6 @@ class AlternateMemoryBestFitHeap bool IsUseAllowedInAlternateMemory(const AllocationValue& value, const HloUse& use) const; - // Given a BufferInterval, creates AllocationValue objects and corresponding - // AllocationSequences and appends them into allocation_sequence_list_. - void CreateAllocationValues( - const BufferInterval& buffer_interval, - std::vector& allocation_values) const; - - // Given colocated intervals, populates allocation_values with the - // corresponding AllocationValue objects. - void CreateAllocationValuesFromColocatedIntervals( - absl::Span colocated_intervals, - std::vector& allocation_values); - // Finds allocations for allocation values generated from colocated intervals. // All of the allocation values have a must-alias relationship with each // other. Returns either kSuccess if all of the sites could be placed in the @@ -1115,10 +1150,6 @@ class AlternateMemoryBestFitHeap Result AllocateAllocationValues( absl::Span allocation_values); - // Go through all the uses in the AllocationValues and find the aliasing - // positions. - void FindAliases(std::vector* allocation_values) const; - // Finds an allocation for an allocation request for a segment (see the // documentation for AllocationRequest above how a segment is defined). // @@ -1194,12 +1225,6 @@ class AlternateMemoryBestFitHeap bool AreIntervalsReservedInAlternateMemory( absl::Span colocated_intervals) const; - // Given a buffer interval, returns the colocated intervals. Unlike the - // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it - // returns the colocated intervals sorted by scheduled time. - std::vector GetSortedColocatedIntervals( - const BufferInterval& interval) const; - // Since the allocations are recorded to the AllocationSequence, we don't // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap // to avoid unnecessarily adding the chunk to the chunk map. diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 5af61eac5d1..f574bd37937 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -5148,5 +5148,75 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { 4); } +TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { + // This is a test for b/170668492, where prefetching for consecutive + // conditionals can cause the prefetch to start in the conditional's + // computation. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + true_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 5 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 6 + ROOT neg1 = f32[3]{0} negate(gte) // 7 + } + + false_computation.0 { + p0 = (f32[3]{0}) parameter(0) // 8 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 9 + ROOT neg2 = f32[3]{0} negate(gte) // 10 + } + + true_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 12 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 13 + ROOT neg1 = f32[3]{0} negate(gte) // 14 + } + + false_computation.1 { + p0 = (f32[3]{0}) parameter(0) // 15 + gte = f32[3]{0} get-tuple-element(p0), index=0 // 16 + ROOT neg2 = f32[3]{0} negate(gte) // 17 + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) // 0 + p1 = f32[3]{0} parameter(1) // 1 + p2 = pred[] parameter(2) // 2 + tuple0 = (f32[3]{0}) tuple(p0) // 3 + tuple1 = (f32[3]{0}) tuple(p1) // 4 + conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11 + conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18 + ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeMemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, *module)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_async_copy_to_overlap_ratio=*/1.0, + /*max_async_copy_to_overlap_ratio=*/12.0, + /*preferred_async_copy_to_overlap_ratio=*/2.0); + + LOG(INFO) << module->ToString(); + + HloInstruction* conditional1 = + module->entry_computation()->GetInstructionWithName("conditional1"); + const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}}; + const Shape& shape = + module->entry_computation()->parameter_instruction(0)->shape(); + + // Expect that the prefetch to start before conditional0's called + // computations. + EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/11, &use), + 5); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 954fb1360fd..4eaed3a12e6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -76,7 +76,7 @@ cc_library( ":emission_context", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/gpu:target_constants", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "@llvm-project//llvm:Core", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", @@ -117,7 +117,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:stream_executor_util", "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", - "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core/platform:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/stream_executor/gpu:asm_compiler", ]), @@ -191,6 +191,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], @@ -210,6 +211,7 @@ cc_library( "//tensorflow/compiler/mlir/hlo:lhlo_fuse_linalg", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -276,6 +278,7 @@ tf_cc_binary( "//tensorflow/core:lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:Support", ], ) @@ -293,6 +296,7 @@ tf_cc_binary( "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index db03aaeb7d1..a664a316e13 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -35,11 +35,12 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/Bufferize.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" #include "tensorflow/compiler/xla/util.h" @@ -48,7 +49,7 @@ namespace mlir_gpu { Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { mlir::PassManager pm(module.getContext()); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); // We have to anticipate later unrolling in tiling to make sure that we get // the requested tiling after unrolling. Compute the new tiling here if @@ -124,8 +125,6 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { pm.addNestedPass<::mlir::FuncOp>( ::mlir::mhlo::createLegalizeTrigonometricToApproximationPass()); } - // Move scalar operations into the launch to ensure smaller signatures. - pm.addPass(createMoveScalarComputationsIntoGpuLaunchPass()); // Take launches to launches with kernels. pm.addPass(::mlir::createGpuKernelOutliningPass()); // Make sure the kernel signature resembled the original function's @@ -181,7 +180,7 @@ class LowerToNVVMPass Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { // We cannot verify as the signature of the kernel is rewritten. ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); // Rewrite kernel functions to LLVM IR. auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); @@ -250,7 +249,7 @@ class LowerToROCDLPass Status LowerKernelBodiesToROCDL(mlir::ModuleOp module) { // We cannot verify as the signature of the kernel is rewritten. ::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false); - applyPassManagerCLOptions(pm); + tensorflow::applyTensorflowAndCLOptions(pm); auto enable_if_vlog_is_on = [](mlir::Pass*, mlir::Operation*) { return VLOG_IS_ON(1); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index 2e94c1a54f2..52af857efe5 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -438,7 +438,6 @@ StatusOr> TransformKernelToXlaThunk( // Finally, create the thunk and set the launch dimensions. gpu::Thunk::ThunkInfo info; - info.hlo_instruction = instr; auto thunk = absl::make_unique(info, buffers, kernel.getName().str()); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc index 7f4810e3f54..84751bc0507 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" @@ -54,6 +55,21 @@ struct FusionOpRemoverPass : FusionOpRemoverPassBase { } }; +template +bool HasEffectsOnValue(mlir::Value value, mlir::Operation* op) { + auto mem_effects_interface = + mlir::dyn_cast_or_null(op); + if (!mem_effects_interface) { + return false; + } + llvm::SmallVector effects; + mem_effects_interface.getEffects(effects); + return llvm::any_of(effects, + [op](const mlir::MemoryEffects::EffectInstance& effect) { + return mlir::isa(effect.getEffect()); + }); +} + struct StoreForwardingPass : StoreForwardingPassBase { mlir::StoreOp findStore(mlir::Operation* op, std::function matches) { @@ -87,10 +103,9 @@ struct StoreForwardingPass : StoreForwardingPassBase { while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { defOp = subviewOp.source().getDefiningOp(); } - if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { - return allocOp.getOperation(); - } - return nullptr; + return HasEffectsOnValue(memref, defOp) + ? defOp + : nullptr; } // Retrieves AllocOp from the cache or actually looks for it. @@ -101,7 +116,7 @@ struct StoreForwardingPass : StoreForwardingPassBase { if (allocOpIt != memrefToAllocOp->end()) { return allocOpIt->second; } - auto allocOp = SearchAllocOp(memref); + mlir::Operation* allocOp = SearchAllocOp(memref); memrefToAllocOp->insert({memref, allocOp}); return allocOp; } @@ -169,13 +184,18 @@ struct DeadTempBufferRemovalPass void runOnFunction() override { llvm::SmallVector dead_ops; - getFunction().walk([&](mlir::AllocOp allocOp) { - if (!operationConsideredDead(allocOp)) { + getFunction().walk([&](mlir::Operation* op) { + if (op->getNumResults() != 1 || + !HasEffectsOnValue(op->getResult(0), + op)) { + return; + } + if (!operationConsideredDead(op)) { return; } // TODO(herhut): There should be a generic helper for this. - recursiveErase(allocOp, &dead_ops); + recursiveErase(op, &dead_ops); }); for (auto op : dead_ops) { op->erase(); @@ -183,63 +203,6 @@ struct DeadTempBufferRemovalPass } }; -struct MoveScalarComputationsIntoGpuLaunchPass - : MoveScalarComputationsIntoGpuLaunchPassBase< - MoveScalarComputationsIntoGpuLaunchPass> { - static bool isInliningBeneficiary(mlir::Operation* op) { - return llvm::isa(op); - } - - static bool extractBeneficiaryOps( - mlir::Operation* op, llvm::SmallVectorImpl* ops, - llvm::SetVector args) { - if (!isInliningBeneficiary(op)) { - return false; - } - - ops->push_back(op); - for (auto operand : op->getOperands()) { - // It is an existing arg, keep going. - if (args.count(operand)) { - continue; - } - mlir::Operation* definingOp = operand.getDefiningOp(); - if (!definingOp || !extractBeneficiaryOps(definingOp, ops, args)) { - return false; - } - } - return true; - } - - static void inlineOperationsIntoLaunch(mlir::gpu::LaunchOp launch) { - llvm::SetVector used_above; - mlir::getUsedValuesDefinedAbove(launch.body(), used_above); - mlir::BlockAndValueMapping inlined_map; - for (mlir::Value v : used_above) { - llvm::SmallVector ops_to_move; - mlir::Operation* definingOp = v.getDefiningOp(); - if (definingOp && - extractBeneficiaryOps(definingOp, &ops_to_move, used_above)) { - mlir::OpBuilder b(launch.body()); - for (mlir::Operation* op : llvm::reverse(ops_to_move)) { - auto result = b.clone(*op, inlined_map); - for (auto pair : llvm::zip(op->getResults(), result->getResults())) { - mlir::replaceAllUsesInRegionWith(std::get<0>(pair), - std::get<1>(pair), launch.body()); - } - inlined_map.map(op->getResults(), result->getResults()); - } - } - } - } - - void runOnFunction() override { - getFunction().walk( - [](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); }); - } -}; - struct RewriteKernelSignaturePass : RewriteKernelSignaturePassBase { void runOnFunction() override { @@ -394,11 +357,6 @@ std::unique_ptr createDeadTempBufferRemovalPass() { return absl::make_unique(); } -std::unique_ptr -createMoveScalarComputationsIntoGpuLaunchPass() { - return absl::make_unique(); -} - std::unique_ptr createRewriteKernelSignaturePass() { return absl::make_unique(); } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.h b/tensorflow/compiler/xla/service/mlir_gpu/passes.h index 19ebe53f7ce..832321387c6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/passes.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.h @@ -37,10 +37,6 @@ std::unique_ptr createStoreForwardingPass(); /// that loads and stores are side-effect free (in bounds, no aliasing, etc.). std::unique_ptr createDeadTempBufferRemovalPass(); -/// Moves scalar computations to the GPULaunchOp body. -std::unique_ptr -createMoveScalarComputationsIntoGpuLaunchPass(); - /// Sorts the operands to the kernel for a deterministic order. First operands /// that are defined by function arguments, followed by operands that are /// returned from the function. This only works for simple functions without diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.td b/tensorflow/compiler/xla/service/mlir_gpu/passes.td index 1b19fcf5274..55fe15ad6ff 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/passes.td +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.td @@ -46,15 +46,6 @@ def DeadTempBufferRemovalPass }]; } -def MoveScalarComputationsIntoGpuLaunchPass - : FunctionPass<"mlir-gpu-inline-scalar-computation"> { - let summary = "Pass to Move scalar computations to the GPULaunchOp body."; - let constructor = "createMoveScalarComputationsIntoGpuLaunchPass()"; - let description = [{ - Moves scalar computations to the GPULaunchOp body. - }]; -} - def RewriteKernelSignaturePass : FunctionPass<"mlir-gpu-rewrite-signatures"> { let summary = "Rewrite kernel signatures to be deterministic."; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir index 15329d6181a..58132f4ea45 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/dead_temp_buffer_removal.mlir @@ -6,6 +6,18 @@ func @dead() { %0 = alloc() : memref<42xi32> %c0 = constant 0 : i32 %c12 = constant 12 : index + // CHECK-NOT: store + store %c0, %0[%c12] : memref<42xi32> + return +} + +// CHECK-LABEL: @dead_alloca +func @dead_alloca() { + // CHECK-NOT: alloca + %0 = alloc() : memref<42xi32> + %c0 = constant 0 : i32 + %c12 = constant 12 : index + // CHECK-NOT: store store %c0, %0[%c12] : memref<42xi32> return } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/fusion_op_remover.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/fusion_op_remover.mlir new file mode 100644 index 00000000000..69ebbbd5a72 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/fusion_op_remover.mlir @@ -0,0 +1,20 @@ +// RUN: xla-mlir-gpu-opt --mlir-gpu-fusion-op-remover %s | FileCheck %s + +// CHECK-LABEL: func @fusion_memref +func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, + %input3: memref<10xf32>, %out: memref<10xf32>) -> () { + // CHECK-NOT: lmhlo.fusion + "lmhlo.fusion"() ( { + %0 = tensor_load %input1 : memref<10xf32> + %1 = tensor_load %input2 : memref<10xf32> + %2 = "mhlo.add"(%0, %1) {name = "add"} + : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %3 = tensor_load %input3 : memref<10xf32> + %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} + : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + tensor_store %4, %out : memref<10xf32> + // CHECK-NOT: lmhlo.terminator + "lmhlo.terminator"() : () -> () + } ) : () -> () + return +} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/rewrite_kernel_signatures.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/rewrite_kernel_signatures.mlir new file mode 100644 index 00000000000..cff1989f05b --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/rewrite_kernel_signatures.mlir @@ -0,0 +1,138 @@ +// RUN: xla-mlir-gpu-opt --mlir-gpu-rewrite-signatures %s --split-input-file --verify-diagnostics | FileCheck %s + +module attributes {gpu.container_module} { + +// CHECK-LABEL: @kernel_module +gpu.module @kernel_module { + // CHECK-LABEL: gpu.func @kernel + // CHECK-SAME: %{{.*}}: memref<32xf32>, %{{.*}}: memref<16xf32>, + // CHECK-SAME: %{{.*}}: memref<8xf32> + gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>, + %arg2: memref<32xf32>) kernel { + gpu.return + } +} + + // CHECK-LABEL: @caller +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + + // CHECK: gpu.launch_func + // CHECK-SAME: index, memref<32xf32>, memref<16xf32>, memref<8xf32>) + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<8xf32>, memref<16xf32>, memref<32xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + // expected-error @+1 {{number of kernel arguments does not match numberof arguments and results of surrounding function}} + gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>) kernel { + gpu.return + } +} + +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<16xf32>, memref<32xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + // expected-error @+1 {{result 0 of containing function is not an argument to the kernel}} + gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>, + %arg2: memref<8xf32>) kernel { + gpu.return + } +} + +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + %fake = alloc() : memref<8xf32> + + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %arg1, %arg0, %fake) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<16xf32>, memref<32xf32>, memref<8xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + // expected-error @+1 {{argument 1 to containing function is not an argument to the kernel}} + gpu.func @kernel(%arg0: memref<16xf32>, %arg1: memref<32xf32>, + %arg2: memref<8xf32>) kernel { + gpu.return + } +} + +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + %fake = alloc() : memref<16xf32> + + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %fake, %arg0, %res) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<16xf32>, memref<32xf32>, memref<8xf32>) -> () + + return %res : memref<8xf32> +} + +} + +// ----- + +module attributes {gpu.container_module} { + +gpu.module @kernel_module { + gpu.func @kernel(%arg0: memref<8xf32>, %arg1: memref<16xf32>, + %arg2: memref<32xf32>) kernel { + gpu.return + } +} + +// expected-error @+1 {{surrounding function has more than one block}} +func @caller(%arg0: memref<32xf32>, %arg1: memref<16xf32>) -> memref<8xf32> { + %cst = constant 8 : index + %res = alloc() : memref<8xf32> + br ^bb1 + + ^bb1: + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %res, %arg1, %arg0) + { kernel = @kernel_module::@kernel } + : (index, index, index, index, index, index, + memref<8xf32>, memref<16xf32>, memref<32xf32>) -> () + + return %res : memref<8xf32> +} + +} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir new file mode 100644 index 00000000000..8b993bb56a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/passes/store_forwarding_pass.mlir @@ -0,0 +1,72 @@ +// RUN: xla-mlir-gpu-opt --mlir-gpu-store-forwarding %s | FileCheck %s + +// CHECK-LABEL: @forward +func @forward() -> f32 { + %0 = alloc() : memref<1024xf32> + %c42 = constant 24 : index + // CHECK: %[[CST:.*]] = constant 1.0 + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK-NOT: load + %1 = load %0[%c42] : memref<1024xf32> + // CHECK: return %[[CST]] + return %1 : f32 +} + +// CHECK-LABEL: @forward_alloca +func @forward_alloca() -> f32 { + %0 = alloca() : memref<1024xf32> + %c42 = constant 24 : index + // CHECK: %[[CST:.*]] = constant 1.0 + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK-NOT: load + %1 = load %0[%c42] : memref<1024xf32> + // CHECK: return %[[CST]] + return %1 : f32 +} + +// CHECK-LABEL: @wrong_index +func @wrong_index() -> f32 { + %0 = alloc() : memref<1024xf32> + %c42 = constant 24 : index + %c12 = constant 12 : index + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK: %[[RES:.*]] = load + %1 = load %0[%c12] : memref<1024xf32> + // CHECK: return %[[RES]] + return %1 : f32 +} + +// CHECK-LABEL: @wrong_memref +func @wrong_memref() -> f32 { + %0 = alloc() : memref<1024xf32> + %1 = alloc() : memref<1024xf32> + %c42 = constant 24 : index + %c1 = constant 1.0 : f32 + store %c1, %0[%c42] : memref<1024xf32> + // CHECK: %[[RES:.*]] = load + %2 = load %1[%c42] : memref<1024xf32> + // CHECK: return %[[RES]] + return %2 : f32 +} + +// CHECK-LABEL: @with_parallel_loop +func @with_parallel_loop() { + %0 = alloc() : memref<1024xf32> + %c0 = constant 0 : index + %c42 = constant 24 : index + %c1 = constant 1 : index + // CHECK: %[[CST:.*]] = constant 1.100000e+01 : f32 + %c11 = constant 1.100000e+01 : f32 + store %c11, %0[%c42] : memref<1024xf32> + // CHECK: scf.parallel + scf.parallel (%i0) = (%c0) to (%c42) step (%c1) { + // CHECK-NOT: load + %1 = load %0[%c42] : memref<1024xf32> + // CHECK-NEXT: store %[[CST]] + store %1, %0[%c0] : memref<1024xf32> + } + return +} diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 473a9ca7456..67c7896cebd 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -31,13 +31,18 @@ limitations under the License. namespace xla { -ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape, - const se::Platform* platform, int device_ordinal) - : on_host_shape_(std::move(on_host_shape)), - on_device_shape_(std::move(on_device_shape)), +ShapedBuffer::ShapedBuffer(Shape on_device_shape, const se::Platform* platform, + int device_ordinal) + : on_device_shape_(std::move(on_device_shape)), platform_(platform), device_ordinal_(device_ordinal), - buffers_(&on_device_shape_) {} + buffers_(&on_device_shape_) { + on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape_); +} + +ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape, + const se::Platform* platform, int device_ordinal) + : ShapedBuffer(on_device_shape, platform, device_ordinal) {} ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) : on_host_shape_(std::move(s.on_host_shape_)), @@ -52,8 +57,8 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) } ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { - on_host_shape_ = std::move(s.on_host_shape_); on_device_shape_ = std::move(s.on_device_shape_); + on_host_shape_ = std::move(s.on_host_shape_); platform_ = s.platform_; device_ordinal_ = s.device_ordinal_; buffers_ = std::move(s.buffers_); @@ -68,12 +73,9 @@ ShapedBuffer::~ShapedBuffer() {} StatusOr ShapedBuffer::SubShapedBuffer( const ShapeIndex& index) const { - TF_ASSIGN_OR_RETURN(const Shape* host_sub_shape, - ShapeUtil::TryGetSubshape(on_host_shape(), index)); TF_ASSIGN_OR_RETURN(const Shape* device_sub_shape, ShapeUtil::TryGetSubshape(on_device_shape(), index)); - ShapedBuffer sub_shaped_buffer(*host_sub_shape, *device_sub_shape, platform_, - device_ordinal_); + ShapedBuffer sub_shaped_buffer(*device_sub_shape, platform_, device_ordinal_); TF_ASSIGN_OR_RETURN(ShapeTree sub_buffers, buffers_.SubShapeTree(index)); sub_shaped_buffer.set_buffers(std::move(sub_buffers)); @@ -88,12 +90,11 @@ void ShapedBuffer::clear() { } string ShapedBuffer::ToString() const { - string s = absl::StrCat( - "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), - "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), - ", on-device shape=" + - ShapeUtil::HumanStringWithLayout(on_device_shape()), - ":\n"); + string s = + absl::StrCat("ShapedBuffer(", platform_->Name(), ":", device_ordinal(), + "), on-device shape=" + + ShapeUtil::HumanStringWithLayout(on_device_shape()), + ":\n"); ShapeUtil::ForEachSubshape( on_device_shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { @@ -116,13 +117,19 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } +ScopedShapedBuffer::ScopedShapedBuffer(Shape on_device_shape, + se::DeviceMemoryAllocator* allocator, + int device_ordinal) + : ShapedBuffer(std::move(on_device_shape), allocator->platform(), + device_ordinal), + allocator_(allocator) {} + ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal) - : ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape), - allocator->platform(), device_ordinal), - allocator_(allocator) {} + : ScopedShapedBuffer(std::move(on_device_shape), allocator, + device_ordinal) {} ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, se::DeviceMemoryAllocator* allocator) @@ -171,13 +178,11 @@ void ScopedShapedBuffer::Deallocate() { } ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) { - const xla::Shape& sub_on_host_shape = - xla::ShapeUtil::GetSubshape(on_host_shape(), {index}); const xla::Shape& sub_on_device_shape = xla::ShapeUtil::GetSubshape(on_device_shape(), {index}); - ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape, - memory_allocator(), device_ordinal()); + ScopedShapedBuffer output(sub_on_device_shape, memory_allocator(), + device_ordinal()); auto src_it = buffers().find(index); auto dst_it = output.buffers().begin(); while (dst_it != output.buffers().end()) { diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index 995b0ece7cd..7f1248998a6 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -43,6 +43,10 @@ class ShapedBuffer { // both the on-host and on-device shape are required. The on-device shape // determines the number of device allocations (DeviceMemoryBase) held by the // ShapedBuffer. + ShapedBuffer(Shape on_device_shape, const se::Platform* platform, + int device_ordinal); + + // TODO(b/170310047): remove this overload. ShapedBuffer(Shape on_host_shape, Shape on_device_shape, const se::Platform* platform, int device_ordinal); @@ -97,14 +101,18 @@ class ShapedBuffer { // Reset the shape of this shaped buffer and underlying buffer structure. // // Precondition: EqualStructure(this->on_device_shape_, on_device_shape). - void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + void set_shapes(const Shape& on_device_shape) { CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) << "Structures are not the same. new: " << on_device_shape << ", old: " << on_device_shape_; - on_host_shape_ = on_host_shape; + on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape); on_device_shape_ = on_device_shape; buffers_.replace_shape_ptr(&on_device_shape_); } + // TODO(b/170310047): remove this overload. + void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + set_shapes(on_device_shape); + } // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. @@ -119,7 +127,6 @@ class ShapedBuffer { string ToString() const; protected: - // The shape of the data when represented on the host. Shape on_host_shape_; // The shape of the data on the device. @@ -148,6 +155,10 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); class ScopedShapedBuffer : public ShapedBuffer { public: // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. + explicit ScopedShapedBuffer(Shape on_device_shape, + se::DeviceMemoryAllocator* allocator, + int device_ordinal); + // TODO(b/170310047): remove this overload. explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape, se::DeviceMemoryAllocator* allocator, int device_ordinal); diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index a2c208d62e4..49751d10c5a 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -97,12 +97,12 @@ class TestAllocator : public se::DeviceMemoryAllocator { TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { Shape s = ShapeUtil::MakeShape(F32, {1}); TestAllocator allocator; - ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); + ScopedShapedBuffer sb1(s, &allocator, /*device_ordinal=*/0); sb1.set_buffer( allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), /*index=*/{}); - ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); + ScopedShapedBuffer sb2(s, &allocator, /*device_ordinal=*/1); sb2.set_buffer( allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), /*index=*/{}); @@ -119,7 +119,7 @@ TEST(ScopedShapedBufferTest, TestTakeSubTree) { s = xla::ShapeUtil::MakeTupleShape(std::vector(2, s)); s = xla::ShapeUtil::MakeTupleShape(std::vector(3, s)); - ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0); + ScopedShapedBuffer sb(s, &allocator, /*device_ordinal=*/0); sb.buffers().ForEachMutableElement( [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { TF_ASSERT_OK_AND_ASSIGN( @@ -156,8 +156,7 @@ TEST(ScopedShapedBufferTest, TestSubShapeTree) { Shape tuple_shape = xla::ShapeUtil::MakeTupleShape({array_shape, array_shape}); TestAllocator allocator; - ScopedShapedBuffer sb(tuple_shape, tuple_shape, &allocator, - /*device_ordinal=*/0); + ScopedShapedBuffer sb(tuple_shape, &allocator, /*device_ordinal=*/0); sb.buffers().ForEachMutableElement( [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { TF_ASSERT_OK_AND_ASSIGN( @@ -182,7 +181,7 @@ void BM_TakeSubTree(int iters, int depth, int fan_out) { std::vector shapes(fan_out, shape); shape = xla::ShapeUtil::MakeTupleShape(shapes); } - xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator, + xla::ScopedShapedBuffer shaped_buffer(shape, /*allocator=*/&allocator, /*device_ordinal=*/0); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.cc b/tensorflow/compiler/xla/service/space_to_batch_converter.cc index 6339e6806fb..47aee8ed5a8 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.cc @@ -52,7 +52,7 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; // Runs the visitor on a computation. - static bool Run(HloComputation* computation); + static bool Run(int64 limit_on_batch_size, HloComputation* computation); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -60,18 +60,23 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation) - : computation_(computation) {} + explicit ConvolutionVisitor(int64 limit_on_batch_size, + HloComputation* computation) + : computation_(computation), limit_on_batch_size_(limit_on_batch_size) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; // Whether rewrite has occurred. bool changed_ = false; + + // Limit on batch size to apply this technique on. + int64 limit_on_batch_size_; }; -bool ConvolutionVisitor::Run(HloComputation* computation) { - ConvolutionVisitor visitor(computation); +bool ConvolutionVisitor::Run(int64 limit_on_batch_size, + HloComputation* computation) { + ConvolutionVisitor visitor(limit_on_batch_size, computation); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -93,27 +98,29 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { constexpr int64 kLowLimitForSplitCount = 4; constexpr int64 kHighLimitForSplitCount = 24; + // Batch in batch_group_count has different semantics (it isn't true batch). + // Consider supporting this case in future if needed. + if (convolution->batch_group_count() != 1) { + return Status::OK(); + } + if (convolution->window().dimensions(kChosenSpatialDim).window_dilation() != 1) { return Status::OK(); } + // TODO(b/168316428): Support base dilations. if (convolution->window().dimensions(kChosenSpatialDim).base_dilation() != 1) { return Status::OK(); } - if (convolution->window().dimensions(kChosenSpatialDim).window_reversal()) { - return Status::OK(); - } - int64 activations_batch_dim = dim_numbers.input_batch_dimension(); const int64 old_batch_size = convolution->operand(0)->shape().dimensions(activations_batch_dim); - // TODO(b/168316428): Only doing this for batch 1 currently. Extend later. - if (old_batch_size != 1) { + if (old_batch_size > limit_on_batch_size_) { return Status::OK(); } @@ -265,11 +272,20 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { // -2 low padding and +2 high padding) to create shape B. Then, we select // between A and B such that halo regions are placed into A at the right // locations. + + // The benefit of the above mentioned scheme is that it allows for batch + // growth. Here are some examples of the size increases it causes for a 3x3 + // kernel. + // with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8. + // with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16. + // with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24. + std::vector reshape_dimensions( activations->shape().dimensions().begin(), activations->shape().dimensions().end()); + reshape_dimensions[spatial_dimension_to_split] = spatial_split_size; - reshape_dimensions[activations_batch_dim] = num_splits; + reshape_dimensions[activations_batch_dim] = num_splits * old_batch_size; TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape, MakeReshapeHlo(reshape_dimensions, activations)); @@ -341,11 +357,19 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { TF_ASSIGN_OR_RETURN(HloInstruction * select, MakeSelectHlo(shape_mask, straightened_activations, rotated_activations, convolution)); - VLOG(1) << "Select generated"; + VLOG(1) << "Select generated" << select->ToString(); // Increase batch size for one last time. - TF_ASSIGN_OR_RETURN( - activations, MakeReshapeHlo(pad_applied->shape().dimensions(), select)); + std::vector combined_batch_dimensions( + pad_applied->shape().dimensions().begin(), + pad_applied->shape().dimensions().end()); + + combined_batch_dimensions[activations_batch_dim] = + old_batch_size * num_splits; + TF_ASSIGN_OR_RETURN(activations, + MakeReshapeHlo(combined_batch_dimensions, select)); + + VLOG(1) << "Batch merge done " << activations->ToString(); // Now, we rewrite the convolution with a larger batch. const auto& activations_shape = activations->shape(); @@ -389,28 +413,35 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { VLOG(1) << "new_conv " << new_conv->ToString(); + const int64 output_split_spatial_dim = + new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim); + const int64 output_batch_dim = new_dim_numbers.output_batch_dimension(); + Shape new_shape = new_conv->shape(); - const int64 new_batch_size = - new_shape.dimensions(new_dim_numbers.output_batch_dimension()); - const int64 new_spatial_dim_size = new_shape.dimensions( - new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); - new_shape.set_dimensions( - new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim), - new_batch_size * new_spatial_dim_size); - new_shape.set_dimensions(new_dim_numbers.output_batch_dimension(), - old_batch_size); + const int64 new_batch_size = new_shape.dimensions(output_batch_dim); + const int64 new_spatial_dim_size = + new_shape.dimensions(output_split_spatial_dim); + + CHECK_EQ(new_batch_size % old_batch_size, 0); + + const int64 output_split_batch_size = new_batch_size / old_batch_size; + + std::vector new_dimensions(new_conv->shape().dimensions().begin(), + new_conv->shape().dimensions().end()); + new_dimensions[output_split_spatial_dim] = + output_split_batch_size * new_spatial_dim_size; + new_dimensions[new_dim_numbers.output_batch_dimension()] = old_batch_size; // Reshape the output of the new conv into the old convolutions shape. TF_ASSIGN_OR_RETURN(HloInstruction * reshape, - MakeReshapeHlo(new_shape, new_conv)); + MakeReshapeHlo(new_dimensions, new_conv)); convolution->SetupDerivedInstruction(reshape); std::vector start_indices(rank, 0), - end_indices(new_shape.dimensions().begin(), new_shape.dimensions().end()), + end_indices(new_dimensions.begin(), new_dimensions.end()), strides(rank, 1); - end_indices[new_dim_numbers.output_spatial_dimensions(kChosenSpatialDim)] = - convolution->shape().dimensions( - dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); + end_indices[output_split_spatial_dim] = convolution->shape().dimensions( + dim_numbers.output_spatial_dimensions(kChosenSpatialDim)); // This slicing is getting rid of the padding we added to evenly divide space. TF_ASSIGN_OR_RETURN( @@ -435,7 +466,7 @@ StatusOr ConvolutionSpaceToBatchConverter::Run(HloModule* module) { module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp)) { + if (ConvolutionVisitor::Run(limit_on_batch_size_, comp)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter.h b/tensorflow/compiler/xla/service/space_to_batch_converter.h index da4102ddf37..a92abda0337 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter.h +++ b/tensorflow/compiler/xla/service/space_to_batch_converter.h @@ -26,7 +26,8 @@ namespace xla { // batch. class ConvolutionSpaceToBatchConverter : public HloModulePass { public: - ConvolutionSpaceToBatchConverter() = default; + explicit ConvolutionSpaceToBatchConverter(int64 limit_on_batch_size = 1) + : limit_on_batch_size_(limit_on_batch_size) {} absl::string_view name() const override { return "convolution-space-to-batch-converter"; @@ -35,6 +36,8 @@ class ConvolutionSpaceToBatchConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + int64 limit_on_batch_size_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc index 0495d7f1031..bbc3882cde9 100644 --- a/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc +++ b/tensorflow/compiler/xla/service/space_to_batch_converter_test.cc @@ -65,31 +65,42 @@ ENTRY computation { TEST_F(ConvolutionSpaceToBatchConverterTest, SimpleBatch2) { string hlo_string = R"( - HloModule module -ENTRY computation { - %p0 = bf16[2,258,258,32] parameter(0) - %p1 = bf16[3,3,32,32] parameter(1) - ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, - dim_labels=b01f_01io->b01f -} + ENTRY computation { + %p0 = bf16[2,258,258,32] parameter(0) + %p1 = bf16[3,3,32,32] parameter(1) + ROOT %convolution = bf16[2,256,256,32] convolution(%p0, %p1), window={size=3x3}, + dim_labels=b01f_01io->b01f + } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - ConvolutionSpaceToBatchConverter converter; - ASSERT_FALSE(converter.Run(module.get()).ValueOrDie()); + ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/2); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0), op::Slice()); + auto reshape = root->operand(0)->operand(0); + EXPECT_THAT(reshape, op::Reshape()); + EXPECT_THAT(reshape->operand(0), op::Convolution()); + const int64 batch_dim = reshape->operand(0) + ->convolution_dimension_numbers() + .output_batch_dimension(); + // Verify that the transform has increased the batch size. + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); } -TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { +TEST_F(ConvolutionSpaceToBatchConverterTest, Batch4WithStrideAndPad) { string hlo_string = R"( HloModule module ENTRY computation { - %p0 = bf16[1,224,224,3]{3,2,1,0} parameter(0) + %p0 = bf16[4,224,224,3]{3,2,1,0} parameter(0) %p1 = bf16[7,7,3,64]{3,2,1,0} parameter(1) - ROOT %convolution.3 = bf16[1,112,112,64]{3,2,1,0} convolution(%p0, %p1), + ROOT %convolution.3 = bf16[4,112,112,64]{3,2,1,0} convolution(%p0, %p1), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f } )"; @@ -97,7 +108,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { ParseAndReturnVerifiedModule(hlo_string)); auto computation = module->entry_computation(); - ConvolutionSpaceToBatchConverter converter; + ConvolutionSpaceToBatchConverter converter(/*limit_on_batch_size=*/4); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Transpose()); @@ -109,7 +120,7 @@ TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithStrideAndPad) { ->convolution_dimension_numbers() .output_batch_dimension(); - EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 1); + EXPECT_GT(reshape->operand(0)->shape().dimensions(batch_dim), 4); } TEST_F(ConvolutionSpaceToBatchConverterTest, Batch1WithKernelDilation) { diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc index f2d996c4b63..0d34c5b62e9 100644 --- a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc @@ -281,6 +281,22 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); } + if (original_hlo->feature_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } + + if (original_hlo->batch_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } + // Reshard RHS so that each shard computes the partial sum of the full // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() // that reshards LHS. @@ -574,6 +590,22 @@ PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims); } + + if (original_hlo->feature_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_feature_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } + + if (original_hlo->batch_group_count() > 1 && + (lhs.sharding().tile_assignment().dim(dnums.input_batch_dimension()) > + 1 || + rhs.sharding().tile_assignment().dim( + dnums.kernel_output_feature_dimension()) > 1)) { + return nullptr; + } // Reshard LHS by exchanging halo such that each shard computes the partial // sum of the full shape result, and add AllReduce. // diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index 65a7b3f2927..30fc8355402 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -1132,6 +1132,46 @@ StatusOr PartitionDotGroupOnContracting( .hlo(); } +DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount( + const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) { + const auto& dnums = original_hlo->convolution_dimension_numbers(); + DotConvDimsMapping new_dims_mapping; + new_dims_mapping.batch_dims = dims_mapping.batch_dims; + new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; + // Append batch dims. + new_dims_mapping.batch_dims.emplace_back(); + new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension(); + new_dims_mapping.batch_dims.back().rhs = + dnums.kernel_output_feature_dimension(); + new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension(); + new_dims_mapping.batch_dims.back().spatial = -1; + // Setup non contracting dims. + new_dims_mapping.lhs_non_contracting_dims.emplace_back(); + new_dims_mapping.lhs_non_contracting_dims.back().lhs = + dnums.input_batch_dimension(); + new_dims_mapping.rhs_non_contracting_dims.emplace_back(); + new_dims_mapping.rhs_non_contracting_dims.back().rhs = + dnums.kernel_input_feature_dimension(); + return new_dims_mapping; +} + +DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount( + const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) { + const auto& dnums = original_hlo->convolution_dimension_numbers(); + DotConvDimsMapping new_dims_mapping; + new_dims_mapping.batch_dims = dims_mapping.batch_dims; + new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; + new_dims_mapping.contracting_dims = dims_mapping.contracting_dims; + // Append batch dims. + new_dims_mapping.batch_dims.emplace_back(); + new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension(); + new_dims_mapping.batch_dims.back().rhs = + dnums.kernel_output_feature_dimension(); + new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension(); + new_dims_mapping.batch_dims.back().spatial = -1; + return new_dims_mapping; +} + // Recursive partitioning function. If there are partial dimensions matching in // the operands and output, group the devices and recursively partition the // in-group dot. @@ -1210,6 +1250,15 @@ StatusOr PartitionDot( (original_hlo->opcode() == HloOpcode::kConvolution && (original_hlo->batch_group_count() > 1 || original_hlo->feature_group_count() > 1))) { + // Partition with kernel_input_feature_dim > 1 and feature_group_count > 1 + // is not supported. + const auto& dnums = original_hlo->convolution_dimension_numbers(); + if (original_hlo->feature_group_count() > 1 && + rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) > + 1) { + return nullptr; + } + TF_ASSIGN_OR_RETURN( auto partitioned_conv, PartitionConvolution(lhs, rhs, output_base_shape, output_sharding, @@ -1220,6 +1269,56 @@ StatusOr PartitionDot( if (partitioned_conv) { return partitioned_conv; } + + // Recursively partition on different types of dimensions for convolution. + // Case 0.a: Group partitions by feature group count. + if (original_hlo->feature_group_count() > 1 || + original_hlo->batch_group_count() > 1) { + DotConvDimsMapping new_dims_mapping; + if (original_hlo->feature_group_count() > 1) { + new_dims_mapping = + ConvertDimsMappingWithFeatureGroupCount(dims_mapping, original_hlo); + } + + if (original_hlo->batch_group_count() > 1) { + new_dims_mapping = + ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo); + } + + const int64 conv_lhs_contracting_partitions = get_partitions_for_dims( + lhs.sharding(), new_dims_mapping.contracting_dims, 0); + const int64 conv_rhs_contracting_partitions = get_partitions_for_dims( + rhs.sharding(), new_dims_mapping.contracting_dims, 1); + const int64 conv_lhs_non_contracting_partitions = get_partitions_for_dims( + lhs.sharding(), new_dims_mapping.lhs_non_contracting_dims, 0); + const int64 conv_rhs_non_contracting_partitions = get_partitions_for_dims( + rhs.sharding(), new_dims_mapping.rhs_non_contracting_dims, 1); + const int64 conv_lhs_batch_partitions = get_partitions_for_dims( + lhs.sharding(), new_dims_mapping.batch_dims, 0); + const int64 conv_rhs_batch_partitions = get_partitions_for_dims( + rhs.sharding(), new_dims_mapping.batch_dims, 1); + const int64 conv_output_batch_partitions = get_partitions_for_dims( + output_sharding, new_dims_mapping.batch_dims, 2); + if ((conv_lhs_batch_partitions == conv_output_batch_partitions || + conv_rhs_batch_partitions == conv_output_batch_partitions) && + conv_output_batch_partitions > 1) { + TF_ASSIGN_OR_RETURN( + auto try_partitioned_conv, + PartitionDotGroupOnBatch( + lhs, rhs, output_base_shape, output_sharding, new_dims_mapping, + num_partitions, conv_lhs_contracting_partitions, + conv_rhs_contracting_partitions, + conv_lhs_non_contracting_partitions, + conv_rhs_non_contracting_partitions, create_sharded_dot, + conv_window, module, original_hlo, + require_matching_devices_to_group, options, b, + windowed_dot_general_loops)); + if (try_partitioned_conv) { + return try_partitioned_conv; + } + } + return nullptr; + } } TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 43e6dbf5d51..84c40e888a3 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -5708,6 +5708,179 @@ ENTRY entry { op::Shape("f32[8,801,1,1024]"))); } +TEST_F(SpmdPartitioningTest, + PartitionConvGroupOnFeatureGroupCount_RHSPartialReplicate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _), + rhs), + op::Shape("f32[16, 401, 1, 512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvGroupOnFeatureGroupCount_RHSAlignWithOutput) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[5,1,1,1024] parameter(1), sharding={replicated} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto rhs = + AllOf(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape()), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(_, op::Concatenate(left_halo, lhs, right_halo), _), + rhs), + op::Shape("f32[16, 401, 1, 512]"))); +} + +TEST_F(SpmdPartitioningTest, + PartitionConvGroupOnFeatureGroupCount_LHSAlignWithOutput) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate} + %rhs = f32[5,1,1,1024] parameter(1) + %rhs.copy = f32[5,1,1,1024] copy(%rhs), + sharding={devices=[1,1,1,2,2]0,2,1,3 last_tile_dim_replicate} + ROOT %conv = f32[16,801,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=b01f_01io->b01f,feature_group_count=1024, + window={size=5x1 pad=2_2x0_0}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[8,801,1,1024]")); + auto resharded_lhs = + AllOf(op::Reshape(op::Transpose(op::AllToAll(op::Reshape( + op::Pad(op::DynamicSlice(lhs, op::Constant(), op::Constant(), + op::Constant(), op::Reshape()), + op::Constant()))))), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(resharded_lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(resharded_lhs))); + auto rhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Constant(), + op::Constant(), op::Reshape())), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT( + root, + AllOf( + op::Convolution( + op::Select( + _, op::Concatenate(left_halo, resharded_lhs, right_halo), _), + rhs), + op::Shape("f32[16, 401, 1, 512]"))); +} + +TEST_F(SpmdPartitioningTest, PartitionConvGroupOnBatchGroupCount) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[16,801,1,1024] parameter(0) + %lhs.copy = f32[16,801,1,1024] copy(%lhs), + sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[16,801,1,1024] parameter(1) + %rhs.copy = f32[16,801,1,1024] copy(%rhs), + sharding={devices=[1,2,1,2]0,1,2,3} + ROOT %conv = f32[5,1,1,1024] convolution(%lhs.copy, %rhs.copy), + dim_labels=f01b_i01o->01bf,batch_group_count=1024, + window={size=801x1 pad=2_2x0_0}, + sharding={devices=[1,1,1,2,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Select(_, + op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + _), + op::Shape("f32[16,401,1,512]")); + auto left_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto right_halo = AllOf(op::Shape("f32[16,2, 1, 512]"), + op::CollectivePermute(op::Slice(lhs))); + auto rhs = AllOf(op::Copy(op::DynamicSlice( + op::Pad(op::Parameter(), op::Constant()), op::Constant(), + op::Reshape(), op::Constant(), op::Reshape())), + op::Shape("f32[16,401,1,512]")); + auto conv = AllOf(op::Convolution(op::Concatenate(left_halo, lhs, right_halo), + op::Select(_, rhs, _)), + op::Shape("f32[5,1,1,512]")); + EXPECT_THAT(root, AllOf(op::CollectivePermute(op::AllReduce(conv)), + op::Shape("f32[5,1,1,512]"))); +} + TEST_F(SpmdPartitioningTest, PartitionConvWithFeatureGroupCountAlignOuputWithRHS) { const char* const hlo_string = R"( diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 0fd64209152..913bfed926a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -169,8 +169,7 @@ Status TransferManager::TransferArrayToDeviceAsync( "%d < %d", dest.size(), GetByteSizeRequirement(on_device_shape)); } - ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, - stream->parent()->platform(), + ShapedBuffer shaped_buffer(on_device_shape, stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); return TransferLiteralToDevice(stream, literal, shaped_buffer, @@ -194,8 +193,7 @@ void TransferManager::TransferArrayFromDevice( "%d < %d", source.size(), GetByteSizeRequirement(shape))); } - ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, - stream->parent()->platform(), + ShapedBuffer shaped_buffer(shape, stream->parent()->platform(), stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); return TransferLiteralFromDevice(stream, shaped_buffer, literal, @@ -406,8 +404,8 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); - ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape), - allocator, device_ordinal); + ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator, + device_ordinal); // Allocate an appropriate sized buffer for each element in the shape // including the tuple pointer arrays. diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index c0670d26eee..c49d7d899e7 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -51,7 +51,11 @@ class TransferManager { // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user // needing to consider device-specific behaviors. virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { - return host_shape; + // Strips off any preexisting tiling or memory space information. + // TODO(phawkins): fix clients not to including tiling or memory space + // information in shapes passed to this function and turn this into an + // assertion. + return ShapeUtil::DeviceShapeToHostShape(host_shape); } // Base class for specifying platform specific transfer metadata that can be @@ -189,6 +193,7 @@ class TransferManager { // shapes, and returns static shapes with dynamic shapes updated. // The shape of the buffer also have to be compatible with the host shape and // device shape. + // TODO(b/170310047): remove host_shape. virtual Status ReadDynamicShapes(se::Stream* stream, ShapedBuffer* device_buffer, Shape* host_shape, Shape* device_shape); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index c9799453939..614dfc4ffe6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -35,17 +35,46 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, const TransposeFolding::TransposableGemmOperandsFn& transposable_gemm_operands) { - if (HloOpcode::kDot != dot.opcode() || - dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) { + if (HloOpcode::kDot != dot.opcode()) { return {}; } + if (!absl::c_equal(dot.dot_dimension_numbers().lhs_batch_dimensions(), + dot.dot_dimension_numbers().rhs_batch_dimensions())) { + return {}; + } + + int64 num_batch_dims = + dot.dot_dimension_numbers().lhs_batch_dimensions_size(); + int64 expected_rank = 2 + num_batch_dims; + auto is_r2_transpose = [&](const HloInstruction& transpose) { + if (transpose.opcode() != HloOpcode::kTranspose) { + return false; + } + const auto& transpose_dims = transpose.dimensions(); + if (transpose_dims.size() != expected_rank) { + return false; + } + + // Check that the transpose doesn't touch any batch dimensions, but does + // transpose the non-batch ones. + for (int64 i = 0; i != expected_rank; ++i) { + bool is_batch = absl::c_linear_search( + dot.dot_dimension_numbers().lhs_batch_dimensions(), + transpose_dims[i]); + if ((transpose_dims[i] == i) != is_batch) { + return false; + } + } + return true; + }; + TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < dot.operand_count(); ++i) { auto& operand = *dot.operand(i); - if (operand.IsRank2Transpose()) { + if (is_r2_transpose(operand)) { operand_set.push_back(i); - } else if (operand.shape().rank() != 2) { + } else if (operand.shape().rank() != expected_rank) { return {}; } } @@ -84,25 +113,25 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { HloInstruction* new_lhs = dot->mutable_operand(0); HloInstruction* new_rhs = dot->mutable_operand(1); - CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0); - CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0); CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1); CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1); for (int64 operand_index : pair.second) { - // We've checked that there aren't any batch dimensions and that the inputs - // are rank 2, and shape inference guarantees that there is exactly one - // contracting dimension. + // We checked that the batch dimensions are not touched by the transpose, + // and shape inference guarantees that there is exactly one contracting + // dimension. if (operand_index == 0) { CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose); new_dim_numbers.set_lhs_contracting_dimensions( - 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0)); + 0, + new_lhs->dimensions(new_dim_numbers.lhs_contracting_dimensions(0))); new_lhs = new_lhs->mutable_operand(0); } else { CHECK_EQ(operand_index, 1); CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose); new_dim_numbers.set_rhs_contracting_dimensions( - 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0)); + 0, + new_rhs->dimensions(new_dim_numbers.rhs_contracting_dimensions(0))); new_rhs = new_rhs->mutable_operand(0); } } diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 8a2112c87dc..3fe69d22e9c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -42,7 +42,7 @@ namespace { class TransposeFoldingTest : public HloTestBase { protected: - void FoldTranspose(HloModule* module) { + bool FoldTranspose(HloModule* module) { TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { @@ -52,7 +52,9 @@ class TransposeFoldingTest : public HloTestBase { const TransposeFolding::OperandIndices& candidate_operands) { return candidate_operands; }); - EXPECT_IS_OK(transpose_folding.Run(module).status()); + auto folded = transpose_folding.Run(module); + EXPECT_IS_OK(folded.status()); + return *folded; } }; @@ -465,5 +467,81 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } +TEST_F(TransposeFoldingTest, FoldBatchDotTranspose) { + string hlo_string = R"( +HloModule FoldBatchDotTranspose + +ENTRY entry_computation { + x = f32[7,7,2,3]{3,2,1,0} parameter(0) + y = f32[7,7,2,3]{3,2,1,0} parameter(1) + transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,3,2} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_TRUE(FoldTranspose(module.get())); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3)); +} + +TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeBatch) { + string hlo_string = R"( +HloModule NoFoldBatchDotTransposeBatch + +ENTRY entry_computation { + x = f32[7,7,2,3]{3,2,1,0} parameter(0) + y = f32[7,7,2,3]{3,2,1,0} parameter(1) + transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={1,0,3,2} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_FALSE(FoldTranspose(module.get())); +} + +TEST_F(TransposeFoldingTest, FoldBatchDotTransposeNonContiguousBatch) { + string hlo_string = R"( +HloModule FoldBatchDotTransposeNonContiguousBatch + +ENTRY entry_computation { + x = f32[7,2,7,3]{3,2,1,0} parameter(0) + y = f32[7,2,7,3]{3,2,1,0} parameter(1) + transpose = f32[7,3,7,2]{3,2,1,0} transpose(y), dimensions={0,3,2,1} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={1}, lhs_batch_dims={0,2}, rhs_batch_dims={0,2} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_TRUE(FoldTranspose(module.get())); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3)); +} + +TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeIdentity) { + string hlo_string = R"( +HloModule NoFoldBatchDotTransposeIdentity + +ENTRY entry_computation { + x = f32[7,7,2,3]{3,2,1,0} parameter(0) + y = f32[7,7,3,2]{3,2,1,0} parameter(1) + transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,2,3} + ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3}, + rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_FALSE(FoldTranspose(module.get())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 0833919b124..238879ebdc0 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1074,7 +1074,8 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, } } - return std::make_tuple(true, deleted_indices, inserted_indices); + return std::make_tuple(!deleted_indices.empty() || !inserted_indices.empty(), + deleted_indices, inserted_indices); } /* static */ std::vector> @@ -1623,4 +1624,14 @@ static Shape MergeDimensions(absl::Span segs, return absl::nullopt; } +Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { + ForEachMutableSubshape(&s, [](Shape* subshape, const ShapeIndex& index) { + if (subshape->IsArray()) { + subshape->mutable_layout()->clear_tiles(); + subshape->mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace); + } + }); + return s; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 3f69a8b0aca..5a5695d32ee 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -783,6 +783,10 @@ class ShapeUtil { static absl::optional> FindTranspose021(const Shape& a, const Shape& b); + // Strips device-specific information, namely tiling and memory-space + // information, from a shape. + static Shape DeviceShapeToHostShape(Shape s); + private: // Validates the shape size is sane. This makes sure it's safe to do // calculations in int64 without overflowing. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4e2030667ee..414b53d4f67 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -558,6 +558,8 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1))); EXPECT_FALSE(std::get<0>( ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2))); + EXPECT_FALSE(std::get<0>( + ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape0))); } TEST(ShapeUtilTest, ForEachIndex) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 19a7b69b98c..98ed49ad76a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -71,9 +71,9 @@ cc_library( hdrs = ["manifest_checking_test.h"], deps = [ ":test_macros_header", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], @@ -171,8 +171,8 @@ cc_library( "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -233,8 +233,8 @@ cc_library( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -312,7 +312,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", "//third_party/eigen3", "@com_google_absl//absl/memory", @@ -509,8 +509,8 @@ xla_test( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -555,8 +555,8 @@ xla_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1458,8 +1458,8 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -1915,8 +1915,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -1924,9 +1924,14 @@ xla_test( name = "concat_test", srcs = ["concat_test.cc"], deps = [ + ":client_library_test_base", + ":hlo_test_base", + ":literal_test_util", ":test_macros_header", + ":xla_internal_test_main", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -1934,9 +1939,6 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -1954,8 +1956,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", ], @@ -1984,8 +1986,8 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -2045,8 +2047,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", ], ) @@ -2400,8 +2402,8 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", ], ) @@ -2523,8 +2525,8 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:stream_pool", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory_allocator", ], ) diff --git a/tensorflow/compiler/xla/tests/buffer_donation_test.cc b/tensorflow/compiler/xla/tests/buffer_donation_test.cc index f78083fe2af..7915737178d 100644 --- a/tensorflow/compiler/xla/tests/buffer_donation_test.cc +++ b/tensorflow/compiler/xla/tests/buffer_donation_test.cc @@ -119,8 +119,7 @@ class BufferDonationTest : public HloTestBase { } }); - args.emplace_back( - ExecutionInput(std::move(owned_buffers), argument_literal.shape())); + args.emplace_back(ExecutionInput(std::move(owned_buffers))); } StatusOr output_status = diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 4f5b525a342..9df83e30ad4 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -34,6 +36,7 @@ namespace xla { namespace { using ConcatTest = ClientLibraryTestBase; +using ConcatTestHlo = HloTestBase; using ::testing::HasSubstr; // Concatenate expects at least one argument. @@ -518,6 +521,250 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { ComputeAndCompareR1(&builder, expected, {a_data.get()}); } +// TODO(b/169314478): Enable the test when the slow compilation is fixed. +XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule jit_broken.874 + +primitive_computation_add.866 { + parameter.867 = f32[] parameter(0) + parameter.868 = f32[] parameter(1) + ROOT add.869 = f32[] add(parameter.867, parameter.868) +} + +ENTRY jit_broken.874 { + parameter.38 = f32[4,2]{1,0} parameter(0) + reshape.723 = f32[4,2,1]{2,1,0} reshape(parameter.38) + reshape.724 = f32[4,2,1]{2,1,0} reshape(parameter.38) + concatenate.42 = f32[4,2,2]{2,1,0} concatenate(reshape.723, reshape.724), dimensions={2} + slice.351 = f32[4,1,2]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:2]} + reshape.1058 = f32[4,2]{1,0} reshape(slice.351) + slice.352 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [1:2]} + reshape.1059 = f32[4]{0} reshape(slice.352) + slice.353 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1060 = f32[4]{0} reshape(slice.353) + add.124 = f32[4]{0} add(reshape.1059, reshape.1060) + slice.354 = f32[4,1]{1,0} slice(reshape.1058), slice={[0:4], [0:1]} + reshape.1061 = f32[4]{0} reshape(slice.354) + slice.379 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1062 = f32[4]{0} reshape(slice.379) + add.89 = f32[4]{0} add(reshape.1061, reshape.1062) + subtract.126 = f32[4]{0} subtract(add.124, add.89) + is-finite.127 = pred[4]{0} is-finite(subtract.126) + not.128 = pred[4]{0} not(is-finite.127) + abs.129 = f32[4]{0} abs(subtract.126) + constant.130 = f32[] constant(inf) + broadcast.131 = f32[4]{0} broadcast(constant.130), dimensions={} + compare.132 = pred[4]{0} compare(abs.129, broadcast.131), direction=EQ, type=UNSIGNED + not.133 = pred[4]{0} not(compare.132) + and.134 = pred[4]{0} and(not.128, not.133) + add.135 = f32[4]{0} add(add.124, add.89) + maximum.125 = f32[4]{0} maximum(add.124, add.89) + abs.136 = f32[4]{0} abs(subtract.126) + negate.137 = f32[4]{0} negate(abs.136) + exponential.138 = f32[4]{0} exponential(negate.137) + log-plus-one.139 = f32[4]{0} log-plus-one(exponential.138) + add.140 = f32[4]{0} add(maximum.125, log-plus-one.139) + select.141 = f32[4]{0} select(and.134, add.135, add.140) + slice.356 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1064 = f32[4]{0} reshape(slice.356) + add.214 = f32[4]{0} add(select.141, reshape.1064) + slice.380 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1066 = f32[4]{0} reshape(slice.380) + add.179 = f32[4]{0} add(select.141, reshape.1066) + subtract.216 = f32[4]{0} subtract(add.214, add.179) + is-finite.217 = pred[4]{0} is-finite(subtract.216) + not.218 = pred[4]{0} not(is-finite.217) + abs.219 = f32[4]{0} abs(subtract.216) + constant.220 = f32[] constant(inf) + broadcast.221 = f32[4]{0} broadcast(constant.220), dimensions={} + compare.222 = pred[4]{0} compare(abs.219, broadcast.221), direction=EQ, type=UNSIGNED + not.223 = pred[4]{0} not(compare.222) + and.224 = pred[4]{0} and(not.218, not.223) + add.225 = f32[4]{0} add(add.214, add.179) + maximum.215 = f32[4]{0} maximum(add.214, add.179) + abs.226 = f32[4]{0} abs(subtract.216) + negate.227 = f32[4]{0} negate(abs.226) + exponential.228 = f32[4]{0} exponential(negate.227) + log-plus-one.229 = f32[4]{0} log-plus-one(exponential.228) + add.230 = f32[4]{0} add(maximum.215, log-plus-one.229) + select.231 = f32[4]{0} select(and.224, add.225, add.230) + slice.359 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1068 = f32[4]{0} reshape(slice.359) + add.304 = f32[4]{0} add(select.231, reshape.1068) + slice.381 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1070 = f32[4]{0} reshape(slice.381) + add.269 = f32[4]{0} add(select.231, reshape.1070) + subtract.306 = f32[4]{0} subtract(add.304, add.269) + is-finite.307 = pred[4]{0} is-finite(subtract.306) + not.308 = pred[4]{0} not(is-finite.307) + abs.309 = f32[4]{0} abs(subtract.306) + constant.310 = f32[] constant(inf) + broadcast.311 = f32[4]{0} broadcast(constant.310), dimensions={} + compare.312 = pred[4]{0} compare(abs.309, broadcast.311), direction=EQ, type=UNSIGNED + not.313 = pred[4]{0} not(compare.312) + and.314 = pred[4]{0} and(not.308, not.313) + add.315 = f32[4]{0} add(add.304, add.269) + maximum.305 = f32[4]{0} maximum(add.304, add.269) + abs.316 = f32[4]{0} abs(subtract.306) + negate.317 = f32[4]{0} negate(abs.316) + exponential.318 = f32[4]{0} exponential(negate.317) + log-plus-one.319 = f32[4]{0} log-plus-one(exponential.318) + add.320 = f32[4]{0} add(maximum.305, log-plus-one.319) + select.321 = f32[4]{0} select(and.314, add.315, add.320) + slice.362 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1072 = f32[4]{0} reshape(slice.362) + add.394 = f32[4]{0} add(select.321, reshape.1072) + slice.382 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1074 = f32[4]{0} reshape(slice.382) + add.359 = f32[4]{0} add(select.321, reshape.1074) + subtract.396 = f32[4]{0} subtract(add.394, add.359) + is-finite.397 = pred[4]{0} is-finite(subtract.396) + not.398 = pred[4]{0} not(is-finite.397) + abs.399 = f32[4]{0} abs(subtract.396) + constant.400 = f32[] constant(inf) + broadcast.401 = f32[4]{0} broadcast(constant.400), dimensions={} + compare.402 = pred[4]{0} compare(abs.399, broadcast.401), direction=EQ, type=UNSIGNED + not.403 = pred[4]{0} not(compare.402) + and.404 = pred[4]{0} and(not.398, not.403) + add.405 = f32[4]{0} add(add.394, add.359) + maximum.395 = f32[4]{0} maximum(add.394, add.359) + abs.406 = f32[4]{0} abs(subtract.396) + negate.407 = f32[4]{0} negate(abs.406) + exponential.408 = f32[4]{0} exponential(negate.407) + log-plus-one.409 = f32[4]{0} log-plus-one(exponential.408) + add.410 = f32[4]{0} add(maximum.395, log-plus-one.409) + select.411 = f32[4]{0} select(and.404, add.405, add.410) + slice.365 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1076 = f32[4]{0} reshape(slice.365) + add.484 = f32[4]{0} add(select.411, reshape.1076) + slice.383 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1078 = f32[4]{0} reshape(slice.383) + add.449 = f32[4]{0} add(select.411, reshape.1078) + subtract.486 = f32[4]{0} subtract(add.484, add.449) + is-finite.487 = pred[4]{0} is-finite(subtract.486) + not.488 = pred[4]{0} not(is-finite.487) + abs.489 = f32[4]{0} abs(subtract.486) + constant.490 = f32[] constant(inf) + broadcast.491 = f32[4]{0} broadcast(constant.490), dimensions={} + compare.492 = pred[4]{0} compare(abs.489, broadcast.491), direction=EQ, type=UNSIGNED + not.493 = pred[4]{0} not(compare.492) + and.494 = pred[4]{0} and(not.488, not.493) + add.495 = f32[4]{0} add(add.484, add.449) + maximum.485 = f32[4]{0} maximum(add.484, add.449) + abs.496 = f32[4]{0} abs(subtract.486) + negate.497 = f32[4]{0} negate(abs.496) + exponential.498 = f32[4]{0} exponential(negate.497) + log-plus-one.499 = f32[4]{0} log-plus-one(exponential.498) + add.500 = f32[4]{0} add(maximum.485, log-plus-one.499) + select.501 = f32[4]{0} select(and.494, add.495, add.500) + slice.368 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1080 = f32[4]{0} reshape(slice.368) + add.574 = f32[4]{0} add(select.501, reshape.1080) + slice.384 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1082 = f32[4]{0} reshape(slice.384) + add.539 = f32[4]{0} add(select.501, reshape.1082) + subtract.576 = f32[4]{0} subtract(add.574, add.539) + is-finite.577 = pred[4]{0} is-finite(subtract.576) + not.578 = pred[4]{0} not(is-finite.577) + abs.579 = f32[4]{0} abs(subtract.576) + constant.580 = f32[] constant(inf) + broadcast.581 = f32[4]{0} broadcast(constant.580), dimensions={} + compare.582 = pred[4]{0} compare(abs.579, broadcast.581), direction=EQ, type=UNSIGNED + not.583 = pred[4]{0} not(compare.582) + and.584 = pred[4]{0} and(not.578, not.583) + add.585 = f32[4]{0} add(add.574, add.539) + maximum.575 = f32[4]{0} maximum(add.574, add.539) + abs.586 = f32[4]{0} abs(subtract.576) + negate.587 = f32[4]{0} negate(abs.586) + exponential.588 = f32[4]{0} exponential(negate.587) + log-plus-one.589 = f32[4]{0} log-plus-one(exponential.588) + add.590 = f32[4]{0} add(maximum.575, log-plus-one.589) + select.591 = f32[4]{0} select(and.584, add.585, add.590) + slice.371 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1084 = f32[4]{0} reshape(slice.371) + add.664 = f32[4]{0} add(select.591, reshape.1084) + slice.385 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1086 = f32[4]{0} reshape(slice.385) + add.629 = f32[4]{0} add(select.591, reshape.1086) + subtract.666 = f32[4]{0} subtract(add.664, add.629) + is-finite.667 = pred[4]{0} is-finite(subtract.666) + not.668 = pred[4]{0} not(is-finite.667) + abs.669 = f32[4]{0} abs(subtract.666) + constant.670 = f32[] constant(inf) + broadcast.671 = f32[4]{0} broadcast(constant.670), dimensions={} + compare.672 = pred[4]{0} compare(abs.669, broadcast.671), direction=EQ, type=UNSIGNED + not.673 = pred[4]{0} not(compare.672) + and.674 = pred[4]{0} and(not.668, not.673) + add.675 = f32[4]{0} add(add.664, add.629) + maximum.665 = f32[4]{0} maximum(add.664, add.629) + abs.676 = f32[4]{0} abs(subtract.666) + negate.677 = f32[4]{0} negate(abs.676) + exponential.678 = f32[4]{0} exponential(negate.677) + log-plus-one.679 = f32[4]{0} log-plus-one(exponential.678) + add.680 = f32[4]{0} add(maximum.665, log-plus-one.679) + select.681 = f32[4]{0} select(and.674, add.675, add.680) + slice.374 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1088 = f32[4]{0} reshape(slice.374) + add.754 = f32[4]{0} add(select.681, reshape.1088) + slice.386 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1090 = f32[4]{0} reshape(slice.386) + add.719 = f32[4]{0} add(select.681, reshape.1090) + subtract.756 = f32[4]{0} subtract(add.754, add.719) + is-finite.757 = pred[4]{0} is-finite(subtract.756) + not.758 = pred[4]{0} not(is-finite.757) + abs.759 = f32[4]{0} abs(subtract.756) + constant.760 = f32[] constant(inf) + broadcast.761 = f32[4]{0} broadcast(constant.760), dimensions={} + compare.762 = pred[4]{0} compare(abs.759, broadcast.761), direction=EQ, type=UNSIGNED + not.763 = pred[4]{0} not(compare.762) + and.764 = pred[4]{0} and(not.758, not.763) + add.765 = f32[4]{0} add(add.754, add.719) + maximum.755 = f32[4]{0} maximum(add.754, add.719) + abs.766 = f32[4]{0} abs(subtract.756) + negate.767 = f32[4]{0} negate(abs.766) + exponential.768 = f32[4]{0} exponential(negate.767) + log-plus-one.769 = f32[4]{0} log-plus-one(exponential.768) + add.770 = f32[4]{0} add(maximum.755, log-plus-one.769) + select.771 = f32[4]{0} select(and.764, add.765, add.770) + slice.377 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [1:2]} + reshape.1092 = f32[4]{0} reshape(slice.377) + add.844 = f32[4]{0} add(select.771, reshape.1092) + slice.387 = f32[4,1,1]{2,1,0} slice(concatenate.42), slice={[0:4], [0:1], [0:1]} + reshape.1094 = f32[4]{0} reshape(slice.387) + add.809 = f32[4]{0} add(select.771, reshape.1094) + subtract.846 = f32[4]{0} subtract(add.844, add.809) + is-finite.847 = pred[4]{0} is-finite(subtract.846) + not.848 = pred[4]{0} not(is-finite.847) + abs.849 = f32[4]{0} abs(subtract.846) + constant.850 = f32[] constant(inf) + broadcast.851 = f32[4]{0} broadcast(constant.850), dimensions={} + compare.852 = pred[4]{0} compare(abs.849, broadcast.851), direction=EQ, type=UNSIGNED + not.853 = pred[4]{0} not(compare.852) + and.854 = pred[4]{0} and(not.848, not.853) + add.855 = f32[4]{0} add(add.844, add.809) + maximum.845 = f32[4]{0} maximum(add.844, add.809) + abs.856 = f32[4]{0} abs(subtract.846) + negate.857 = f32[4]{0} negate(abs.856) + exponential.858 = f32[4]{0} exponential(negate.857) + log-plus-one.859 = f32[4]{0} log-plus-one(exponential.858) + add.860 = f32[4]{0} add(maximum.845, log-plus-one.859) + select.861 = f32[4]{0} select(and.854, add.855, add.860) + constant.865 = f32[] constant(0) + reduce.2 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866 + reduce.3 = f32[] reduce(select.861, constant.865), dimensions={0}, to_apply=primitive_computation_add.866 + add.77 = f32[] add(reduce.2, reduce.3) + constant.719 = f32[] constant(0.125) + multiply = f32[] multiply(add.77, constant.719) + ROOT tuple.873 = (f32[]) tuple(multiply) +})") + .ConsumeValueOrDie(); + auto input_array = absl::make_unique>(4, 2); + input_array->FillUnique(1.0f); + auto input = LiteralUtil::CreateR2FromArray2D(*input_array); + EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt)); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; @@ -578,7 +825,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { {x_data.get(), y_data.get()}, ErrorSpec(1e-4)); } -// Test that the HLO optimization to replace a concat of a bradcasted scalar +// Test that the HLO optimization to replace a concat of a broadcasted scalar // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); @@ -604,7 +851,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { {x_data.get(), y_data.get(), z_data.get()}, ErrorSpec(1e-4)); } -// Test that the HLO optimization to replace a concat of a bradcasted scalar +// Test that the HLO optimization to replace a concat of a broadcasted scalar // produces the correct result in rank 3 with both high and low padding in // different dimensions. XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index bdbf7044fe2..aa02deb7bca 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -266,7 +266,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf_headers", ], diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 4034e5fdd27..6e7deda13f0 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -374,7 +374,9 @@ std::pair SplitF64ToF32(double x) { // Only values within the range of F32 are supported, unless it is infinity. // Small values with large negative exponents would be rounded to zero. - CHECK(std::isfinite(x_f32)) << x; + if (!std::isfinite(x_f32)) { + LOG(WARNING) << "Out of range F64 constant detected: " << x; + } // The high float is simply the double rounded to the nearest float. Because // we are rounding to nearest with ties to even, the error introduced in diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc index 69acc59d8a2..5477dfba18d 100644 --- a/tensorflow/compiler/xla/util_test.cc +++ b/tensorflow/compiler/xla/util_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" +#include #include #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/bfloat16.h" namespace xla { namespace { @@ -103,5 +105,26 @@ TEST(UtilTest, SanitizeFileName) { EXPECT_EQ(SanitizeFileName("/A\\B[C]"), "_A_B_C_"); } +TEST(UtilTest, RoundTripFpToString) { + EXPECT_EQ(RoundTripFpToString(std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString(-std::numeric_limits::quiet_NaN()), + "-nan"); + EXPECT_EQ(RoundTripFpToString( + std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString( + -std::numeric_limits::quiet_NaN()), + "-nan"); + EXPECT_EQ(RoundTripFpToString(std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString(-std::numeric_limits::quiet_NaN()), + "-nan"); + EXPECT_EQ(RoundTripFpToString(std::numeric_limits::quiet_NaN()), + "nan"); + EXPECT_EQ(RoundTripFpToString(-std::numeric_limits::quiet_NaN()), + "-nan"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 2d311dd2f70..7da8d2cb84d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -245,6 +245,13 @@ message ComputationStats { double transcendental_count = 2; } +// The type optimization profiles in use. +enum ProfileType { + INVALID = 0; + WINDOW = 1; + FLAG = 2; +} + // Symbolization metadata for HLO Instructions. // // This metadata is used for debugging XLA code generation, as well as @@ -268,6 +275,8 @@ message OpMetadata { // e.g. it could be the file and line of user code that generated the op. string source_file = 3; int32 source_line = 4; + + repeated ProfileType profile_type = 5; } // Profile data from the execution of a computation. diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index c3d3689b587..1b699e7d8df 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -74,7 +74,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor", "//tensorflow/stream_executor:device_memory_allocator", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4c2a6e9ca05..f5bbff7827d 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -47,11 +47,9 @@ # Public mobile targets, e.g. for Android: # # filegroup ":android_proto_srcs" - Protos -# filegroup ":android_srcs" - Core sources # cc_library ":portable_tensorflow_lib" - Native library # cc_library ":portable_tensorflow_lib_lite" - Native library, without ops, # supporting SELECTIVE_REGISTRATION feature. -# portable_proto_library ":portable_proto_lib" (Google-internal) # # Note that :framework and :lib have incomplete transitive dependencies (they # declare but do not define some symbols) if framework_shared_object=True @@ -69,10 +67,9 @@ load( "if_chromiumos", "if_cuda_or_rocm", "if_ios", + "if_libtpu", "if_mobile", "if_not_windows", - "if_tpu", - "tf_android_core_proto_headers", "tf_cc_test", "tf_cc_test_mkl", "tf_cc_tests", @@ -80,26 +77,18 @@ load( "tf_cuda_library", "tf_defines_nortti_if_lite_protos", "tf_features_nomodules_if_mobile", - "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_lite_protos", - "tf_portable_full_lite_protos", "transitive_hdrs", ) # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cc_header_only_library") -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") - # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "if_nccl") # buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tensorflow_opensource_extra_deps") - -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") +load("@tf_toolchains//macros:cpp.bzl", "tensorflow_opensource_extra_deps") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu") @@ -110,9 +99,6 @@ load("//tensorflow:tensorflow.bzl", "tf_monitoring_framework_deps") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "genrule") - # For platform specific build config load( "//tensorflow/core/platform:build_config.bzl", @@ -146,7 +132,6 @@ load( "if_mkl", "mkl_deps", ) -# Placeholder for Google-internal load statements. package( default_visibility = [ @@ -170,14 +155,8 @@ package_group( # Export the BUILD file so automated tooling can check licenses exports_files([ "BUILD", - "ops/ops.pbtxt", ]) -package_group( - name = "experimental_access", - packages = ["//tensorflow/core/common_runtime/..."], -) - # Authorized users go here. package_group(name = "friends") @@ -188,7 +167,6 @@ package_group(name = "friends") # # Note that some protos are in neither additional_core_proto_srcs nor this # filegroup; e.g. ones with individual proto_library targets. -# LINT.IfChange COMMON_PROTO_SRCS = [ "//tensorflow/core/protobuf:bfc_memory_map.proto", "//tensorflow/core/protobuf:config.proto", @@ -251,7 +229,6 @@ ERROR_CODES_PROTO_SRCS = [ "//tensorflow/core/protobuf:error_codes.proto", "//tensorflow/core/lib/core:error_codes.proto", ] -# LINT.ThenChange(//tensorflow/core/portable_proto_config.asciipb) CORE_PROTO_SRCS = COMMON_PROTO_SRCS + EXAMPLE_PROTO_SRCS + FRAMEWORK_PROTO_SRCS + UTIL_PROTO_SRCS + PROFILER_PROTO_SRCS + ERROR_CODES_PROTO_SRCS @@ -259,6 +236,7 @@ tf_proto_library( name = "protos_all", srcs = [], cc_api_version = 2, + create_go_proto = False, make_default_target_header_only = True, protodeps = [ "//tensorflow/core/example:protos_all", @@ -297,7 +275,6 @@ cc_library( hdrs = ["//tensorflow/core/platform:base_hdrs"], copts = tf_copts(), tags = ["avoid_dep"], - visibility = [":__subpackages__"], deps = [ "//tensorflow/core/platform", "//tensorflow/core/platform:byte_order", @@ -319,11 +296,6 @@ alias( visibility = ["//tensorflow/core/kernels:friends"], ) -alias( - name = "human_readable_json", - actual = "//tensorflow/core/platform:human_readable_json", -) - # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -385,22 +357,6 @@ cc_library( ], ) -# APIs defined in lib_experimental are for experimental usage and may be -# subject to change. Its visibility is limited to selected packages. -cc_library( - name = "lib_experimental", - hdrs = [ - "//tensorflow/core/lib/core:legacy_lib_core_threadpool_options_header", - ], - visibility = [ - ":experimental_access", - "//tensorflow/cc:__pkg__", - ], - deps = [ - ":lib", - ], -) - alias( name = "feature_util", actual = "//tensorflow/core/example:feature_util", @@ -469,7 +425,9 @@ tf_cuda_library( "//tensorflow/core/framework:control_flow.h", # TODO(josh11b): Make internal? "//tensorflow/core/framework:dataset.h", "//tensorflow/core/framework:dataset_stateful_op_allowlist.h", + "//tensorflow/core/framework:device.h", "//tensorflow/core/framework:device_base.h", + "//tensorflow/core/framework:device_factory.h", "//tensorflow/core/framework:function.h", "//tensorflow/core/framework:function_handle_cache.h", "//tensorflow/core/framework:graph_def_util.h", @@ -534,30 +492,6 @@ tf_cuda_library( ], ) -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "allocator", - actual = "//tensorflow/core/framework:allocator", - visibility = ["//visibility:public"], -) - -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "allocator_registry_impl", - actual = "//tensorflow/core/framework:allocator_registry_impl", - visibility = ["//visibility:public"], -) - -alias( - name = "overflow", - actual = "//tensorflow/core/util:overflow", -) - -alias( - name = "exec_on_stall", - actual = "//tensorflow/core/util:exec_on_stall", -) - alias( name = "ptr_util", actual = "//tensorflow/core/util:ptr_util", @@ -612,157 +546,7 @@ cc_library( ], ) -# Generates library per group of ops. -tf_gen_op_libs( - is_external = False, - op_lib_names = [ - "batch_ops", - "bitwise_ops", - "boosted_trees_ops", - "tensor_forest_ops", - "candidate_sampling_ops", - "checkpoint_ops", - "clustering_ops", - "collective_ops", - "control_flow_ops", - "count_ops", - "ctc_ops", - "data_flow_ops", - "dataset_ops", - "decode_proto_ops", - "encode_proto_ops", - "experimental_dataset_ops", - "function_ops", - "functional_ops", - "image_ops", - "io_ops", - "linalg_ops", - "list_ops", - "map_ops", - "lookup_ops", - "manip_ops", - "math_ops", - "mkl_nn_ops", - "nccl_ops", - "nn_ops", - "no_op", - "parsing_ops", - "random_grad", - "random_ops", - "special_math_ops", - "stateful_random_ops", - "remote_fused_graph_ops", - "rnn_ops", - "rpc_ops", - "scoped_allocator_ops", - "sdca_ops", - "set_ops", - "script_ops", - "sendrecv_ops", - "sparse_csr_matrix_ops", - "sparse_ops", - "spectral_ops", - "state_ops", - "stateless_random_ops", - "stateless_random_ops_v2", - "summary_ops", - "training_ops", - ], - deps = [ - ":lib", - ":protos_all_cc", - ], -) - -tf_gen_op_libs( - is_external = False, - op_lib_names = [ - "logging_ops", - ], - deps = [ - ":lib", - ":protos_all_cc", - # TODO(b/162630222): remove this dependency. - "//tensorflow/c/kernels:histogram_summary_op_lib", - "//tensorflow/c/kernels:merge_summary_op_lib", - "//tensorflow/c/kernels:summary_op_lib", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "string_ops", - ], - deps = [ - ":lib_internal", - ":lib_proto_parsing", - "@com_google_absl//absl/strings", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "array_ops", - ], - deps = [ - ":lib", - ":protos_all_cc", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "mkl_array_ops", - ], - deps = [":protos_all_cc"], -) - -tf_gen_op_libs( - op_lib_names = [ - "audio_ops", - ], - deps = [":lib"], -) - -tf_gen_op_libs( - op_lib_names = ["debug_ops"], - deps = [":lib"], -) - -tf_gen_op_libs( - is_external = False, - op_lib_names = [ - "resource_variable_ops", - ], - deps = [":lib"], -) - -tf_gen_op_libs( - op_lib_names = [ - "tpu_configuration_ops", - "tpu_cross_replica_ops", - "tpu_embedding_ops", - "tpu_embedding_load_retrieve_ops", - "tpu_functional_ops", - "tpu_heartbeat_ops", - "tpu_host_compute_ops", - "tpu_infeed_ops", - "tpu_outfeed_ops", - "tpu_ordinal_selector_ops", - "tpu_replication_ops", - ], - deps = [ - ":lib", - ":lib_proto_parsing", - ":protos_all_cc", - "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", - "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", - "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", - "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", - ], -) - -# And one for all user ops +# One target for all user ops cc_library( name = "user_ops_op_lib", srcs = glob(["user_ops/**/*.cc"]), @@ -773,215 +557,29 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "word2vec_ops", - srcs = ["ops/word2vec_ops.cc"], - linkstatic = 1, - visibility = ["//tensorflow:internal"], - deps = [":framework"], - alwayslink = 1, -) - -cc_library( - name = "cudnn_rnn_ops", - srcs = [ - "ops/cudnn_rnn_ops.cc", - ], - linkstatic = 1, - visibility = ["//tensorflow:internal"], - deps = [ - ":framework", - ":lib", - ":lib_internal", - ":stream_executor", - "//tensorflow/core/kernels:bounds_check_lib", - ], - alwayslink = 1, -) - -tf_gen_op_libs( - op_lib_names = [ - "cudnn_rnn_ops", - ], - deps = [ - ":lib", - ], -) - -cc_library( - name = "ragged_ops", - deps = [ - ":ragged_array_ops_op_lib", - ":ragged_conversion_ops_op_lib", - ":ragged_math_ops_op_lib", - ], -) - -tf_gen_op_libs( - op_lib_names = [ - "ragged_array_ops", - "ragged_conversion_ops", - "ragged_math_ops", - ], - deps = ["//tensorflow/core/util:ragged_to_dense_util"], -) - cc_library( name = "ops", visibility = ["//visibility:public"], deps = [ - ":array_ops_op_lib", - ":audio_ops_op_lib", - ":batch_ops_op_lib", - ":bitwise_ops_op_lib", - ":boosted_trees_ops_op_lib", - ":tensor_forest_ops_op_lib", - ":candidate_sampling_ops_op_lib", - ":checkpoint_ops_op_lib", - ":clustering_ops_op_lib", - ":collective_ops_op_lib", - ":control_flow_ops_op_lib", - ":count_ops_op_lib", - ":ctc_ops_op_lib", - ":cudnn_rnn_ops_op_lib", - ":data_flow_ops_op_lib", - ":dataset_ops_op_lib", - ":debug_ops_op_lib", - ":decode_proto_ops_op_lib", - ":encode_proto_ops_op_lib", - ":experimental_dataset_ops_op_lib", - ":function_ops_op_lib", - ":functional_ops_op_lib", - ":image_ops_op_lib", - ":io_ops_op_lib", - ":linalg_ops_op_lib", - ":list_ops_op_lib", - ":map_ops_op_lib", - ":logging_ops_op_lib", - ":lookup_ops_op_lib", - ":manip_ops_op_lib", - ":math_ops_op_lib", - ":nccl_ops_op_lib", - ":nn_ops_op_lib", - ":no_op_op_lib", - ":parsing_ops_op_lib", - ":ragged_ops", - ":random_ops_op_lib", - ":rnn_ops_op_lib", - ":special_math_ops_op_lib", - ":stateful_random_ops_op_lib", - ":remote_fused_graph_ops_op_lib", - ":resource_variable_ops_op_lib", - ":rpc_ops_op_lib", - ":scoped_allocator_ops_op_lib", - ":script_ops_op_lib", - ":sdca_ops_op_lib", - ":sendrecv_ops_op_lib", - ":set_ops_op_lib", - ":sparse_csr_matrix_ops_op_lib", - ":sparse_ops_op_lib", - ":summary_ops_op_lib", - ":spectral_ops_op_lib", - ":state_ops_op_lib", - ":stateless_random_ops_op_lib", - ":stateless_random_ops_v2_op_lib", - ":string_ops_op_lib", - ":training_ops_op_lib", ":user_ops_op_lib", - ":word2vec_ops", "//tensorflow/c/kernels:bitcast_op_lib", "//tensorflow/c/kernels:histogram_summary_op_lib", "//tensorflow/c/kernels:merge_summary_op_lib", "//tensorflow/c/kernels:summary_op_lib", + "//tensorflow/core/ops:ops", ] + if_chromiumos( [], - # Non-tpu platforms don't need tpu dependency. It would be best to guard - # them by if_tpu. But there is no such flag yet. + # Non-tpu platforms don't need tpu dependency. [ - ":tpu_configuration_ops_op_lib", - ":tpu_cross_replica_ops_op_lib", - ":tpu_embedding_ops_op_lib", - ":tpu_embedding_load_retrieve_ops_op_lib", - ":tpu_functional_ops_op_lib", - ":tpu_heartbeat_ops_op_lib", - ":tpu_host_compute_ops_op_lib", - ":tpu_infeed_ops_op_lib", - ":tpu_outfeed_ops_op_lib", - ":tpu_ordinal_selector_ops_op_lib", - ":tpu_replication_ops_op_lib", "//tensorflow/core/tpu/ops", ], - ) + if_mkl([ - ":mkl_array_ops_op_lib", - ":mkl_nn_ops_op_lib", - ]) + if_tensorrt([ + ) + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_ops_op_lib", "//tensorflow/compiler/tf2tensorrt:trt_op_libs", - ]) + if_tpu( + ]) + if_libtpu( if_false = ["//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op"], if_true = [], ), - alwayslink = 1, -) - -cc_library( - name = "array_grad", - srcs = ["ops/array_grad.cc"], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":array_ops_op_lib", - ":framework", - ":lib", - "//tensorflow/c/kernels:bitcast_op_lib", - ], - alwayslink = 1, -) - -cc_library( - name = "functional_grad", - srcs = ["ops/functional_grad.cc"], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":framework", - ":functional_ops_op_lib", - ":lib", - ], - alwayslink = 1, -) - -cc_library( - name = "math_grad", - srcs = [ - "ops/math_grad.cc", - "ops/random_grad.cc", - "ops/stateless_random_grad.cc", - ], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":framework", - ":lib", - ":math_ops_op_lib", - ":protos_all_cc", - ], - alwayslink = 1, -) - -cc_library( - name = "nn_grad", - srcs = ["ops/nn_grad.cc"], - linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 - visibility = ["//visibility:public"], - deps = [ - ":framework", - ":lib", - ":nn_ops_op_lib", - ] + if_mkl([ - ":mkl_nn_ops_op_lib", - ]), - alwayslink = 1, ) alias( @@ -1213,10 +811,10 @@ tf_cuda_library( ":ops", ":protos_all_cc", ":test", - ":testlib_ops", # TODO(gunan): resolve dependency issues and load these kernels dynamically. ":testlib_kernels_impl", "//tensorflow/cc:scope", + "//tensorflow/core/common_runtime:testlib_ops", "//tensorflow/core/framework:fake_input", "//tensorflow/core/framework:function_testlib", "//tensorflow/core/framework:shape_inference_testutil", @@ -1226,13 +824,6 @@ tf_cuda_library( ], ) -alias( - name = "testlib_ops", - testonly = 1, - actual = - "//tensorflow/core/common_runtime:testlib_ops", -) - # This is a link-only library to provide a DirectSession # implementation of the Session interface. tf_cuda_library( @@ -1246,13 +837,6 @@ tf_cuda_library( alwayslink = 1, ) -# ----------------------------------------------------------------------------- -# MKL targets -alias( - name = "mkl_graph_util", - actual = "//tensorflow/core/graph:mkl_graph_util", -) - # ----------------------------------------------------------------------------- # Public Android targets @@ -1292,7 +876,7 @@ filegroup( "**/*main.cc", ], ), - visibility = ["//visibility:private"], + visibility = ["//visibility:public"], ) # Sources required to build the TensorFlow framework with runtime on @@ -1352,11 +936,98 @@ filegroup( visibility = ["//visibility:public"], ) -alias( - name = "android_srcs", - actual = ":mobile_srcs", - visibility = ["//visibility:public"], -) +# All the aliases for stuff under ops/ +# Once the dependencies move to the real targets, remove the aliases here! + +[ + alias( + name = "%s" % (name,), + actual = "//tensorflow/core/ops:%s" % (name,), + visibility = ["//visibility:public"], + ) + for name in [ + "array_grad", + "array_ops_op_lib", + "audio_ops_op_lib", + "batch_ops_op_lib", + "bitwise_ops_op_lib", + "boosted_trees_ops_op_lib", + "candidate_sampling_ops_op_lib", + "checkpoint_ops_op_lib", + "clustering_ops_op_lib", + "collective_ops_op_lib", + "control_flow_ops_op_lib", + "count_ops_op_lib", + "ctc_ops_op_lib", + "cudnn_rnn_ops_op_lib", + "data_flow_ops_op_lib", + "dataset_ops_op_lib", + "debug_ops_op_lib", + "decode_proto_ops_op_lib", + "encode_proto_ops_op_lib", + "experimental_dataset_ops_op_lib", + "function_ops_op_lib", + "functional_grad", + "functional_ops_op_lib", + "image_ops_op_lib", + "io_ops_op_lib", + "linalg_ops_op_lib", + "list_ops_op_lib", + "logging_ops_op_lib", + "lookup_ops_op_lib", + "manip_ops_op_lib", + "map_ops_op_lib", + "math_grad", + "math_ops_op_lib", + "mkl_array_ops_op_lib", + "mkl_nn_ops_op_lib", + "nccl_ops_op_lib", + "nn_grad", + "nn_ops_op_lib", + "no_op_op_lib", + "parsing_ops_op_lib", + "portable_op_registrations_and_gradients", + "ragged_array_ops_op_lib", + "ragged_conversion_ops_op_lib", + "ragged_math_ops_op_lib", + "ragged_ops", + "random_grad_op_lib", + "random_ops_op_lib", + "remote_fused_graph_ops_op_lib", + "resource_variable_ops_op_lib", + "rnn_ops_op_lib", + "rpc_ops_op_lib", + "scoped_allocator_ops_op_lib", + "script_ops_op_lib", + "sdca_ops_op_lib", + "sendrecv_ops_op_lib", + "set_ops_op_lib", + "sparse_csr_matrix_ops_op_lib", + "sparse_ops_op_lib", + "special_math_ops_op_lib", + "spectral_ops_op_lib", + "state_ops_op_lib", + "stateful_random_ops_op_lib", + "stateless_random_ops_op_lib", + "stateless_random_ops_v2_op_lib", + "string_ops_op_lib", + "summary_ops_op_lib", + "tensor_forest_ops_op_lib", + "tpu_configuration_ops_op_lib", + "tpu_cross_replica_ops_op_lib", + "tpu_embedding_ops_op_lib", + "tpu_embedding_load_retrieve_ops_op_lib", + "tpu_functional_ops_op_lib", + "tpu_heartbeat_ops_op_lib", + "tpu_host_compute_ops_op_lib", + "tpu_infeed_ops_op_lib", + "tpu_outfeed_ops_op_lib", + "tpu_ordinal_selector_ops_op_lib", + "tpu_replication_ops_op_lib", + "training_ops_op_lib", + "word2vec_ops", + ] +] # Native library support for mobile applications. Does not contain # operators, use :portable_tensorflow_lib if you want full operator @@ -1371,7 +1042,7 @@ alias( # Compiles to a trivial library on non-mobile to prevent irrelevant # build errors. If not building this e.g. as part of an android_binary, # a command such as the following must be used: -# bazel build -c opt tensorflow/core:android_tensorflow_lib \ +# bazel build -c opt tensorflow/core:portable_tensorflow_lib \ # --define=TENSORFLOW_PROTOS=lite \ # --crosstool_top=//external:android/crosstool \ # --cpu=armeabi-v7a \ @@ -1394,24 +1065,6 @@ cc_library( alwayslink = 1, ) -alias( - name = "android_tensorflow_lib_lite", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -alias( - name = "android_tensorflow_lib_lite_nortti", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -alias( - name = "android_tensorflow_lib_lite_nortti_lite_protos", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - cc_library( name = "mobile_additional_lib_deps", deps = tf_additional_lib_deps() + [ @@ -1422,26 +1075,6 @@ cc_library( ], ) -alias( - name = "ios_tensorflow_lib_lite", - actual = ":portable_tensorflow_lib_lite", - visibility = ["//visibility:public"], -) - -# Full TensorFlow library with operator support. Use this unless reducing -# binary size (by packaging a reduced operator set) is a concern. -alias( - name = "android_tensorflow_lib", - actual = ":portable_tensorflow_lib", - visibility = ["//visibility:public"], -) - -alias( - name = "ios_tensorflow_lib", - actual = ":portable_tensorflow_lib", - visibility = ["//visibility:public"], -) - cc_library( name = "portable_tensorflow_lib", srcs = if_mobile([":portable_op_registrations_and_gradients"]), @@ -1462,52 +1095,6 @@ cc_library( alwayslink = 1, ) -alias( - name = "android_op_registrations_and_gradients", - actual = ":portable_op_registrations_and_gradients", - visibility = ["//visibility:public"], -) - -filegroup( - name = "portable_op_registrations_and_gradients", - srcs = ["//tensorflow/c/kernels:android_all_ops"] + glob( - [ - "ops/**/*.cc", - "ops/**/*.h", - ], - exclude = [ - "**/*test.cc", - "**/*testutil*", - "**/*testlib*", - "**/*main.cc", - "**/tpu_*", - ], - ), - visibility = ["//visibility:public"], -) - -filegroup( - name = "android_test_srcs", - testonly = 1, - # TODO(andrewharp/nhua): - # make more test-related sources portable e.g. "//tensorflow/core/platform:test.cc", - srcs = tf_portable_full_lite_protos( - full = [ - "//tensorflow/core/framework:android_test_hdrs", - "//tensorflow/core/framework:android_test_srcs", - "//tensorflow/core/platform:android_test_srcs", - "//tensorflow/core/util:android_test_srcs", - ], - lite = [ - "//tensorflow/core/framework:android_test_hdrs", - "//tensorflow/core/framework:android_test_srcs_no_core", - "//tensorflow/core/platform:android_test_srcs", - "//tensorflow/core/util:android_test_srcs", - ], - ), - visibility = ["//visibility:public"], -) - # This is like android_test_srcs, minus the things that are already in mobile_srcs. filegroup( name = "android_test_srcs_no_core", @@ -1522,18 +1109,6 @@ filegroup( ) # Portable library providing testing functionality for TensorFlow. -alias( - name = "android_tensorflow_test_lib", - actual = ":portable_tensorflow_test_lib", - visibility = ["//visibility:public"], -) - -alias( - name = "ios_tensorflow_test_lib", - actual = ":portable_tensorflow_test_lib", - visibility = ["//visibility:public"], -) - cc_library( name = "portable_tensorflow_test_lib", testonly = 1, @@ -1553,7 +1128,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":portable_tensorflow_lib", - ":protos_all_cc", "//tensorflow/core/kernels:portable_tensorflow_kernels", "//tensorflow/core/platform/default/build_config:gtest", "//third_party/eigen3", @@ -1606,105 +1180,8 @@ alias( ] ] -# The following targets will be moved to core/protobuf. The aliases are only temporary -# since moving existing users will require several CLs over several projects. -[ - [ - alias( - name = "protobuf_%s_pyclif%s" % (proto_name, target_suffix), - actual = "//tensorflow/core/protobuf:%s_pyclif%s" % (proto_name, target_suffix), - visibility = ["//visibility:public"], - ) - for target_suffix in [ - "", - "_pb2", - ] - ] - for proto_name in [ - "config", - "device_properties", - "graph_debug_info", - "meta_graph", - "saved_model", - ] -] - # ----------------------------------------------------------------------------- # Internal targets - -alias( - name = "autotuning_proto", - actual = "//tensorflow/core/protobuf:autotuning_proto", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "autotuning_proto_cc", - actual = "//tensorflow/core/protobuf:autotuning_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "conv_autotuning_proto", - actual = "//tensorflow/core/protobuf:conv_autotuning_proto", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "conv_autotuning_proto_cc", - actual = "//tensorflow/core/protobuf:conv_autotuning_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "worker_proto_cc", - actual = "//tensorflow/core/protobuf:worker_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "worker_service_proto_cc", - actual = "//tensorflow/core/protobuf:worker_service_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "master_proto_cc", - actual = "//tensorflow/core/protobuf:master_proto_cc", - visibility = [ - "//learning/brain/frameworks/uptc:__subpackages__", - "//tensorflow:internal", - ], -) - -alias( - name = "master_service_proto_cc", - actual = "//tensorflow/core/protobuf:master_service_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - -alias( - name = "eager_service_proto_cc", - actual = "//tensorflow/core/protobuf:eager_service_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - filegroup( name = "lib_internal_private_headers", srcs = [ @@ -1918,6 +1395,7 @@ cc_library( "//tensorflow/core/platform:denormal", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:env", + "//tensorflow/core/platform:env_impl", "//tensorflow/core/platform:error", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:file_statistics", @@ -2022,21 +1500,6 @@ cc_library( ], ) -alias( - name = "png_internal", - actual = "//tensorflow/core/lib/png:png_io", -) - -alias( - name = "portable_png_internal", - actual = "//tensorflow/core/lib/png:png_io", -) - -alias( - name = "android_png_internal", - actual = "//tensorflow/core/lib/png:png_io", -) - cc_library( name = "tflite_portable_logging", hdrs = [ @@ -2113,16 +1576,6 @@ cc_library( ], ) -alias( - name = "android_jpeg_internal", - actual = ":portable_jpeg_internal", -) - -alias( - name = "android_gif_internal", - actual = ":portable_gif_internal", -) - alias( name = "error_codes_proto_impl", actual = "//tensorflow/core/protobuf:error_codes_proto_impl", @@ -2133,11 +1586,6 @@ alias( actual = "//tensorflow/core/protobuf:error_codes_proto_impl_cc", ) -alias( - name = "error_codes_proto_cc", - actual = "//tensorflow/core/lib/core:error_codes_proto_cc", -) - alias( name = "version_lib", actual = "//tensorflow/core/util:version_info", @@ -2149,6 +1597,7 @@ filegroup( "//tensorflow/core/example:feature_util.h", "//tensorflow/core/framework:framework_internal_private_hdrs", "//tensorflow/core/graph:framework_internal_private_headers", + "//tensorflow/core/public:session_options.h", "//tensorflow/core/util:framework_internal_private_hdrs", "//tensorflow/core/util:memmapped_file_system_hdrs", "//tensorflow/core/util/sparse:framework_internal_private_headers_group", @@ -2273,6 +1722,7 @@ tf_cuda_library( "//tensorflow/core/framework:tensor", "//tensorflow/core/framework:tensor_shape", "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/platform:env_impl", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", @@ -2314,15 +1764,10 @@ cc_header_only_library( ], visibility = ["//visibility:public"], deps = [ - ":stream_executor", + "//tensorflow/core/platform:stream_executor", ], ) -alias( - name = "stream_executor", - actual = "//tensorflow/core/platform:stream_executor", -) - # Like stream_executor library, but compiles without --config=cuda # and does not include any cuda dependencies. alias( @@ -2380,12 +1825,6 @@ tf_cuda_library( alwayslink = 1, ) -alias( - name = "core_cpu_impl", - actual = - "//tensorflow/core/common_runtime:core_cpu_impl", -) - alias( name = "core_cpu_lib", actual = @@ -2398,18 +1837,6 @@ alias( "//tensorflow/core/common_runtime:core_cpu_internal", ) -alias( - name = "regexp_internal", - actual = - "//tensorflow/core/platform:regexp", - visibility = [ - "//tensorflow/compiler:__subpackages__", - "//tensorflow/core/kernels:__subpackages__", - "//tensorflow/core/profiler:__subpackages__", - "//tensorflow/stream_executor:__subpackages__", - ], -) - alias( name = "direct_session_internal", actual = @@ -2422,14 +1849,6 @@ alias( visibility = ["//visibility:public"], ) -alias( - name = "replay_log_proto_cc", - actual = "//tensorflow/core/protobuf:replay_log_proto_cc", - visibility = [ - "//tensorflow:internal", - ], -) - alias( name = "gpu_runtime", actual = @@ -2459,12 +1878,6 @@ alias( actual = "//tensorflow/core/framework:tensor_testutil", ) -# TODO(gonnet): Remove this alias once all users have been moved to the actual target. -alias( - name = "shape_inference_testutil", - actual = "//tensorflow/core/framework:shape_inference_testutil", -) - # Main program for tests alias( name = "test_main", @@ -2604,18 +2017,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "framework_op_gen_lib_test", - size = "small", - srcs = ["//tensorflow/core/framework:op_gen_lib_test.cc"], - deps = [ - ":protos_all_cc", - ":test", - ":test_main", - "//tensorflow/core/framework:op_gen_lib", - ], -) - test_suite( name = "higher_level_tests", tests = [ @@ -2680,22 +2081,6 @@ tf_cc_tests( ], ) -tf_cc_test( - name = "cudnn_rnn_ops_test_cc", - size = "small", - srcs = [ - "ops/cudnn_rnn_ops_test.cc", - ], - deps = [ - ":core", - ":framework", - ":lib", - ":test", - ":test_main", - ":testlib", - ], -) - tf_cc_test_mkl( name = "mkl_related_tests", size = "small", @@ -2772,223 +2157,6 @@ tf_cc_tests_gpu( ], ) -tf_cc_test_gpu( - name = "variant_op_copy_test", - size = "small", - srcs = ["//tensorflow/core/framework:variant_op_copy_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session", - ":framework", - ":framework_internal", - ":gpu_runtime", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core/kernels:array", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "framework_run_handler_util_test", - size = "small", - srcs = ["//tensorflow/core/framework:run_handler_util_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":framework_internal", - ":lib", - ":test", - ":test_main", - ], -) - -tf_cc_test( - name = "framework_run_handler_test", - size = "small", - srcs = ["//tensorflow/core/framework:run_handler_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core_cpu", - ":direct_session_internal", - ":framework_internal", - ":lib", - ":lib_internal", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/framework:tensor_testutil", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:matmul_op", - "//third_party/eigen3", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - ], -) - -tf_cc_test( - name = "framework_op_segment_test", - size = "small", - srcs = ["//tensorflow/core/framework:op_segment_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:ops_util", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "ops_array_grad_test", - size = "small", - srcs = ["ops/array_grad_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:array", - "//tensorflow/core/kernels:cwise_op", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:math", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "ops_math_grad_test", - size = "small", - srcs = ["ops/math_grad_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_gpu"], - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/core/kernels:array", - "//tensorflow/core/kernels:data_flow", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:math", - "//third_party/eigen3", - ], -) - -tf_cc_test( - name = "ops_remote_fused_graph_ops_test", - size = "small", - srcs = ["ops/remote_fused_graph_ops_test.cc"], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/kernels:remote_fused_graph_ops", - ], -) - -tf_cc_test( - name = "ops_tests", - size = "small", - srcs = [ - "ops/array_ops_test.cc", - "ops/candidate_sampling_ops_test.cc", - "ops/control_flow_ops_test.cc", - "ops/ctc_ops_test.cc", - "ops/data_flow_ops_test.cc", - "ops/functional_ops_test.cc", - "ops/image_ops_test.cc", - "ops/io_ops_test.cc", - "ops/linalg_ops_test.cc", - "ops/math_ops_test.cc", - "ops/nn_ops_test.cc", - "ops/parsing_ops_test.cc", - "ops/random_ops_test.cc", - "ops/rnn_ops_test.cc", - "ops/set_ops_test.cc", - "ops/shape_function_test.cc", - "ops/sparse_csr_matrix_ops_test.cc", - "ops/sparse_ops_test.cc", - "ops/spectral_ops_test.cc", - "ops/state_ops_test.cc", - "ops/string_ops_test.cc", - "ops/training_ops_test.cc", - ], - linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/cc:cc_ops", - "//third_party/eigen3", - ], -) - # Test data filegroup( name = "image_testdata", @@ -3023,22 +2191,6 @@ filegroup( visibility = ["//visibility:public"], ) -alias( - name = "lmdb_testdata", - actual = "//tensorflow/core/lib/lmdb:lmdb_testdata", - visibility = ["//visibility:public"], -) - -alias( - name = "cuda_libdevice_path", - actual = "//tensorflow/core/platform:cuda_libdevice_path", -) - -# Normalize CORE_PROTO_SRCS to generate valid output file names. -PORTABLE_PROTO_HEADERS_OUT = tf_android_core_proto_headers(CORE_PROTO_SRCS) + [ - "//google/protobuf/any.proto.h", -] - transitive_hdrs( name = "headers", visibility = ["//tensorflow:__subpackages__"], @@ -3047,7 +2199,7 @@ transitive_hdrs( ":framework", ":lib", ":protos_all_cc", - ":stream_executor", "//tensorflow/core/platform:platform_strings", + "//tensorflow/core/platform:stream_executor", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_Qr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Qr.pbtxt index ac8f7597aae..3d8fbea9424 100644 --- a/tensorflow/core/api_def/base_api/api_def_Qr.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Qr.pbtxt @@ -34,6 +34,10 @@ END Computes the QR decomposition of each inner matrix in `tensor` such that `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` +Currently, the gradient for the QR decomposition is well-defined only when +the first `P` columns of the inner matrix are linearly independent, where +`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`. + ```python # a is a tensor. # q is a tensor of orthonormal matrices. diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4.pbtxt new file mode 100644 index 00000000000..a84ccb78436 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV4.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "QuantizeAndDequantizeV4" + summary: "Returns the gradient of `QuantizeAndDequantizeV4`." + description: < - - +For usage examples see the python [tf.tensor_scatter_nd_update]( +https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function -In Python, this scatter operation would look like this: - - >>> indices = tf.constant([[4], [3], [1], [7]]) - >>> updates = tf.constant([9, 10, 11, 12]) - >>> tensor = tf.ones([8], dtype=tf.int32) - >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates)) - tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32) - -We can also, insert entire slices of a higher rank tensor all at once. For -example, if we wanted to insert two slices in the first dimension of a -rank-3 tensor with two matrices of new values. - -In Python, this scatter operation would look like this: - - >>> indices = tf.constant([[0], [2]]) - >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], - ... [7, 7, 7, 7], [8, 8, 8, 8]], - ... [[5, 5, 5, 5], [6, 6, 6, 6], - ... [7, 7, 7, 7], [8, 8, 8, 8]]]) - >>> tensor = tf.ones([4, 4, 4], dtype=tf.int32) - >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()) - [[[5 5 5 5] - [6 6 6 6] - [7 7 7 7] - [8 8 8 8]] - [[1 1 1 1] - [1 1 1 1] - [1 1 1 1] - [1 1 1 1]] - [[5 5 5 5] - [6 6 6 6] - [7 7 7 7] - [8 8 8 8]] - [[1 1 1 1] - [1 1 1 1] - [1 1 1 1] - [1 1 1 1]]] - -Note that on CPU, if an out of bound index is found, an error is returned. -On GPU, if an out of bound index is found, the index is ignored. END } diff --git a/tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4.pbtxt b/tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4.pbtxt new file mode 100644 index 00000000000..8054405368a --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "QuantizeAndDequantizeV4Grad" +} diff --git a/tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt b/tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt new file mode 100644 index 00000000000..8054405368a --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "QuantizeAndDequantizeV4Grad" +} diff --git a/tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4.pbtxt b/tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4.pbtxt new file mode 100644 index 00000000000..0ed576f0690 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "QuantizeAndDequantizeV4Grad" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt b/tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt new file mode 100644 index 00000000000..0ed576f0690 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_QuantizeAndDequantizeV4Grad.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "QuantizeAndDequantizeV4Grad" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorScatterUpdate.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorScatterUpdate.pbtxt index aa4ab86c7c3..96ff5fe0091 100644 --- a/tensorflow/core/api_def/python_api/api_def_TensorScatterUpdate.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_TensorScatterUpdate.pbtxt @@ -1,11 +1,4 @@ op { graph_op_name: "TensorScatterUpdate" - deprecation_message: "Use tensor_scatter_nd_update instead" - endpoint { - name: "tensor_scatter_nd_update" - } - endpoint { - name: "tensor_scatter_update" - deprecation_version: 2 - } + visibility: HIDDEN } diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index f0378ad7538..8f1769cf753 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -1,6 +1,6 @@ load( "//tensorflow:tensorflow.bzl", - "if_tpu", + "if_libtpu", "tf_cc_test", "tf_cc_test_mkl", "tf_cc_tests", @@ -93,7 +93,7 @@ cc_library( deps = [ ":core_cpu", "//tensorflow/core/common_runtime/gpu:gpu_runtime", - ] + if_tpu(["//tensorflow/core/tpu:tpu_runtime"]), + ] + if_libtpu(["//tensorflow/core/tpu:tpu_runtime"]), ) filegroup( @@ -499,28 +499,19 @@ cc_library( cc_library( name = "device", - srcs = ["device.cc"], hdrs = ["device.h"], copts = tf_copts(), deps = [ "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", ], ) cc_library( name = "device_factory", - srcs = ["device_factory.cc"], hdrs = ["device_factory.h"], copts = tf_copts(), deps = [ - ":device", - ":session_options", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", + "//tensorflow/core:framework_internal", ], ) @@ -1739,7 +1730,6 @@ tf_cuda_library( "//tensorflow/core:framework_internal", "//tensorflow/core:graph", "//tensorflow/core:lib", - "//tensorflow/core:lib_experimental", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", @@ -2230,7 +2220,10 @@ tf_cuda_cc_test( srcs = ["direct_session_test.cc"], args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory linkstatic = tf_kernel_tests_linkstatic(), - tags = ["notsan"], # b/168317266 + tags = [ + "noasan", # b/168317266 + "notsan", # b/168317266 + ], deps = [ ":core_cpu", ":core_cpu_internal", @@ -2274,6 +2267,7 @@ tf_cc_test( srcs = ["direct_session_test.cc"], linkstatic = tf_kernel_tests_linkstatic(), tags = [ + "noasan", #b/168811551 "notsan", #b/168811551 ], deps = [ diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index f5c2522f142..213736fb68b 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -306,7 +306,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, void BaseCollectiveExecutor::CompleteParamsAsync( const DeviceAttributes& device, CollectiveParams* cp, CancellationManager* cancel_mgr, StatusCallback done) { - cp->instance.gpu_ring_order = *gpu_ring_order_; + cp->group.gpu_ring_order = *gpu_ring_order_; const auto is_callback_called = std::make_shared>(false); auto done_with_timeout = done; auto timeout_microseconds = @@ -402,9 +402,9 @@ void BaseCollectiveExecutor::UnblockDependencies( mutex_lock l(launch_mu_); if (launched_.find(col_params.instance.instance_key) == launched_.end()) { const string& task_name = - col_params.instance.task_names[col_params.default_rank]; + col_params.group.task_names[col_params.default_rank]; const int32 num_devices = - col_params.instance.num_devices_per_task.at(task_name); + col_params.group.num_devices_per_task.at(task_name); launched_[col_params.instance.instance_key] = num_devices; } if (--launched_[col_params.instance.instance_key] == 0) { diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 8b2ece5d9de..9c46314af67 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -97,9 +97,11 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( auto it = group_table_.find(cp->group.group_key); if (it == group_table_.end()) { gr = new GroupRec; + mutex_lock grl(gr->mu); gr->group.group_key = cp->group.group_key; gr->group.group_size = cp->group.group_size; gr->group.device_type = cp->group.device_type; + gr->group.gpu_ring_order = cp->group.gpu_ring_order; // Initialize group runtime details. CollectiveImplementationInterface* col_impl; @@ -164,6 +166,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( " but that group has size ", gr->group.group_size); } } + bool new_device = false; if (gr->status.ok()) { // Insert device if not already present. auto it = gr->devices.find(device.name()); @@ -177,15 +180,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( } else { // This is a new device that has not yet joined the group. gr->devices[device.name()] = device; - if (gr->devices.size() == gr->group.group_size) { - // The group is full after adding this device, calculate the number - // of tasks. - std::unordered_set tasks; - for (const auto& item : gr->devices) { - tasks.insert(TaskNameFromDeviceName(item.first)); - } - gr->group.num_tasks = static_cast(tasks.size()); - } + new_device = true; if (VLOG_IS_ON(1)) { string dev_buf; for (const auto& d : gr->devices) { @@ -211,7 +206,6 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( } if (gr->status.ok()) { - cp->group.runtime_details = gr->group.runtime_details; // If the group is not yet complete, queue to wait for it. VLOG(2) << "group_size " << gr->group.group_size << " set size " << gr->devices.size() << " gr " << gr; @@ -221,6 +215,10 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( return; } CHECK_EQ(gr->devices.size(), gr->group.group_size); + // We get a full group. Fill in remaining fields in gr->group. + if (new_device) { + FinishGroup(gr); + } } // At this point, we either have a full group, or an error status. Ensure // that all callbacks are invoked with the appropriate status. @@ -248,16 +246,16 @@ typedef std::unordered_map TaskDeviceMap; typedef std::unordered_map GlobalDeviceMap; // Create a populated GlobalDeviceMap from CollInstanceParams and localities. -GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip, +GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp, const std::vector& attributes) { GlobalDeviceMap gdm; - CHECK_EQ(ip.device_names.size(), ip.task_names.size()); - CHECK_EQ(ip.device_names.size(), attributes.size()); - for (int i = 0; i < ip.device_names.size(); ++i) { - TaskDeviceMap& tdm = gdm[ip.task_names[i]]; - DevRec* dr = &tdm[ip.device_names[i]]; - dr->task = ip.task_names[i]; - dr->device = ip.device_names[i]; + CHECK_EQ(gp.device_names.size(), gp.task_names.size()); + CHECK_EQ(gp.device_names.size(), attributes.size()); + for (int i = 0; i < gp.device_names.size(); ++i) { + TaskDeviceMap& tdm = gdm[gp.task_names[i]]; + DevRec* dr = &tdm[gp.device_names[i]]; + dr->task = gp.task_names[i]; + dr->device = gp.device_names[i]; dr->original_rank = i; dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap. dr->global_rank = 0; // Will be populated later by EstablishGlobalRank. @@ -378,25 +376,23 @@ void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) { } } -// The first time a shared CollectiveParams is established for a -// shared set of instances we compute a good rank order for all the -// devices in the group, that is appropriate for a ring algorithm. -// This order need not be the same across different instance groups -// sharing the same device group where there is more than one good -// order. +// The first time a CollGroupParams is established for a group we compute a good +// rank order for all the devices in the group, that is appropriate for a ring +// algorithm. GlobalDeviceMap EstablishGlobalRank( - CollectiveParams* cp, const std::vector& attributes) { + const CollGroupParams& gp, + const std::vector& attributes) { VLOG(1) << "EstablishGlobalRank"; - GlobalDeviceMap gdm = BuildDevRecs(cp->instance, attributes); + GlobalDeviceMap gdm = BuildDevRecs(gp, attributes); for (auto& iter : gdm) { TaskDeviceMap& tdm = iter.second; - OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm); + OrderTaskDeviceMap(gp.gpu_ring_order, &tdm); } // Connect the global rank order by the order in which tasks first appear. std::set ordered_tasks; int next_rank = 0; - for (int i = 0; i < cp->instance.task_names.size(); ++i) { - const string& task_name = cp->instance.task_names[i]; + for (int i = 0; i < gp.task_names.size(); ++i) { + const string& task_name = gp.task_names[i]; if (ordered_tasks.find(task_name) != ordered_tasks.end()) { continue; } @@ -411,81 +407,104 @@ GlobalDeviceMap EstablishGlobalRank( } // Count the devices associated with each task and set -// cp->same_num_devices_per_task. Requires cp->instance.task_names +// gp->same_num_devices_per_task. Requires gp->task_names // be sorted. -void SetDevPerTask(CollectiveParams* cp) { - cp->instance.num_devices_per_task.clear(); - const string* last_task_name = &cp->instance.task_names[0]; +void SetDevPerTask(CollGroupParams* gp) { + gp->num_devices_per_task.clear(); + const string* last_task_name = &gp->task_names[0]; int count = 0; - for (const string& task_name : cp->instance.task_names) { + for (const string& task_name : gp->task_names) { if (task_name == *last_task_name) { ++count; } else { - cp->instance.num_devices_per_task[*last_task_name] = count; + gp->num_devices_per_task[*last_task_name] = count; count = 1; last_task_name = &task_name; } } - cp->instance.num_devices_per_task[*last_task_name] = count; + gp->num_devices_per_task[*last_task_name] = count; - cp->instance.same_num_devices_per_task = false; + gp->same_num_devices_per_task = false; int dev_per_task = -1; - for (const auto& task_dev : cp->instance.num_devices_per_task) { + for (const auto& task_dev : gp->num_devices_per_task) { if (dev_per_task == -1) { dev_per_task = task_dev.second; } else if (dev_per_task != task_dev.second) { return; } } - cp->instance.same_num_devices_per_task = true; - CHECK_EQ((cp->group.group_size % cp->group.num_tasks), 0); + gp->same_num_devices_per_task = true; + CHECK_EQ((gp->group_size % gp->num_tasks), 0); } -// Sort cp->instance.device_names lexicographically, but do by first -// computing a reordering permutation so we can keep cp->instance.task_names +// Sort gp->device_names lexicographically, but do by first +// computing a reordering permutation so we can keep gp->task_names // in corresponding order. -void SortDevicesAndTasks(CollectiveParams* cp) { - VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance; - CHECK(cp); - CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()); - CHECK_EQ(cp->group.group_size, cp->instance.task_names.size()); - std::vector perm(cp->group.group_size); +void SortDevicesAndTasks(CollGroupParams* gp) { + VLOG(1) << "SortDevicesAndTasks " << gp << " " << gp; + CHECK(gp); + CHECK_EQ(gp->group_size, gp->device_names.size()); + CHECK_EQ(gp->group_size, gp->task_names.size()); + std::vector perm(gp->group_size); // TODO(tucker): substitute std::iota when the windows build supports it. // std::iota(perm.begin(), perm.end(), 0); for (int i = 0; i < perm.size(); ++i) { perm[i] = i; } - std::sort(perm.begin(), perm.end(), [cp](int a, int b) { - return cp->instance.device_names[a] < cp->instance.device_names[b]; + std::sort(perm.begin(), perm.end(), [gp](int a, int b) { + return gp->device_names[a] < gp->device_names[b]; }); std::vector new_devs; std::vector new_tasks; - new_devs.reserve(cp->group.group_size); - new_tasks.reserve(cp->group.group_size); + new_devs.reserve(gp->group_size); + new_tasks.reserve(gp->group_size); for (int pi : perm) { - new_devs.push_back(cp->instance.device_names[pi]); - new_tasks.push_back(cp->instance.task_names[pi]); + new_devs.push_back(gp->device_names[pi]); + new_tasks.push_back(gp->task_names[pi]); } - cp->instance.device_names = std::move(new_devs); - cp->instance.task_names = std::move(new_tasks); - VLOG(1) << "Modified device_names on " << cp; - SetDevPerTask(cp); + gp->device_names = std::move(new_devs); + gp->task_names = std::move(new_tasks); + VLOG(1) << "Modified device_names on " << gp; + SetDevPerTask(gp); } } // namespace +void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) { + gr->group.device_names.reserve(gr->devices.size()); + gr->group.task_names.reserve(gr->devices.size()); + std::vector attributes; + // Unique tasks. It's used to calculate num_tasks. + std::unordered_set tasks; + attributes.reserve(gr->devices.size()); + for (const auto& item : gr->devices) { + gr->group.device_names.push_back(item.first); + string task_name = TaskNameFromDeviceName(item.first); + gr->group.task_names.push_back(task_name); + tasks.insert(task_name); + attributes.push_back(item.second); + } + gr->group.num_tasks = static_cast(tasks.size()); + // Sort device_names lexicographically, keeping task_names in corresponding + // order. Also set number of devices per task. + SortDevicesAndTasks(&gr->group); + // Establish the final order of gp->device_names and gp->task_names by + // considering localities of all devices. + CompleteDefaultRanking(attributes, &gr->group); +} + void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp) { cp->task.is_local.resize(cp->group.group_size, false); for (int i = 0; i < cp->group.group_size; ++i) { - cp->task.is_local[i] = (cp->instance.task_names[i] == task_name); + cp->task.is_local[i] = (cp->group.task_names[i] == task_name); } } void CollectiveParamResolverLocal::SetDefaultRank(const string& device, CollectiveParams* cp) { - CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp; + CHECK_EQ(cp->group.group_size, cp->group.device_names.size()) << cp; for (int i = 0; i < cp->group.group_size; ++i) { - if (cp->instance.device_names[i] == device) { + if (cp->group.device_names[i] == device) { cp->default_rank = i; break; } @@ -494,40 +513,14 @@ void CollectiveParamResolverLocal::SetDefaultRank(const string& device, void CollectiveParamResolverLocal::InitInstanceSharedParams( const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) { - std::vector attributes; ir->shared.instance = cp->instance; - { - mutex_lock gl(gr->mu); - ir->shared.group = gr->group; - ir->shared.instance.device_names.clear(); - ir->shared.instance.task_names.clear(); - ir->shared.instance.device_names.reserve(gr->devices.size()); - ir->shared.instance.task_names.reserve(gr->devices.size()); - attributes.reserve(gr->devices.size()); - for (const auto& item : gr->devices) { - ir->shared.instance.device_names.push_back(item.first); - ir->shared.instance.task_names.push_back( - TaskNameFromDeviceName(item.first)); - attributes.push_back(item.second); - } - VLOG(2) << "Initialized names for instance: " - << ir->shared.instance.ToString(); - } ir->shared.default_rank = -1; - // Sort device_names lexicographically, keeping task_names in corresponding - // order. Also set number of devices per task. - SortDevicesAndTasks(&ir->shared); - - // Get Locality data for all devices. - // Set is_local and task_names in *shared prior to invoking // GetDeviceAttributesAsync. In a distributed context this function can be // called by a derived class, some of the devices may be non-local and // GetDeviceAttributesAsync will use those fields to launch RPCs. CompleteTaskIsLocal(task_name_, &ir->shared); - - CompleteDefaultRanking(gr, cp, ir, attributes); } // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks @@ -535,33 +528,31 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams( // TensorFlow runtime. This set of devices may be a superset of the devices // participating in this instance of collectives. void CollectiveParamResolverLocal::CompleteDefaultRanking( - const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir, - const std::vector& attributes) { + const std::vector& attributes, CollGroupParams* gp) { // Establish an instance-specific default rank order for devices // based on localities. This rank order should be a good ring // order, if possible. - GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, attributes); + GlobalDeviceMap gdm = EstablishGlobalRank(*gp, attributes); // Reflect the new global ranking on shared - size_t num_devices = ir->shared.group.group_size; + size_t num_devices = gp->group_size; std::vector new_device_names(num_devices, ""); std::vector new_task_names(num_devices, ""); for (const auto& git : gdm) { const TaskDeviceMap& tdm = git.second; for (const auto& tit : tdm) { const DevRec& dr = tit.second; - new_device_names[dr.global_rank] = - ir->shared.instance.device_names[dr.original_rank]; - new_task_names[dr.global_rank] = - ir->shared.instance.task_names[dr.original_rank]; + new_device_names[dr.global_rank] = gp->device_names[dr.original_rank]; + new_task_names[dr.global_rank] = gp->task_names[dr.original_rank]; } } - ir->shared.instance.device_names = new_device_names; - ir->shared.instance.task_names = new_task_names; + gp->device_names = new_device_names; + gp->task_names = new_task_names; if (VLOG_IS_ON(2)) { string buf; for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d); - VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf; + VLOG(2) << "Optimized device order for group " << gp->group_key << ": " + << buf; } } @@ -655,10 +646,13 @@ void CollectiveParamResolverLocal::CompleteInstanceLocal( // Populate the group portion of *cp from *gr. Most of it should already // match. - DCHECK_EQ(cp->group.group_key, gr->group.group_key); - DCHECK_EQ(cp->group.group_size, gr->group.group_size); - DCHECK_EQ(cp->group.device_type, gr->group.device_type); - cp->group = gr->group; + { + mutex_lock l(gr->mu); + DCHECK_EQ(cp->group.group_key, gr->group.group_key); + DCHECK_EQ(cp->group.group_size, gr->group.group_size); + DCHECK_EQ(cp->group.device_type, gr->group.device_type); + cp->group = gr->group; + } InstanceRec* ir = GetOrCreateInstanceRec(gr, cp); CompleteInstanceFromInitializedIRec(device, gr, cp, ir, is_source, done); @@ -756,11 +750,11 @@ void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir, } } } - if (ir->known_count < ir->shared.group.group_size) { + if (ir->known_count < cp->group.group_size) { ir->known_waiters.push_back(f); return; } - CHECK_EQ(ir->known_count, ir->shared.group.group_size); + CHECK_EQ(ir->known_count, cp->group.group_size); if (ir->source_rank < 0) { // NOTE(ayushd): changing the error message below would also require // updating CompleteParamsBroadcastForgotSend test in diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 2b30ba1a14d..19d68a14a63 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -69,8 +69,8 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // Used to complete/verify CollGroup. struct GroupRec { - CollGroupParams group; mutable mutex mu; + CollGroupParams group TF_GUARDED_BY(mu); Status status TF_GUARDED_BY(mu); std::unordered_map devices TF_GUARDED_BY(mu); std::vector waiting TF_GUARDED_BY(mu); @@ -88,6 +88,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { const GroupRecCallback& done) TF_LOCKS_EXCLUDED(group_mu_); + // Finishes the group parameters once all members of the group are there. + void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu); + // Used to complete/verify CollInstance. struct InstanceRec; @@ -131,11 +134,10 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) TF_LOCKS_EXCLUDED(gr->mu); - // Establishes the final order of ir->shared.instance.device_names and - // ir->shared.instance.task_names by considering localities of all devices. - void CompleteDefaultRanking(const GroupRec* gr, const CollectiveParams* cp, - InstanceRec* ir, - const std::vector& attributes); + // Establishes the final order of gp->device_names and gp->task_names by + // considering localities of all devices. + void CompleteDefaultRanking(const std::vector& attributes, + CollGroupParams* gp); // Finish populating *cp. // Precondition: *gr has been fully populated by CompleteGroupLocal. diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index af2e7fe8057..e1ac46f2e53 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -59,32 +59,21 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { } void RunCompleteDefaultRanking( - const CollectiveParams& shared_cp, - const std::vector& attributes, + CollGroupParams group, const std::vector& attributes, const std::vector& gpu_ring_order, const std::vector& expected_device_order) { - CollectiveParams cp; - cp.instance.device_names = shared_cp.instance.device_names; - CollectiveParamResolverLocal::InstanceRec ir; - { - mutex_lock l(ir.mu); - ir.shared.name = shared_cp.name; - ir.shared.group = shared_cp.group; - ir.shared.instance = shared_cp.instance; - if (!gpu_ring_order.empty()) { - ir.shared.instance.gpu_ring_order = ""; - for (int i = 0; i < static_cast(gpu_ring_order.size() - 1); - ++i) { - ir.shared.instance.gpu_ring_order = strings::StrCat( - ir.shared.instance.gpu_ring_order, gpu_ring_order[i], ","); - } - ir.shared.instance.gpu_ring_order = strings::StrCat( - ir.shared.instance.gpu_ring_order, gpu_ring_order.back()); + if (!gpu_ring_order.empty()) { + group.gpu_ring_order = ""; + for (int i = 0; i < static_cast(gpu_ring_order.size() - 1); ++i) { + group.gpu_ring_order = + strings::StrCat(group.gpu_ring_order, gpu_ring_order[i], ","); } - VLOG(2) << "gpu_ring_order " << ir.shared.instance.gpu_ring_order; - prl_->CompleteDefaultRanking(nullptr, &cp, &ir, attributes); - EXPECT_EQ(ir.shared.instance.device_names, expected_device_order); + group.gpu_ring_order = + strings::StrCat(group.gpu_ring_order, gpu_ring_order.back()); } + VLOG(2) << "gpu_ring_order " << group.gpu_ring_order; + prl_->CompleteDefaultRanking(attributes, &group); + EXPECT_EQ(group.device_names, expected_device_order); } DeviceAttributes GetDeviceAttributes(const string& device_name) { @@ -101,19 +90,15 @@ class CollectiveParamResolverLocalTest : public ::testing::Test { TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { constexpr int kNumGpus = 8; - CollectiveParams cp; + CollGroupParams group; std::vector attributes(kNumGpus); - cp.name = "PRLTest"; - cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = 1; - cp.group.group_size = kNumGpus; - cp.instance.instance_key = 5; - cp.instance.type = REDUCTION_COLLECTIVE; - cp.instance.data_type = DataType(DT_FLOAT); + group.device_type = DeviceType("GPU"); + group.num_tasks = 1; + group.group_size = kNumGpus; std::unordered_set clique1 = {0, 1, 6, 7}; for (int gpu_idx = 0; gpu_idx < kNumGpus; ++gpu_idx) { - cp.instance.task_names.push_back("/job:localhost/replica:0/task:0"); - cp.instance.device_names.push_back(strings::StrCat( + group.task_names.push_back("/job:localhost/replica:0/task:0"); + group.device_names.push_back(strings::StrCat( "/job:localhost/replica:0/task:0/device:GPU:", gpu_idx)); DeviceLocality locality; // Build localities so that 0,1,6,7 and 2,3,4,5 form 2 strongly connected @@ -138,7 +123,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { } *attributes[gpu_idx].mutable_locality() = locality; } - RunCompleteDefaultRanking(cp, attributes, {1, 3, 5, 7, 6, 4, 2, 0}, + RunCompleteDefaultRanking(group, attributes, {1, 3, 5, 7, 6, 4, 2, 0}, { "/job:localhost/replica:0/task:0/device:GPU:1", "/job:localhost/replica:0/task:0/device:GPU:3", @@ -149,7 +134,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { "/job:localhost/replica:0/task:0/device:GPU:2", "/job:localhost/replica:0/task:0/device:GPU:0", }); - RunCompleteDefaultRanking(cp, attributes, {7, 6, 5, 4, 3, 2, 1, 0}, + RunCompleteDefaultRanking(group, attributes, {7, 6, 5, 4, 3, 2, 1, 0}, { "/job:localhost/replica:0/task:0/device:GPU:7", "/job:localhost/replica:0/task:0/device:GPU:6", @@ -162,7 +147,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { }); // With no gpu_ring_order passed, automatic link detection should kick in. // Starting at dev 0, the best order would be: 0,1,6,7,3,2,4,5 - RunCompleteDefaultRanking(cp, attributes, {}, + RunCompleteDefaultRanking(group, attributes, {}, { "/job:localhost/replica:0/task:0/device:GPU:0", "/job:localhost/replica:0/task:0/device:GPU:1", @@ -189,18 +174,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { cp->instance.type = REDUCTION_COLLECTIVE; cp->instance.data_type = DataType(DT_FLOAT); cp->instance.shape = TensorShape({5}); - cp->instance.device_names.push_back( - strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i)); cp->instance.impl_details.subdiv_offsets.push_back(0); cp->is_source = false; Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync( - GetDeviceAttributes(cp->instance.device_names[0]), cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + string device = + strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -208,17 +192,17 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { } for (int i = 0; i < NUM_DEVS; ++i) { TF_ASSERT_OK(statuses[i]); - ASSERT_EQ(cps[i].instance.device_names.size(), 3); + ASSERT_EQ(cps[i].group.device_names.size(), 3); for (int j = 0; j < NUM_DEVS; ++j) { EXPECT_EQ( strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j), - cps[i].instance.device_names[j]); + cps[i].group.device_names[j]); EXPECT_TRUE(cps[i].task.is_local[j]); } EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0); EXPECT_FALSE(cps[i].is_source); EXPECT_EQ(cps[i].default_rank, i); - EXPECT_TRUE(cps[i].instance.same_num_devices_per_task); + EXPECT_TRUE(cps[i].group.same_num_devices_per_task); } } @@ -233,8 +217,6 @@ void InitializeCollectiveParamsForBroadcast(int instance_key, int device_idx, cp->instance.type = BROADCAST_COLLECTIVE; cp->instance.data_type = DataType(DT_FLOAT); cp->instance.shape = TensorShape({5}); - cp->instance.device_names.push_back(strings::StrCat( - "/job:localhost/replica:0/task:0/device:CPU:", device_idx)); cp->instance.impl_details.subdiv_offsets.push_back(0); cp->is_source = is_source; } @@ -248,13 +230,14 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { CollectiveParams* cp = &cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, i == 1, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync( - GetDeviceAttributes(cp->instance.device_names[0]), cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + string device = + strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { @@ -262,16 +245,16 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { } for (int i = 0; i < NUM_DEVS; ++i) { TF_ASSERT_OK(statuses[i]); - ASSERT_EQ(cps[i].instance.device_names.size(), 3); + ASSERT_EQ(cps[i].group.device_names.size(), 3); for (int j = 0; j < NUM_DEVS; ++j) { EXPECT_EQ( strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", j), - cps[i].instance.device_names[j]); + cps[i].group.device_names[j]); EXPECT_TRUE(cps[i].task.is_local[j]); } EXPECT_EQ(cps[i].is_source, (i == 1)); EXPECT_EQ(cps[i].default_rank, i); - EXPECT_TRUE(cps[i].instance.same_num_devices_per_task); + EXPECT_TRUE(cps[i].group.same_num_devices_per_task); } } @@ -287,13 +270,14 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) { CollectiveParams* cp = &cps[i]; InitializeCollectiveParamsForBroadcast(kInstanceKey, i, false, cp); Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() { - prl_->CompleteParamsAsync( - GetDeviceAttributes(cp->instance.device_names[0]), cp, - nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { - statuses[i] = s; - note[i].Notify(); - }); + string device = + strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); + prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, + nullptr /*CancellationManager*/, + [&statuses, ¬e, i](const Status& s) { + statuses[i] = s; + note[i].Notify(); + }); }); } for (int i = 0; i < NUM_DEVS; ++i) { diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc index a94e6cb0a36..1a5001758ea 100644 --- a/tensorflow/core/common_runtime/collective_util.cc +++ b/tensorflow/core/common_runtime/collective_util.cc @@ -60,8 +60,8 @@ string SubdivPermDebugString(const CollectiveParams& col_params) { for (int di = 0; di < subdiv_perms[sdi].size(); ++di) { int idx = subdiv_perms[sdi][di]; if (idx >= 0) { - CHECK_GT(col_params.instance.device_names.size(), idx); - strings::StrAppend(&buf, col_params.instance.device_names[idx], "\n"); + CHECK_GT(col_params.group.device_names.size(), idx); + strings::StrAppend(&buf, col_params.group.device_names[idx], "\n"); } } strings::StrAppend(&buf, " subdiv_offsets: "); diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 9e2db9faaf1..83785e3341b 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -12,191 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -// A Device is a something that can perform computations as part of a -// model. Devices can be local (runs computation on this machine), or -// remote (contacts a device local to another machine using an RPC to -// do the work). Devices are registered in a DeviceSet, which is also -// responsible for the Device <-> id mapping. -// -// Device names -// * Every Device should have a unique name with the format: -// /job:___/replica:___/task:___/(gpu|cpu):___ -// An example name would be "/job:train/replica:0/task:3/device:GPU:2". -// * Task numbers are within the specified replica, so there are as -// many "task zeros" as replicas. - #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_ -#include -#include - -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/control_flow.h" -#include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/op_segment.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/device_name_utils.h" - -namespace tensorflow { - -class Device : public DeviceBase { - public: - // Callback type that takes a Status and returns void. - typedef std::function DoneCallback; - - Device(Env* env, const DeviceAttributes& device_attributes); - ~Device() override; - - // Full name of this device (see top comment). - const std::string& name() const override { return device_attributes_.name(); } - - // Parsed name of this device - const DeviceNameUtils::ParsedName& parsed_name() const { - return parsed_name_; - } - - // Describes what kind of device this is. This is intended to be - // human-readable and not computer-parsed, except that two devices - // with the same device_type() are expected to perform similarly - // (both from a computation and communication perspective). - const std::string& device_type() const { - return device_attributes_.device_type(); - } - - // Returns an aggregation of device attributes. - const DeviceAttributes& attributes() const override { - return device_attributes_; - } - - // Performs the actual compute function. - // - // Subclasses may override this function if they wish to perform - // some initialization before each compute. - virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) { - op_kernel->Compute(context); - } - - // Asynchronous kernel's compute. - virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, - AsyncOpKernel::DoneCallback done) { - op_kernel->ComputeAsync(context, std::move(done)); - } - - // Blocks until all operations queued on the device at the time of - // the call have completed. Returns any error pending on the device - // at completion. - virtual Status Sync() = 0; - - // Calls the given callback when all operations queued on the device at the - // time of the call have completed. The callback is passed any error pending - // on the device at completion. - // TODO(b/112409994): Consolidate these two APIs, removing the synchronous - // version. - virtual void Sync(const DoneCallback& done); - - // On session completion, the executor may call Device::Sync() depending on - // flag settings. Override this to return false for devices that don't allow - // such calls. Instead, these devices must use other mechanisms (such as - // num_deferred_ops) to ensure the device has finished processing necessary - // work at session completion. In addition, for these devices, RefreshStatus - // must be called at session completion to retrieve execution result status. - // - // Devices that override this function must also implement RefreshStatus. - virtual bool AllowsSyncOnCompletion() const { return true; } - - // This is used in conjunction with AllowsSyncOnCompletion to allow the - // executor to get execution result status at session completion. - // - // For supported devices, this call returns the underlying device stream's - // current status in a non-blocking way, without using blocking calls such as - // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device - // status is also updated with the retrieved stream status. - virtual Status RefreshStatus() { - return errors::Unimplemented( - "RefreshStatus is not supported on this device."); - } - - // Optionally modify the device's GraphDef before execution. - // - // This method should be considered experimental and is supplied to enable - // prototyping of TensorFlow device implementations that need to modify - // the GraphDef before execution. - // - // 'graph' supplies the partition of the graph assigned to this - // device. - virtual Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { - return Status::OK(); - } - - // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr - // if the device does not support contexts. Returns an error status if any - // error occurred while trying to create a context, otherwise OK. - // - // The caller takes ownership of one reference on the output DeviceContext*, - // and should call Unref(). - virtual Status TryGetDeviceContext(DeviceContext** out_context) { - *out_context = nullptr; - return Status::OK(); - } - - // Returns the op segment of this device. The caller can reuse op - // kernels registered for the same session running on this device. - OpSegment* op_segment() { return &op_seg_; } - - // Returns the resource manager associated w/ this device. - virtual ResourceMgr* resource_manager() { return rmgr_; } - - // Summarizes the status of this Device, for debugging. - std::string DebugString() const { return device_attributes_.DebugString(); } - - // Assembles the parameter components into a complete DeviceAttributes value. - static DeviceAttributes BuildDeviceAttributes( - const std::string& name, DeviceType device, Bytes memory_limit, - const DeviceLocality& locality, const std::string& physical_device_desc); - - static DeviceAttributes BuildDeviceAttributes( - const std::string& name, DeviceType device, Bytes memory_limit, - const DeviceLocality& locality) { - // Pass in an empty string as physical device name. - return BuildDeviceAttributes(name, device, memory_limit, locality, ""); - } - - // Clears the resource manager associated with this device. - void ClearResourceMgr() { rmgr_->Clear(); } - - virtual bool IsLocal() const { return true; } - - protected: - void DeleteResourceMgr() { - delete rmgr_; - rmgr_ = nullptr; - } - - private: - const DeviceAttributes device_attributes_; - DeviceNameUtils::ParsedName parsed_name_; - - // op_seg_ maps session handle and op name to OpKernel objects. - OpSegment op_seg_; - - // Resources associated w/ this device. E.g., shared variables, etc. - ResourceMgr* rmgr_ = nullptr; - - TF_DISALLOW_COPY_AND_ASSIGN(Device); -}; - -} // namespace tensorflow +#include "tensorflow/core/framework/device.h" #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h index f10a718db05..1b5a662639a 100644 --- a/tensorflow/core/common_runtime/device_factory.h +++ b/tensorflow/core/common_runtime/device_factory.h @@ -12,140 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_ -#include -#include - -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -class Device; -struct SessionOptions; - -class DeviceFactory { - public: - virtual ~DeviceFactory() {} - static void Register(const std::string& device_type, DeviceFactory* factory, - int priority); - static DeviceFactory* GetFactory(const std::string& device_type); - - // Append to "*devices" all suitable devices, respecting - // any device type specific properties/counts listed in "options". - // - // CPU devices are added first. - static Status AddDevices(const SessionOptions& options, - const std::string& name_prefix, - std::vector>* devices); - - // Helper for tests. Create a single device of type "type". The - // returned device is always numbered zero, so if creating multiple - // devices of the same type, supply distinct name_prefix arguments. - static std::unique_ptr NewDevice(const string& type, - const SessionOptions& options, - const string& name_prefix); - - // Iterate through all device factories and build a list of all of the - // possible physical devices. - // - // CPU is are added first. - static Status ListAllPhysicalDevices(std::vector* devices); - - // Get details for a specific device among all device factories. - // 'device_index' indexes into devices from ListAllPhysicalDevices. - static Status GetAnyDeviceDetails( - int device_index, std::unordered_map* details); - - // For a specific device factory list all possible physical devices. - virtual Status ListPhysicalDevices(std::vector* devices) = 0; - - // Get details for a specific device for a specific factory. Subclasses - // can store arbitrary device information in the map. 'device_index' indexes - // into devices from ListPhysicalDevices. - virtual Status GetDeviceDetails(int device_index, - std::unordered_map* details) { - return Status::OK(); - } - - // Most clients should call AddDevices() instead. - virtual Status CreateDevices( - const SessionOptions& options, const std::string& name_prefix, - std::vector>* devices) = 0; - - // Return the device priority number for a "device_type" string. - // - // Higher number implies higher priority. - // - // In standard TensorFlow distributions, GPU device types are - // preferred over CPU, and by default, custom devices that don't set - // a custom priority during registration will be prioritized lower - // than CPU. Custom devices that want a higher priority can set the - // 'priority' field when registering their device to something - // higher than the packaged devices. See calls to - // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used - // for built-in devices. - static int32 DevicePriority(const std::string& device_type); -}; - -namespace dfactory { - -template -class Registrar { - public: - // Multiple registrations for the same device type with different priorities - // are allowed. Priorities are used in two different ways: - // - // 1) When choosing which factory (that is, which device - // implementation) to use for a specific 'device_type', the - // factory registered with the highest priority will be chosen. - // For example, if there are two registrations: - // - // Registrar("CPU", 125); - // Registrar("CPU", 150); - // - // then CPUFactory2 will be chosen when - // DeviceFactory::GetFactory("CPU") is called. - // - // 2) When choosing which 'device_type' is preferred over other - // DeviceTypes in a DeviceSet, the ordering is determined - // by the 'priority' set during registration. For example, if there - // are two registrations: - // - // Registrar("CPU", 100); - // Registrar("GPU", 200); - // - // then DeviceType("GPU") will be prioritized higher than - // DeviceType("CPU"). - // - // The default priority values for built-in devices is: - // GPU: 210 - // GPUCompatibleCPU: 70 - // ThreadPoolDevice: 60 - // Default: 50 - explicit Registrar(const std::string& device_type, int priority = 50) { - DeviceFactory::Register(device_type, new Factory(), priority); - } -}; - -} // namespace dfactory - -#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \ - INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ - __COUNTER__, ##__VA_ARGS__) - -#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ - ctr, ...) \ - static ::tensorflow::dfactory::Registrar \ - INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \ - ##__VA_ARGS__) - -// __COUNTER__ must go through another macro to be properly expanded -#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_ - -} // namespace tensorflow +#include "tensorflow/core/framework/device_factory.h" #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_ diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index a049de9a188..91f60d6ebe2 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -451,7 +451,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", - "//tensorflow/core:mkl_graph_util", + "//tensorflow/core/graph:mkl_graph_util", ], alwayslink = 1, ) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 7362b4738b8..757ac1f7783 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -573,7 +573,7 @@ Status EagerContext::FindFunctionOpData( return func_lib_def_.LookUp(name, op_data); } -const FunctionDef* EagerContext::FindFunctionDef(const string& name) { +const FunctionDef* EagerContext::FindFunctionDef(const string& name) const { return func_lib_def_.Find(name); } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index d0f1a6c4e78..f48da696d48 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -79,21 +79,6 @@ namespace eager { class RemoteMgr; } // namespace eager -// LINT.IfChange -// Note: Keep in sync with exported copy of enum in eager/c_api.h. -enum ContextDevicePlacementPolicy { - // Running operations with input tensors on the wrong device will fail. - DEVICE_PLACEMENT_EXPLICIT = 0, - // Copy the tensor to the right device but log a warning. - DEVICE_PLACEMENT_WARN = 1, - // Silently copy the tensor, which has a performance cost since the operation - // will be blocked till the copy completes. This is the default policy. - DEVICE_PLACEMENT_SILENT = 2, - // Placement policy which silently copies int32 tensors but not other dtypes. - DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, -}; -// LINT.ThenChange(//tensorflow/c/eager/c_api.h) - class RunMetadataListener { public: virtual ~RunMetadataListener() {} @@ -186,7 +171,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { std::function)>* runner() { return &runner_; } // Specify a executor for this thread. - void SetExecutorForThread(EagerExecutor* executor); + void SetExecutorForThread(EagerExecutor* executor) override; const std::shared_ptr> prioritized_device_type_list() const { @@ -195,15 +180,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { } // Clear pending nodes in thread executors and kernel caches. - void ClearCachesAndThreadExecutors(); + void ClearCachesAndThreadExecutors() override; // Clear pending nodes in default executor and kernel caches. void ClearCachesAndDefaultExecutor(); // Sets the device placement policy for the current thread. - void SetThreadLocalDevicePlacementPolicy(ContextDevicePlacementPolicy policy); + void SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) override; // Returns the device placement policy for the current thread. - ContextDevicePlacementPolicy GetDevicePlacementPolicy() const; + ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override; // Select an appropriate device for an operation. // @@ -227,16 +213,19 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { Status FindFunctionOpData(const string& name, const tensorflow::OpRegistrationData** op_data); - const FunctionDef* FindFunctionDef(const string& name); + const FunctionDef* FindFunctionDef(const string& name) const override; Device* HostCPU() const { return host_cpu_device_; } Device* CanonicalDevice(Device* d) const { return HostCPU() == d ? nullptr : d; } + const DeviceNameUtils::ParsedName& HostCPUParsedName() const override { + return HostCPU()->parsed_name(); + } GraphCollector* GetGraphCollector() { return &graph_collector_; } - EagerExecutor& Executor(); + EagerExecutor& Executor() override; // Add the given `fdef` to the local FunctionLibraryDefinition. And add an // entry to the KernelAndDevice cache for it if it's not exist. @@ -267,9 +256,13 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() const { return log_device_placement_; } - void SetLogDevicePlacement(bool enable) { log_device_placement_ = enable; } + void SetLogDevicePlacement(bool enable) override { + log_device_placement_ = enable; + } bool AllowSoftPlacement() const { return allow_soft_placement_; } - void SetAllowSoftPlacement(bool enable) { allow_soft_placement_ = enable; } + void SetAllowSoftPlacement(bool enable) override { + allow_soft_placement_ = enable; + } bool LogMemory() const { return log_memory_; } Rendezvous* GetRendezvous() const { return rendezvous_; } @@ -316,7 +309,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // TODO(apassos) clean up RunMetadata storage. mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_); - void SetShouldStoreGraphs(bool value); + void SetShouldStoreGraphs(bool value) override; RunMetadata* RunMetadataProto() { return &run_metadata_; } void ClearRunMetadata() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_); diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h index daa9e8f52fb..133799e6ab3 100644 --- a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h +++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h @@ -55,7 +55,8 @@ class EagerOpRewriteRegistry { // Phases at which the Eager op rewrite pass should run. // For now we only added PRE_EXECUTION. Expand as needed. enum Phase { - PRE_EXECUTION = 0 // right before executing an eager op + PRE_EXECUTION = 0, // right before executing an eager op + POST_PLACEMENT = 1 // after device placement }; // Add a rewrite pass to the registry. @@ -70,7 +71,7 @@ class EagerOpRewriteRegistry { static EagerOpRewriteRegistry* Global(); private: - static constexpr int32 kNumPhases = 1; + static constexpr int32 kNumPhases = 2; // Holds all the registered Eager op rewrites. std::array, kNumPhases> rewrites_; }; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 8b3126fc186..5102681934e 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/util/device_name_utils.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/distributed_runtime/eager/eager_client.h" @@ -337,6 +338,30 @@ void AppendTensorShapeToFingerprint(const PartialTensorShape& shape, } } +Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, + const char* attr_name, bool* value) { + Status status = op->Attrs().Get(attr_name, value); + if (status.ok()) { + DVLOG(2) << "Caller explicitly specifies " + << (attr_name ? "=true " : "=false, ") << op->DebugString(); + return Status::OK(); + } + + const FunctionDef* function_def = + ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name()); + if (function_def == nullptr) { + return errors::NotFound("Failed to find function '", op->Name(), "'"); + } + + status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value); + if (status.ok()) { + DVLOG(2) << "Function definition explicitly specifies " + << (attr_name ? "=true" : "=false"); + return Status::OK(); + } + return status; +} + Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, bool* compile_with_xla) { if (!op->is_function()) { @@ -352,27 +377,8 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, return Status::OK(); } - // Does node have an explicit request to compile or not? - Status status = op->Attrs().Get(kXlaMustCompileAttr, compile_with_xla); + Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla); if (status.ok()) { - DVLOG(2) << "Caller explicitly requested " - << (*compile_with_xla ? "" : "not ") - << "to compile with XLA: " << op->DebugString(); - return Status::OK(); - } - - // Does FunctionDef have an explicit request to compile or not? - const FunctionDef* function_def = - ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name()); - if (function_def == nullptr) { - return errors::NotFound("Failed to find function '", op->Name(), "'"); - } - - status = GetNodeAttr(AttrSlice(&function_def->attr()), kXlaMustCompileAttr, - compile_with_xla); - if (status.ok()) { - DVLOG(2) << "Function definition explicitly specifies " - << (*compile_with_xla ? "" : "not ") << "to compile with XLA"; return Status::OK(); } @@ -488,6 +494,7 @@ Status GetOrCreateKernelAndDevice( DVLOG(2) << "Creating new kernel for " << op->Name() << " on device " << DeviceNameOrUnspecified(op->Device()); bool run_function_with_flr = false; + bool function_outputs_on_op_device = false; if (op->is_function()) { bool compile_with_xla; TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla)); @@ -499,6 +506,8 @@ Status GetOrCreateKernelAndDevice( } else { run_function_with_flr = true; } + GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device) + .IgnoreError(); } const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); @@ -548,6 +557,7 @@ Status GetOrCreateKernelAndDevice( std::move(composite_devices), std::move(input_resource_variable_dtypes_and_shapes), runner, ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(), + function_outputs_on_op_device, [&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); }, get_op_id)); } else { @@ -721,8 +731,23 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, TF_RETURN_IF_ERROR(executor.status()); core::RefCountPtr kernel; - TF_RETURN_IF_ERROR( - GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel)); + auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); + + // Run all the registered rewrite pass after the placement, regardless whether + // the placement is successful or not. The passes can either create new ops + // (without placement) or update some fields of the input op. + std::unique_ptr out_op; + TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite( + EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op)); + if (out_op) { + op = out_op.get(); + // If the out op doesn't have device, either because it is a new op or + // the op wasn't placed successfully, then we do the placement again. + if (op->Device() == kVariantDeviceNull) { + status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); + } + } + if (!status.ok()) return status; int num_outputs = kernel->num_outputs(); TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel)); diff --git a/tensorflow/core/common_runtime/eager/execute_node_test.cc b/tensorflow/core/common_runtime/eager/execute_node_test.cc index e2d7a7a2da1..5f0cd660f12 100644 --- a/tensorflow/core/common_runtime/eager/execute_node_test.cc +++ b/tensorflow/core/common_runtime/eager/execute_node_test.cc @@ -37,7 +37,7 @@ class TestKernelAndDeviceFunc final : public KernelAndDeviceFunc { /*flr=*/nullptr, /*pflr=*/nullptr, /*input_devices=*/{}, /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{}, /*runner=*/nullptr, /*collective_executor=*/nullptr, - host_cpu_device, /*name=*/"", + host_cpu_device, /*name=*/"", /*outputs_on_op_device=*/false, /*rendezvous_creator=*/nullptr, /*get_op_id=*/nullptr), test_input_devices_(std::move(input_devices)) {} diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index e0e87c36f03..d2327ff9592 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -160,6 +160,17 @@ Status KernelAndDeviceFunc::InstantiateFunc(const Context& ctx, } options.composite_devices = composite_devices_; options.input_resource_dtypes_and_shapes = input_resource_dtypes_and_shapes_; + if (outputs_on_op_device_) { + const FunctionLibraryDefinition* lib_def = + pflr_->GetFunctionLibraryDefinition(); + const FunctionDef* fdef = lib_def->Find(ndef.op()); + if (fdef == nullptr) { + return errors::InvalidArgument("Failed to find function ", ndef.op()); + } + for (int i = 0; i < fdef->signature().output_arg_size(); ++i) { + options.output_devices.push_back(options.target); + } + } const auto& it = ndef.attr().find("executor_type"); if (it != ndef.attr().end()) { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 0a765510d7b..f48452aa46b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -47,6 +47,8 @@ limitations under the License. namespace tensorflow { +static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice"; + class ProcessFunctionLibraryRuntime; class FunctionLibraryRuntime; @@ -270,12 +272,14 @@ class KernelAndDeviceFunc : public KernelAndDevice { std::function)>* runner, std::unique_ptr collective_executor, Device* host_cpu_device, const string& name, + const bool outputs_on_op_device, std::function rendezvous_creator, std::function get_op_id) : KernelAndDevice(flr, runner, std::move(collective_executor), host_cpu_device), pflr_(pflr), handle_(kInvalidHandle), + outputs_on_op_device_(outputs_on_op_device), input_devices_(std::move(input_devices)), composite_devices_(std::move(composite_devices)), input_resource_dtypes_and_shapes_( @@ -336,6 +340,11 @@ class KernelAndDeviceFunc : public KernelAndDevice { FunctionLibraryRuntime::Handle handle_; // Indicates whether the function needs to execute cross process. bool is_cross_process_; + + // If true, function outputs are explicitly assigned to the default device; + // if false, the output devices are inferred by pflr_. + bool outputs_on_op_device_; + // CPU devices are null. Resource handles' devices are actual backing // devices. std::vector output_devices_; diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc index 4f89c3b6b60..a18301231cf 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc @@ -25,11 +25,15 @@ namespace tensorflow { class EagerOpRewriteTest : public ::testing::Test { public: - EagerOpRewriteTest() {} + EagerOpRewriteTest() : eager_ctx_(nullptr) {} + ~EagerOpRewriteTest() { + if (eager_ctx_) { + eager_ctx_->Unref(); + } + } // Creates a new op to be used as input to MKL eager rewrite. - static std::unique_ptr CreateOp( - const string op_name) { + std::unique_ptr CreateOp(const string op_name) { std::unique_ptr device_mgr = absl::make_unique(DeviceFactory::NewDevice( "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0")); @@ -37,22 +41,21 @@ class EagerOpRewriteTest : public ::testing::Test { bool lazy_remote_tensor_copy = false; tensorflow::Rendezvous* rendezvous = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - tensorflow::EagerContext* eager_ctx = new tensorflow::EagerContext( + eager_ctx_ = new tensorflow::EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, async, lazy_remote_tensor_copy, device_mgr.get(), false, rendezvous); EagerExecutor executor_(false); std::unique_ptr op( - new tensorflow::EagerOperation(eager_ctx)); + new tensorflow::EagerOperation(eager_ctx_)); EXPECT_EQ(Status::OK(), op.get()->Reset(op_name.c_str(), nullptr, false, &executor_)); - eager_ctx->Unref(); return op; } // Validates the result of MKL eager rewrite. - static void CheckRewrite(EagerOperation* orig_op, string expected_op_name) { + void CheckRewrite(EagerOperation* orig_op, string expected_op_name) { std::unique_ptr out_op; EagerOpRewriteRegistry::Global()->RunRewrite( EagerOpRewriteRegistry::PRE_EXECUTION, orig_op, &out_op); @@ -65,6 +68,9 @@ class EagerOpRewriteTest : public ::testing::Test { EXPECT_EQ(actual_op_name, expected_op_name); } + + protected: + tensorflow::EagerContext* eager_ctx_; }; #define CONV_OPS \ diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 620685ea3c1..0b17bf7e3e6 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -980,7 +980,7 @@ void TensorHandle::Poison(Status status, const Device* d) { Status TensorHandle::CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, - tensorflow::Tensor* output) { + tensorflow::Tensor* output) const { tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d; tensorflow::Device* srcd = absl::get(DeviceOrHostCPU(ctx)); const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index eed31b79b0f..f54ebd45889 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -91,6 +91,10 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // Create a handle which packs the given handles of the same dtype and shape. // If handles are on different devices, assign the packed handle to a // CompositeDevice. + // + // The new tensor handle shares ownership of the given handle: their reference + // count will be increased by one after a call to `CreatePackedHandle`. + // TODO(b/170414377): Use `TensorHandlePtr` instead. static Status CreatePackedHandle(std::vector&& handles, const tensorflow::DataType dtype, const tensorflow::TensorShape& shape, @@ -221,8 +225,9 @@ class TensorHandle : public ImmediateExecutionTensorHandle { void Poison(Status status, const Device* d); // TODO(b/154282629): Consider moving it to EagerContext. + // Copies to the tensor on the given device `d`, or to host iff `d` is null. Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, - tensorflow::Tensor* output); + tensorflow::Tensor* output) const; Status InferenceShape( shape_inference::InferenceContext* const inference_context, @@ -336,6 +341,11 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // shape. class PackedTensorHandleData { public: + // Initialize handle data from list of tensor handles. + // Ownership of the tensor handles is shared between the + // `PackedTensorHandleData` and the caller (the reference count for the + // given handles is incremented). + // TODO(b/170414377): Use `TensorHandlePtr` instead. PackedTensorHandleData(std::vector&& handles, const TensorShape& shape); @@ -357,6 +367,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle { Status ExtractPackedHandle(const int index, TensorHandle** handle) const; private: + // TODO(b/170414377): Use `TensorHandlePtr` instead. const std::vector handles_; const TensorShape shape_; diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 74305228e6f..cd3e73a30b9 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_segment.h" @@ -323,6 +324,12 @@ class ExecutorState { // REQUIRES: `!ready->empty()`. void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready); + // A wrapper for runner_ to keep track of the pending queue length. Op + // execution should dispatch work using this function instead of using runner_ + // directly. + template + void RunTask(Closure&& c); + // Clean up when this executor is done. void Finish(); void ScheduleFinish(); @@ -423,6 +430,30 @@ ExecutorState::~ExecutorState() { delete slice_reader_cache_; } +template +template +void ExecutorState::RunTask(Closure&& c) { + // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the + // cacheline size is 64 bytes or smaller. + alignas(64) static std::atomic num_enqueue_ops{0}; + alignas(64) static std::atomic num_dequeue_ops{0}; + + auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed); + // Sample the queue length on every 16 enqueue operations. This amortizes the + // cost of metric updates across 16 operations. + if (n_enqueues % 16 == 0) { + auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed); + metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues); + } + + // mutable is needed because std::forward in the lambda body may move + // the Closure `c`. + runner_([c = std::forward(c)]() mutable { + num_dequeue_ops.fetch_add(1, std::memory_order_relaxed); + std::forward(c)(); + }); +} + template void ExecutorState::RunAsync(Executor::DoneCallback done) { TaggedNodeSeq ready; @@ -1116,7 +1147,7 @@ void ExecutorState::ScheduleReady( // regardless of the `runner_` implementation, all kernels will run // sequentially on the same thread, and thread wakeup overhead and // executor mutex contention will be minimized. - runner_([this, ready = std::move(*ready), scheduled_nsec]() { + RunTask([this, ready = std::move(*ready), scheduled_nsec]() { for (auto& tagged_node : ready) { Process(tagged_node, scheduled_nsec); } @@ -1131,7 +1162,7 @@ void ExecutorState::ScheduleReady( if (inline_ready == nullptr) { // Schedule to run all the ready ops in thread pool. for (auto& tagged_node : *ready) { - runner_([=]() { Process(tagged_node, scheduled_nsec); }); + RunTask([=]() { Process(tagged_node, scheduled_nsec); }); } } else { for (auto& tagged_node : *ready) { @@ -1143,7 +1174,7 @@ void ExecutorState::ScheduleReady( if (curr_expensive_node) { // Dispatch to another thread since there is plenty of work to // do for this thread. - runner_(std::bind(&ExecutorState::Process, this, + RunTask(std::bind(&ExecutorState::Process, this, *curr_expensive_node, scheduled_nsec)); } curr_expensive_node = &tagged_node; @@ -1156,7 +1187,7 @@ void ExecutorState::ScheduleReady( } else { // There are inline nodes to run already. We dispatch this expensive // node to other thread. - runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node, + RunTask(std::bind(&ExecutorState::Process, this, *curr_expensive_node, scheduled_nsec)); } } diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index efbf30d4e65..d500944c3e5 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -58,7 +58,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ], ) @@ -153,7 +153,6 @@ tf_cuda_library( ":gpu_id_impl", ":gpu_init_impl", ":gpu_lib", - "//tensorflow/core:core_cpu_impl", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -161,7 +160,8 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor", + "//tensorflow/core/common_runtime:core_cpu_impl", + "//tensorflow/core/platform:stream_executor", "//tensorflow/core/platform:tensor_float_32_utils", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:scoped_annotation", @@ -181,7 +181,7 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//third_party/eigen3", ] + if_static([":gpu_runtime_impl"]), ) @@ -219,8 +219,8 @@ tf_cuda_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", "//tensorflow/core/framework:allocator", + "//tensorflow/core/platform:stream_executor", ], ) @@ -234,7 +234,7 @@ tf_cuda_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ] + if_static( [":gpu_init_impl"], ), @@ -255,7 +255,7 @@ tf_cuda_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ], alwayslink = 1, ) diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 57af898ecd2..17e0ee4da1f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -601,7 +601,11 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { // Based on the semantics of Device::Sync this call should wait for // all streams not just the current one. -Status BaseGPUDevice::Sync() { return GPUUtil::SyncAll(this); } +Status BaseGPUDevice::Sync() { + return tensorflow_gpu_device_info() + ->stream->parent() + ->BlockHostUntilAllStreamsAreDone(); +} void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index ea38349d61c..e1d81a18464 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -80,20 +80,20 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( CHECK_EQ(col_params->instance.impl_details.collective_name, "HierarchicalTreeBroadcast"); const string& device_name = - col_params->instance.device_names[col_params->default_rank]; + col_params->group.device_names[col_params->default_rank]; // Start by counting the devices in each task. // Precondition: device_names must be sorted so that all devices in // the same task are adjacent. VLOG(2) << "Sorted task names: " - << absl::StrJoin(col_params->instance.task_names, ", "); + << absl::StrJoin(col_params->group.task_names, ", "); std::vector dev_per_task; - const string* prior_task_name = &col_params->instance.task_names[0]; + const string* prior_task_name = &col_params->group.task_names[0]; int dev_count = 1; for (int di = 1; di < col_params->group.group_size; ++di) { - if (col_params->instance.task_names[di] != *prior_task_name) { + if (col_params->group.task_names[di] != *prior_task_name) { dev_per_task.push_back(dev_count); dev_count = 1; - prior_task_name = &col_params->instance.task_names[di]; + prior_task_name = &col_params->group.task_names[di]; } else { ++dev_count; } @@ -135,14 +135,13 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( if (source_task == ti) { // Source device belongs to this task. perm.push_back(col_params->source_rank); - participate = - col_params->instance.device_names[col_params->source_rank] == - device_name; + participate = col_params->group.device_names[col_params->source_rank] == + device_name; } else { // Source does not belong to this task, choose dev 0. perm.push_back(device_count); participate = - col_params->instance.device_names[device_count] == device_name; + col_params->group.device_names[device_count] == device_name; } if (participate) col_params->subdiv_rank.push_back(ti); device_count += dev_per_task[ti]; @@ -165,7 +164,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( int subdiv_source = 0; for (int di = 0; di < dev_per_task[ti]; di++) { perm.push_back(abs_di); - if (col_params->instance.device_names[abs_di] == device_name) { + if (col_params->group.device_names[abs_di] == device_name) { participate = true; col_params->subdiv_rank.push_back(di); } @@ -417,11 +416,11 @@ void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank, col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank]; VLOG(3) << "DispatchSend " << send_buf_key << " from_device " << col_ctx_->device_name << " to_device " - << col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv + << col_params_->group.device_names[dst_idx] << " subdiv=" << subdiv << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx; col_ctx_->col_exec->remote_access()->PostToPeer( - col_params_->instance.device_names[dst_idx], - col_params_->instance.task_names[dst_idx], send_buf_key, col_ctx_->device, + col_params_->group.device_names[dst_idx], + col_params_->group.task_names[dst_idx], send_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), src_tensor, col_ctx_->device_locality, done); @@ -435,12 +434,12 @@ void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank, int src_idx = col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank]; VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device " - << col_params_->instance.device_names[src_idx] << " to_device " + << col_params_->group.device_names[src_idx] << " to_device " << col_ctx_->device_name << " subdiv=" << subdiv << " src_rank=" << src_rank << " src_idx=" << src_idx; col_ctx_->col_exec->remote_access()->RecvFromPeer( - col_params_->instance.device_names[src_idx], - col_params_->instance.task_names[src_idx], + col_params_->group.device_names[src_idx], + col_params_->group.task_names[src_idx], col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc index 41d88595b77..112e6c9881c 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc @@ -328,8 +328,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { dev_name = strings::StrCat(task_name, "/device:CPU:", di); } VLOG(2) << "dev=" << dev_name; - col_params_.instance.device_names.push_back(dev_name); - col_params_.instance.task_names.push_back(task_name); + col_params_.group.device_names.push_back(dev_name); + col_params_.group.task_names.push_back(task_name); col_params_.task.is_local.push_back(true); } } @@ -337,7 +337,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { for (int di = 0; di < num_devices_per_worker; di++) { int default_rank = wi * num_devices_per_worker + di; instances_.push_back(new DeviceInstance( - default_rank, col_params_.instance.device_names[default_rank], + default_rank, col_params_.group.device_names[default_rank], device_type, this)); } } @@ -539,8 +539,8 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { string task_name = strings::StrCat("/job:worker/replica:0/task:", ti); for (int di = 0; di < num_gpus; di++) { string dev_name = strings::StrCat(task_name, "/device:GPU:", di); - cp->instance.task_names.push_back(task_name); - cp->instance.device_names.push_back(dev_name); + cp->group.task_names.push_back(task_name); + cp->group.device_names.push_back(dev_name); } } } @@ -557,22 +557,16 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_)); col_params_.name = parent_->col_params_.name; col_params_.instance.data_type = parent_->col_params_.instance.data_type; - col_params_.group.group_key = parent_->col_params_.group.group_key; + col_params_.group = parent_->col_params_.group; col_params_.instance.instance_key = parent_->col_params_.instance.instance_key; - col_params_.group.device_type = parent_->col_params_.group.device_type; - col_params_.group.group_size = parent_->col_params_.group.group_size; - col_params_.instance.device_names = - parent_->col_params_.instance.device_names; - col_params_.instance.task_names = - parent_->col_params_.instance.task_names; col_params_.task.is_local = parent_->col_params_.task.is_local; col_params_.instance.impl_details.subdiv_permutations = parent_->col_params_.instance.impl_details.subdiv_permutations; col_params_.subdiv_rank = parent_->col_params_.subdiv_rank; int group_size = col_params_.group.group_size; - CHECK_EQ(group_size, col_params_.instance.device_names.size()); + CHECK_EQ(group_size, col_params_.group.device_names.size()); // Default rank is order in device_names. col_params_.default_rank = rank; @@ -789,8 +783,8 @@ TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) { string task_name = strings::StrCat("/job:worker/replica:0/task:", ti); for (int di = 0; di < dev_per_task[ti]; di++) { string dev_name = strings::StrCat(task_name, "/device:GPU:", di); - cp.instance.task_names.push_back(task_name); - cp.instance.device_names.push_back(dev_name); + cp.group.task_names.push_back(task_name); + cp.group.device_names.push_back(dev_name); cp.group.group_size++; } } diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index 5cde4f9049c..ff010ad8a63 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -65,6 +65,11 @@ class CondBuilder { // Adds input to both the then and else nodes from src:src_output. Status AddInput(Node* src, int src_output); + // Finalizes the node described by `node_builder`. If `coloc_attr_` is not + // nullptr, adds the colocation attr to the node before finalizing it. + Status SetColocationAndFinalize(NodeBuilder node_builder, Graph* graph, + Node** created_node); + // The merged outputs of the then and else nodes. std::vector outputs_; @@ -72,6 +77,10 @@ class CondBuilder { Node* control_predecessor_; // The original If op. Node* if_op_; + // The colocation attr on the original If op. If it exists, control flow nodes + // created in the lowering (except the data Switch nodes) will inherit this + // attribute. + const AttrValue* coloc_attr_; // The node with the same name as the original If op: // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true' // and if the original If op had non-zero data outputs. @@ -102,9 +111,10 @@ class CondBuilder { }; CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn, - const NameAttrList& else_fn, - bool keep_node_fetchable, Graph* graph) + const NameAttrList& else_fn, bool keep_node_fetchable, + Graph* graph) : if_op_(if_op), + coloc_attr_(if_op_->attrs().Find(kColocationAttrName)), graph_(graph), name_(if_op->name()), keep_node_fetchable_(keep_node_fetchable), @@ -126,27 +136,39 @@ CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn, } } +Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder, + Graph* graph, + Node** created_node) { + if (coloc_attr_ != nullptr) { + node_builder = node_builder.Attr(kColocationAttrName, *coloc_attr_); + } + return node_builder.Finalize(graph, created_node); +} + Status CondBuilder::CreatePivotNodes() { // Construct the basic cond body (consisting of feeding in the predicate to // create pivot nodes). Node* switch_pred; - TF_RETURN_IF_ERROR(NodeBuilder(NewName("switch_pred"), "Switch", - graph_->op_registry(), &debug_info_) - .Input(NodeOut(pred_)) - .Input(NodeOut(pred_)) - .Device(if_op_->requested_device()) - .Finalize(graph_, &switch_pred)); + TF_RETURN_IF_ERROR( + SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch", + graph_->op_registry(), &debug_info_) + .Input(NodeOut(pred_)) + .Input(NodeOut(pred_)) + .Device(if_op_->requested_device()), + graph_, &switch_pred)); control_predecessor_ = switch_pred; - TF_RETURN_IF_ERROR(NodeBuilder(NewName("pivot_f"), "Identity", - graph_->op_registry(), &debug_info_) - .Input(switch_pred, kElseBranch) - .Device(if_op_->requested_device()) - .Finalize(graph_, &pivot_f_)); - TF_RETURN_IF_ERROR(NodeBuilder(NewName("pivot_t"), "Identity", - graph_->op_registry(), &debug_info_) - .Input(switch_pred, kThenBranch) - .Device(if_op_->requested_device()) - .Finalize(graph_, &pivot_t_)); + TF_RETURN_IF_ERROR( + SetColocationAndFinalize(NodeBuilder(NewName("pivot_f"), "Identity", + graph_->op_registry(), &debug_info_) + .Input(switch_pred, kElseBranch) + .Device(if_op_->requested_device()), + graph_, &pivot_f_)); + TF_RETURN_IF_ERROR( + SetColocationAndFinalize(NodeBuilder(NewName("pivot_t"), "Identity", + graph_->op_registry(), &debug_info_) + .Input(switch_pred, kThenBranch) + .Device(if_op_->requested_device()), + graph_, &pivot_t_)); return Status::OK(); } @@ -160,19 +182,24 @@ Status CondBuilder::AddInput(Node* src, int src_output) { // Colocate the Switch node with the `src` node. // // This is to avoid unnecessary Host<->Device copies between src and the - // Switch node. This aligns with the implementation of legacy tf.cond in - // control_flow_ops.py. The legacy impl colocates the Switch with the - // input tensor which resets the device stack and forces the Switch to have - // the same device as the input node (if set) and sets the colocation _class - // attr. It also ignores the existing colocation constraints on the input node - // using colocate_with(ignore_existing=True). - TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "Switch", - graph_->op_registry(), &debug_info) - .Input(src, src_output) - .Input(pred_) - .Device(src->requested_device()) - .Attr("_class", {src->name()}) - .Finalize(graph_, &input)); + // Switch node. + // + // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize`, + // and instead ignore the existing colocation stack. This is aligned with the + // legacy impl in control_flow_ops.py. The legacy impl colocates this Switch + // with the input tensor which resets the device stack and forces the Switch + // to have the same device as the input node (if set) and sets the colocation + // _class attr. It also ignores the existing colocation stack in the context + // by using colocate_with(ignore_existing=True). + TF_RETURN_IF_ERROR( + NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry(), + &debug_info) + .Input(src, src_output) + .Input(pred_) + .Device(src->requested_device()) + .Attr(kColocationAttrName, + {absl::StrCat(kColocationGroupPrefix, src->name())}) + .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); return Status::OK(); @@ -198,6 +225,8 @@ Status CondBuilder::AddInputs() { Status CondBuilder::AddOutputs() { // Construct the then and else nodes. + // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize` + // because the colocation for branch nodes is applied in python. TF_RETURN_IF_ERROR(then_call_builder_.Finalize(graph_, &then_call_node_)); graph_->AddControlEdge(pivot_t_, then_call_node_); TF_RETURN_IF_ERROR(else_call_builder_.Finalize(graph_, &else_call_node_)); @@ -207,12 +236,12 @@ Status CondBuilder::AddOutputs() { std::vector merges(then_call_node_->num_outputs()); outputs_.resize(merges.size()); for (int i = 0; i < then_call_node_->num_outputs(); ++i) { - TF_RETURN_IF_ERROR( + TF_RETURN_IF_ERROR(SetColocationAndFinalize( NodeBuilder(NewName("output"), "Merge", graph_->op_registry(), &debug_info_) .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) - .Device(if_op_->requested_device()) - .Finalize(graph_, &merges[i])); + .Device(if_op_->requested_device()), + graph_, &merges[i])); outputs_[i] = NodeOut(merges[i], 0); } @@ -226,12 +255,13 @@ Status CondBuilder::AddOutputs() { // // We will use this node to rewrite outgoing control edges from lowered 'If' // node. All data edges will read tensors directly from Merge nodes. - TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge", - graph_->op_registry(), &debug_info_) - .Input({pivot_t_, pivot_f_}) - .ControlInputs({then_call_node_, else_call_node_}) - .Device(if_op_->requested_device()) - .Finalize(graph_, &branch_executed_node_)); + TF_RETURN_IF_ERROR(SetColocationAndFinalize( + NodeBuilder(NewName("branch_executed"), "Merge", graph_->op_registry(), + &debug_info_) + .Input({pivot_t_, pivot_f_}) + .ControlInputs({then_call_node_, else_call_node_}) + .Device(if_op_->requested_device()), + graph_, &branch_executed_node_)); TF_RETURN_IF_ERROR(BuildLoweredIfOutput()); diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 176670c8aa5..99c41e4c75e 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -305,6 +305,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D"; csinfo_.mkl_native_fused_depthwise_conv2d = "_MklNativeFusedDepthwiseConv2dNative"; + csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul"; csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D"; csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D"; csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D"; @@ -407,7 +408,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.concatv2, mkl_op_registry::GetMklOpName(csinfo_.concatv2), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()}); rinfo_.push_back( {csinfo_.conjugate_transpose, mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose), @@ -496,8 +497,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { : csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D, FusedDepthwiseConv2DRewrite, GetRewriteCause()}); - rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul, - CopyAttrsAllCheckConstFilter, FusedMatMulRewrite}); + rinfo_.push_back({csinfo_.fused_matmul, + native_fmt ? csinfo_.mkl_native_fused_matmul + : csinfo_.mkl_fused_matmul, + CopyAttrsAllCheckConstFilter, FusedMatMulRewrite, + GetRewriteCause()}); rinfo_.push_back( {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), @@ -549,7 +553,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.quantized_concatv2, mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.quantized_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d), CopyAttrsQuantizedConv2D, AlwaysRewrite, @@ -961,6 +965,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string mkl_native_fused_batch_norm_ex; string mkl_native_fused_conv2d; string mkl_native_fused_depthwise_conv2d; + string mkl_native_fused_matmul; string mkl_native_pad_with_conv2d; string mkl_native_pad_with_fused_conv2d; string mkl_pad_with_conv2d; @@ -1496,6 +1501,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } return false; } + // For oneDNN, only int32 is supported for axis data type + static bool ConcatV2Rewrite(const Node* n) { + DataType T; + GetNodeAttr(n->def(), "Tidx", &T); + return (T == DT_INT32); + } static bool DequantizeRewrite(const Node* n) { DCHECK(n); @@ -2447,7 +2458,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklLayoutDependentOp( + mkl_op_registry::IsMklOp( mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an // edge from this node's fwd_slot to bwdop's bwd_slot. If there is @@ -2474,7 +2485,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklLayoutDependentOp( + mkl_op_registry::IsMklOp( mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a bwd op, then we need to add workspace edge and @@ -2498,10 +2509,14 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( CHECK_NOTNULL(ws_tensors); // Add workspace edge between fwd op and bwd op. ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); - // Add Mkl tensor edge for workspace edge between fwd op and bwd op. - ws_tensors->push_back(NodeBuilder::NodeOut( - e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, - e->src()->num_outputs()))); + // Check if we are running in native format mode. If so, + // we don't need to have an Mkl metadata tensor for the workspace. + if (!NativeFormatEnabled()) { + // Add Mkl tensor edge for workspace edge between fwd op and bwd op. + ws_tensors->push_back(NodeBuilder::NodeOut( + e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, + e->src()->num_outputs()))); + } *are_ws_tensors_added = true; // In terms of input ordering, we add these calls to add Input // here because workspace edge (and its Mkl tensor) is the last @@ -3660,6 +3675,15 @@ Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange( return s; } + std::vector workspace_tensors; + bool are_workspace_tensors_available = false; + AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors, + &are_workspace_tensors_available); + if (are_workspace_tensors_available) { + DCHECK_EQ(workspace_tensors.size(), 1); + nb.Input(workspace_tensors[0].node, workspace_tensors[0].index); + } + if (!NativeFormatEnabled()) { ri->copy_attrs(const_cast(orig_node), &nb, true); } else { diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index bbe20f4436d..4366a7892d3 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -4258,6 +4258,28 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); } +TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_IndexTest) { + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT64 } }" + " attr { key: 'value' value { " + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'InputList'" + " attr { key: 'N' value { i: 2 } }}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'ConcatV2'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'Tidx' value { type: DT_INT64 } }" + " attr { key: 'N' value { i: 2 } }" + " input: ['B:0', 'B:1', 'A']}" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" + "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); +} + TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { InitGraph( "node { name: 'A' op: 'Input'}" diff --git a/tensorflow/core/common_runtime/permuter.cc b/tensorflow/core/common_runtime/permuter.cc index 61b8dcb79c8..1053cae1f7d 100644 --- a/tensorflow/core/common_runtime/permuter.cc +++ b/tensorflow/core/common_runtime/permuter.cc @@ -87,7 +87,7 @@ void Permuter::DispatchSend(int src_rank, int target_rank, const Tensor* tensor, << " target_rank=" << target_rank << " src_rank=" << src_rank; col_ctx_->col_exec->remote_access()->PostToPeer( col_params_->instance.devices[target_rank], - col_params_->instance.task_names[target_rank], send_buf_key, + col_params_->group.task_names[target_rank], send_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, done); @@ -103,7 +103,7 @@ void Permuter::DispatchRecv(int src_rank, int target_rank, Tensor* tensor, << " target_rank=" << target_rank << " src_rank=" << src_rank; col_ctx_->col_exec->remote_access()->RecvFromPeer( col_params_->instance.devices[src_rank], - col_params_->instance.task_names[src_rank], + col_params_->group.task_names[src_rank], col_params_->task.is_local[src_rank], recv_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), tensor, col_ctx_->device_locality, diff --git a/tensorflow/core/common_runtime/permuter_test.cc b/tensorflow/core/common_runtime/permuter_test.cc index a07162142f7..1b65a9ebe5f 100644 --- a/tensorflow/core/common_runtime/permuter_test.cc +++ b/tensorflow/core/common_runtime/permuter_test.cc @@ -183,11 +183,11 @@ class PermuterTest : public ::testing::Test { } else { dev_name = strings::StrCat(task_name, "/device:CPU:", di); } - col_params_.instance.device_names.push_back(dev_name); + col_params_.group.device_names.push_back(dev_name); col_params_.instance.devices.push_back(dev_name); int default_rank = wi * num_devices_per_worker + di; permutation_.push_back(default_rank); - col_params_.instance.task_names.push_back(task_name); + col_params_.group.task_names.push_back(task_name); col_params_.task.is_local.push_back(true); } } @@ -212,7 +212,7 @@ class PermuterTest : public ::testing::Test { for (int di = 0; di < num_devices_per_worker; di++) { int default_rank = wi * num_devices_per_worker + di; instances_.push_back(new DeviceInstance( - default_rank, col_params_.instance.device_names[default_rank], + default_rank, col_params_.group.device_names[default_rank], device_type, this)); } } @@ -323,16 +323,14 @@ class PermuterTest : public ::testing::Test { col_params_.instance.instance_key = parent_->col_params_.instance.instance_key; col_params_.group.device_type = parent_->col_params_.group.device_type; - col_params_.instance.device_names = - parent_->col_params_.instance.device_names; + col_params_.group.device_names = parent_->col_params_.group.device_names; col_params_.instance.devices = parent_->col_params_.instance.devices; col_params_.instance.permutation = parent->col_params_.instance.permutation; - col_params_.instance.task_names = - parent_->col_params_.instance.task_names; + col_params_.group.task_names = parent_->col_params_.group.task_names; col_params_.task.is_local = parent_->col_params_.task.is_local; CHECK_EQ(col_params_.instance.devices.size(), - col_params_.instance.device_names.size()); + col_params_.group.device_names.size()); // Default rank is order in device_names. col_params_.default_rank = rank; } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 018d4395d19..40c31185eac 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -730,13 +730,24 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( function_name, function_key, ret_node_names.size(), lib_def->ReachableDefinitions(*fdef), std::move(ret_types)); + // Do not run function/graph optimization passes for component functions, + // since they have already processed the main function. + const bool should_run_optimization_passes = !options.is_component_function; + if (!should_run_optimization_passes) { + VLOG(1) << "Skipping function/graph optimization passes when instantiating " + "component function " + << function_name; + } + // Mapping from a function body node name to the control output name. std::unordered_map node_name_to_control_ret; bool control_rets_updated = false; - TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run( - *dev_set, options.config_proto, &graph, &data->lib_def_, - &control_ret_node_names, &control_rets_updated)); + if (should_run_optimization_passes) { + TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run( + *dev_set, options.config_proto, &graph, &data->lib_def_, + &control_ret_node_names, &control_rets_updated)); + } if (control_rets_updated) { // Function graph pass may have resulted in different nodes/node names for @@ -761,17 +772,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( optimization_options.device_set = dev_set.get(); optimization_options.is_function_graph = true; - // Do not run graph optimization passes for component functions, since they - // have already processed the main function. - bool should_run_graph_passes = !options.is_component_function; - if (!should_run_graph_passes) { - VLOG(1) << "Skipping graph optimization passes when instantiating " - "component function " - << function_name; - } - DumpGraph("Before running PRE_PLACEMENT passes", graph.get()); - if (should_run_graph_passes) { + if (should_run_optimization_passes) { TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); } @@ -786,7 +788,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( TF_RETURN_IF_ERROR(placer.Run()); DumpGraph("Before running POST_PLACEMENT passes", graph.get()); - if (should_run_graph_passes) { + if (should_run_optimization_passes) { TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_PLACEMENT, optimization_options)); } @@ -807,7 +809,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( } DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get()); - if (should_run_graph_passes) { + if (should_run_optimization_passes) { TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); } @@ -845,7 +847,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( // Normally POST_PARTITIONING passes are run by distributed workers. // Distributed workers are currently not supported in this code path, so we // run the passes here. - if (should_run_graph_passes) { + if (should_run_optimization_passes) { TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); } diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 4af624ae1d2..2ee74477231 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -151,7 +151,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr, RefCountedIntraProcessRendezvous::RefCountedIntraProcessRendezvous( const DeviceMgr* device_mgr) - : device_mgr_(device_mgr) {} + : device_mgr_(device_mgr), local_(this) {} RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() {} @@ -176,7 +176,7 @@ void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) { PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous( const DeviceMgr* device_mgr) - : device_mgr_(device_mgr) {} + : device_mgr_(device_mgr), local_(nullptr) {} PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {} diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index 870429bd883..b7a3bd11ec6 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -164,7 +164,7 @@ Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { const string& device_name = - col_params->instance.device_names[col_params->default_rank]; + col_params->group.device_names[col_params->default_rank]; // Each subdiv permutation is a ring formed by rotating each // single-task subsequence of devices by an offset. This makes most // sense when each task has the same number of devices but we can't @@ -175,15 +175,15 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { // Precondition: device_names must be sorted so that all devices in // the same task are adjacent. VLOG(2) << "Sorted task names: " - << absl::StrJoin(col_params->instance.task_names, ", "); + << absl::StrJoin(col_params->group.task_names, ", "); std::vector dev_per_task; - const string* prior_task_name = &col_params->instance.task_names[0]; + const string* prior_task_name = &col_params->group.task_names[0]; int dev_count = 1; for (int di = 1; di < col_params->group.group_size; ++di) { - if (col_params->instance.task_names[di] != *prior_task_name) { + if (col_params->group.task_names[di] != *prior_task_name) { dev_per_task.push_back(dev_count); dev_count = 1; - prior_task_name = &col_params->instance.task_names[di]; + prior_task_name = &col_params->group.task_names[di]; } else { ++dev_count; } @@ -227,7 +227,7 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { int permuted_di = prior_dev_count + offset_di; int rank = static_cast(perm.size()); perm.push_back(permuted_di); - if (col_params->instance.device_names[permuted_di] == device_name) { + if (col_params->group.device_names[permuted_di] == device_name) { DCHECK_EQ(permuted_di, col_params->default_rank); col_params->subdiv_rank[sdi] = rank; } @@ -385,8 +385,8 @@ void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) { int send_to_dev_idx = col_params_->instance.impl_details .subdiv_permutations[rf->subdiv_idx][send_to_rank]; col_ctx_->col_exec->remote_access()->PostToPeer( - col_params_->instance.device_names[send_to_dev_idx], - col_params_->instance.task_names[send_to_dev_idx], send_buf_key, + col_params_->group.device_names[send_to_dev_idx], + col_params_->group.task_names[send_to_dev_idx], send_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk, col_ctx_->device_locality, done); @@ -404,8 +404,8 @@ void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) { ? &rf->tmp_chunk : &rf->chunk; col_ctx_->col_exec->remote_access()->RecvFromPeer( - col_params_->instance.device_names[rf->recv_dev_idx], - col_params_->instance.task_names[rf->recv_dev_idx], + col_params_->group.device_names[rf->recv_dev_idx], + col_params_->group.task_names[rf->recv_dev_idx], col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key, col_ctx_->device, col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor, diff --git a/tensorflow/core/common_runtime/ring_gatherer.cc b/tensorflow/core/common_runtime/ring_gatherer.cc index ecffd4a6eea..8ddf3e2c004 100644 --- a/tensorflow/core/common_runtime/ring_gatherer.cc +++ b/tensorflow/core/common_runtime/ring_gatherer.cc @@ -71,9 +71,9 @@ void RingGatherer::Run(StatusCallback done) { if (VLOG_IS_ON(1)) { string buf; - for (int r = 0; r < col_params_->instance.device_names.size(); ++r) { + for (int r = 0; r < col_params_->group.device_names.size(); ++r) { strings::StrAppend(&buf, "dev ", r, " : ", - col_params_->instance.device_names[r], "\n"); + col_params_->group.device_names[r], "\n"); } for (int sd = 0; sd < col_params_->instance.impl_details.subdiv_permutations.size(); diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc index 8fccaf4ea6f..6b51993e2f4 100644 --- a/tensorflow/core/common_runtime/ring_gatherer_test.cc +++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc @@ -222,8 +222,8 @@ class RingGathererTest : public ::testing::Test { dev_name = strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size()); } - col_params_.instance.device_names.push_back(dev_name); - col_params_.instance.task_names.push_back(task_name); + col_params_.group.device_names.push_back(dev_name); + col_params_.group.task_names.push_back(task_name); // Normally each device would set is_local to its own perspective but // this test runs in a single process so is_local is always true. col_params_.task.is_local.push_back(true); @@ -240,7 +240,7 @@ class RingGathererTest : public ::testing::Test { for (int di = 0; di < num_devices; ++di) { int rank = wi * num_devices + di; instances_.push_back(new DeviceInstance( - rank, col_params_.instance.device_names[rank], device_type_, this)); + rank, col_params_.group.device_names[rank], device_type_, this)); } } } @@ -389,9 +389,7 @@ class RingGathererTest : public ::testing::Test { << "Couldn't find device " << dev_name << " existing devices: " << parent_->dev_mgr_->DebugString(); col_params_.name = parent_->col_params_.name; - col_params_.group.group_key = parent_->col_params_.group.group_key; - col_params_.group.device_type = parent_->col_params_.group.device_type; - col_params_.group.group_size = parent_->col_params_.group.group_size; + col_params_.group = parent_->col_params_.group; col_params_.instance = parent->col_params_.instance; col_params_.task.is_local = parent_->col_params_.task.is_local; col_params_.subdiv_rank = parent_->col_params_.subdiv_rank; @@ -399,7 +397,7 @@ class RingGathererTest : public ::testing::Test { int num_subdivs = static_cast(col_params_.subdiv_rank.size()); int group_size = col_params_.group.group_size; CHECK_EQ(group_size, - static_cast(col_params_.instance.device_names.size())); + static_cast(col_params_.group.device_names.size())); // Id of this device is at rank position in first subdiv perm. int my_device_id = col_params_.instance.impl_details.subdiv_permutations[0][rank]; @@ -547,8 +545,8 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, int dev_id = i % num_devs_per_task; string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id); string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id); - cp.instance.task_names.push_back(task_name); - cp.instance.device_names.push_back(device_name); + cp.group.task_names.push_back(task_name); + cp.group.device_names.push_back(device_name); } return cp; } diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc index ab4542d58d8..71f0226549f 100644 --- a/tensorflow/core/common_runtime/ring_reducer.cc +++ b/tensorflow/core/common_runtime/ring_reducer.cc @@ -67,9 +67,9 @@ void RingReducer::Run(StatusCallback done) { if (VLOG_IS_ON(1)) { string buf; - for (int r = 0; r < col_params_->instance.device_names.size(); ++r) { + for (int r = 0; r < col_params_->group.device_names.size(); ++r) { strings::StrAppend(&buf, "dev ", r, " : ", - col_params_->instance.device_names[r], "\n"); + col_params_->group.device_names[r], "\n"); } for (int sd = 0; sd < col_params_->instance.impl_details.subdiv_permutations.size(); diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index d50b83d5644..ad20a243151 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -238,15 +238,15 @@ class RingReducerTest : public ::testing::Test { // Set up all of the fake device contexts. for (int wi = 0; wi < num_workers; ++wi) { string task_name = strings::StrCat("/job:worker/replica:0/task:", wi); - col_params_.instance.num_devices_per_task[task_name] = num_devices; + col_params_.group.num_devices_per_task[task_name] = num_devices; for (int di = 0; di < num_devices; ++di) { string dev_name = strings::StrCat(task_name, "/cpu:", di); if (device_type == DEVICE_GPU) { dev_name = strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size()); } - col_params_.instance.device_names.push_back(dev_name); - col_params_.instance.task_names.push_back(task_name); + col_params_.group.device_names.push_back(dev_name); + col_params_.group.task_names.push_back(task_name); // Normally each device would set is_local to its own perspective but // this test runs in a single process so is_local is always true. col_params_.task.is_local.push_back(true); @@ -263,7 +263,7 @@ class RingReducerTest : public ::testing::Test { for (int di = 0; di < num_devices; ++di) { int rank = wi * num_devices + di; instances_.push_back(new DeviceInstance( - rank, col_params_.instance.device_names[rank], device_type_, this)); + rank, col_params_.group.device_names[rank], device_type_, this)); } } } @@ -414,9 +414,7 @@ class RingReducerTest : public ::testing::Test { << "Couldn't find device " << dev_name << " existing devices: " << parent_->dev_mgr_->DebugString(); col_params_.name = parent_->col_params_.name; - col_params_.group.group_key = parent_->col_params_.group.group_key; - col_params_.group.device_type = parent_->col_params_.group.device_type; - col_params_.group.group_size = parent_->col_params_.group.group_size; + col_params_.group = parent_->col_params_.group; col_params_.instance = parent->col_params_.instance; col_params_.task.is_local = parent_->col_params_.task.is_local; col_params_.subdiv_rank = parent_->col_params_.subdiv_rank; @@ -424,7 +422,7 @@ class RingReducerTest : public ::testing::Test { int num_subdivs = static_cast(col_params_.subdiv_rank.size()); int group_size = col_params_.group.group_size; CHECK_EQ(group_size, - static_cast(col_params_.instance.device_names.size())); + static_cast(col_params_.group.device_names.size())); // Id of this device is at rank position in first subdiv perm. int my_device_id = col_params_.instance.impl_details.subdiv_permutations[0][rank]; @@ -574,8 +572,8 @@ CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, int dev_id = i % num_devs_per_task; string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id); string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id); - cp.instance.task_names.push_back(task_name); - cp.instance.device_names.push_back(device_name); + cp.group.task_names.push_back(task_name); + cp.group.device_names.push_back(device_name); } return cp; } diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index e86603828ec..f0d715b06b7 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -14,8 +14,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "compression_utils", srcs = ["compression_utils.cc"], diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 6272dc477b3..d386dc6ec6b 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -27,8 +27,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tf_proto_library( name = "common_proto", srcs = ["common.proto"], diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index 43a639246ef..2973548ab19 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -54,6 +54,17 @@ constexpr char kJournalDir[] = "tf_data_dispatcher_journal"; // The name of the datasets directory inside the dispatcher's working directory. constexpr char kDatasetsDir[] = "datasets"; +constexpr std::array kNodeNameSharingOps = { + "HashTable", + "HashTableV2", + "MutableHashTable", + "MutableHashTableV2", + "MutableDenseHashTable", + "MutableDenseHashTableV2", + "MutableHashTableOfTensors", + "MutableHashTableOfTensorsV2", +}; + using Dataset = DispatcherState::Dataset; using Worker = DispatcherState::Worker; using NamedJobKey = DispatcherState::NamedJobKey; @@ -312,14 +323,24 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( GetOrRegisterDatasetResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); uint64 fingerprint; - const GraphDef& graph = request->dataset().graph(); - TF_RETURN_IF_ERROR(HashGraph(graph, &fingerprint)); + DatasetDef dataset_def = request->dataset(); + TF_RETURN_IF_ERROR(HashGraph(dataset_def.graph(), &fingerprint)); + // Set `use_node_name_sharing` to `true` so that resources aren't deleted + // prematurely. Otherwise, resources may be deleted when their ops are + // deleted at the end of the GraphRunner::Run used by standalone::Dataset. + for (NodeDef& node : *dataset_def.mutable_graph()->mutable_node()) { + for (const auto& op : kNodeNameSharingOps) { + if (node.op() == op) { + (*node.mutable_attr())["use_node_name_sharing"].set_b(true); + } + } + } mutex_lock l(mu_); #if defined(PLATFORM_GOOGLE) - VLOG_LINES(4, - absl::StrCat("Registering dataset graph: ", graph.DebugString())); + VLOG_LINES(4, absl::StrCat("Registering dataset graph: ", + dataset_def.graph().DebugString())); #else - VLOG(4) << "Registering dataset graph: " << graph.DebugString(); + VLOG(4) << "Registering dataset graph: " << dataset_def.graph().DebugString(); #endif std::shared_ptr dataset; Status s = state_.DatasetFromFingerprint(fingerprint, dataset); @@ -334,7 +355,7 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset( } int64 id; - TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, request->dataset(), id)); + TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id)); response->set_dataset_id(id); VLOG(3) << "Registered new dataset with id " << id; return Status::OK(); @@ -450,22 +471,6 @@ Status DataServiceDispatcherImpl::ValidateMatchingJob( "processing mode <", actual, ">"); } - if (job->dataset_id != dataset_id) { - std::shared_ptr job_dataset_def; - TF_RETURN_IF_ERROR(GetDatasetDef(job->dataset_id, job_dataset_def)); - std::shared_ptr dataset_def; - TF_RETURN_IF_ERROR(GetDatasetDef(dataset_id, dataset_def)); - Status s = CheckGraphsEqual(job_dataset_def->graph(), dataset_def->graph()); - return errors::FailedPrecondition( - "Tried to create a job with name ", job_name, " for dataset ", - dataset_id, - ", but there is already an existing job with that name for dataset ", - job->dataset_id, - ". This either means that you are distributing two different datasets " - "under the same job_name, or that your dataset is being constructed " - "non-deterministically. Comparing the datasets results in: ", - s); - } return Status::OK(); } diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index 45cfda90ff1..c594b936010 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -230,7 +230,6 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:functional_ops_op_lib", "//tensorflow/core:lib", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:no_op_op_lib", @@ -244,6 +243,7 @@ tf_cc_test( "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/protobuf:master_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 2459dfb5949..d3030f4ca0b 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -24,8 +24,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - exports_files( ["server_lib.h"], visibility = ["//tensorflow:internal"], @@ -68,9 +66,9 @@ cc_library( hdrs = ["message_wrappers.h"], deps = [ "//tensorflow/core:framework", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:master_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -106,7 +104,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/types:optional", ], ) @@ -146,7 +144,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:ptr_util", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -183,7 +181,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -197,7 +195,7 @@ cc_library( ":message_wrappers", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -267,8 +265,8 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:worker_proto_cc", "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -299,7 +297,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -311,7 +309,7 @@ cc_library( ":message_wrappers", ":request_id", "//tensorflow/core:lib", - "//tensorflow/core:master_proto_cc", + "//tensorflow/core/protobuf:master_proto_cc", ], ) @@ -332,9 +330,9 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:master_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -356,9 +354,9 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", + "//tensorflow/core/protobuf:master_proto_cc", ], ) @@ -440,10 +438,10 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", "//tensorflow/core/debug", "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:traceme_encode", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -457,7 +455,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -523,8 +521,8 @@ tf_cc_test( "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:worker_proto_cc", "//tensorflow/core/nccl:collective_communicator", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -540,7 +538,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib_internal", # protobuf::Any "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/memory", ], ) @@ -561,7 +559,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -673,7 +671,6 @@ tf_cuda_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:test", @@ -689,6 +686,7 @@ tf_cuda_cc_test( "//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/protobuf:master_proto_cc", tf_grpc_cc_dependency(), ], ) @@ -713,7 +711,6 @@ tf_cuda_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -722,6 +719,7 @@ tf_cuda_cc_test( "//tensorflow/core/distributed_runtime/rpc:grpc_testlib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/protobuf:master_proto_cc", tf_grpc_cc_dependency(), ], ) @@ -783,7 +781,7 @@ cc_library( deps = [ ":message_wrappers", "//tensorflow/core:lib", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) @@ -797,7 +795,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index ee96d63b0da..238d29065d2 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -250,27 +250,32 @@ Status CollectiveParamResolverDistributed::UpdateGroupCache( gr->devices[device.name()] = device; } gr->group.runtime_details.communicator_key = resp.communicator_key(); + FinishGroup(gr.get()); } + GroupRec* previous_gr = nullptr; { // Group membership should never change. Once a record is in group_table_ // it never gets removed. mutex_lock l(group_mu_); - auto it = group_table_.find(gr->group.group_key); + auto it = group_table_.find(resp.group_key()); if (it == group_table_.end()) { VLOG(2) << "UpdateGroupCache: communicator_key=" - << absl::CEscape(gr->group.runtime_details.communicator_key); + << absl::CEscape(resp.communicator_key()); group_table_[gr->group.group_key] = std::move(gr); } else { - auto& previous_gr = group_table_[gr->group.group_key]; - if (previous_gr->group.runtime_details.communicator_key != - gr->group.runtime_details.communicator_key) { - return errors::Internal( - "UpdateGroupCache: CompleteGroupResponse for group ", - gr->group.group_key, " gives communicator_key=", - absl::CEscape(gr->group.runtime_details.communicator_key), - " but cache already holds communicator_key=", - absl::CEscape(previous_gr->group.runtime_details.communicator_key)); - } + previous_gr = it->second.get(); + } + } + if (previous_gr != nullptr) { + mutex_lock grl(previous_gr->mu); + if (previous_gr->group.runtime_details.communicator_key != + resp.communicator_key()) { + return errors::Internal( + "UpdateGroupCache: CompleteGroupResponse for group ", + resp.group_key(), + " gives communicator_key=", absl::CEscape(resp.communicator_key()), + " but cache already holds communicator_key=", + absl::CEscape(previous_gr->group.runtime_details.communicator_key)); } } return Status::OK(); @@ -362,7 +367,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed( if (group_leader_.empty()) { // This is the group leader so resolution is local. return CompleteInstanceLocal(device, gr, cp, cp->is_source, done); - } else if (InstanceIsCached(gr->group.group_key, cp->instance.instance_key)) { + } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) { return CompleteInstanceLocal(device, gr, cp, cp->is_source, done); } else { CompleteInstanceCall* call = new CompleteInstanceCall( diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index 8f35f9e8d27..f08f7a3275d 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -253,18 +253,18 @@ class DeviceResDistTest : public ::testing::Test { int idx = wi * num_devices + di; TF_ASSERT_OK(status_[device_name]); EXPECT_EQ(cp_[device_name].default_rank, idx); - EXPECT_EQ(cp_[device_name].instance.device_names.size(), dev_count); - EXPECT_EQ(cp_[device_name].instance.device_names[idx], device_name); - EXPECT_EQ(cp_[device_name].instance.task_names[idx], task_name); + EXPECT_EQ(cp_[device_name].group.device_names.size(), dev_count); + EXPECT_EQ(cp_[device_name].group.device_names[idx], device_name); + EXPECT_EQ(cp_[device_name].group.task_names[idx], task_name); ValidateDeviceResolver(cp_[device_name], task_name); if (idx > 0) { EXPECT_EQ(cp_[dev0].group.runtime_details.communicator_key, cp_[device_name].group.runtime_details.communicator_key); for (int i = 0; i < dev_count; ++i) { - EXPECT_EQ(cp_[dev0].instance.device_names[i], - cp_[device_name].instance.device_names[i]); - EXPECT_EQ(cp_[dev0].instance.task_names[i], - cp_[device_name].instance.task_names[i]); + EXPECT_EQ(cp_[dev0].group.device_names[i], + cp_[device_name].group.device_names[i]); + EXPECT_EQ(cp_[dev0].group.task_names[i], + cp_[device_name].group.task_names[i]); } } } @@ -272,7 +272,7 @@ class DeviceResDistTest : public ::testing::Test { } void ValidateDeviceResolver(const CollectiveParams& cp, const string& task) { - for (const string& device_name : cp.instance.device_names) { + for (const string& device_name : cp.group.device_names) { DeviceAttributes attributes; TF_ASSERT_OK( dev_resolvers_[task]->GetDeviceAttributes(device_name, &attributes)); diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 08814653e3e..0b0c38e5a35 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -17,8 +17,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "remote_tensor_handle", hdrs = ["remote_tensor_handle.h"], @@ -61,11 +59,11 @@ cc_library( hdrs = ["destroy_tensor_handle_node.h"], deps = [ ":eager_client", - "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/protobuf:eager_service_proto_cc", ], ) @@ -73,10 +71,10 @@ cc_library( name = "eager_client", hdrs = ["eager_client.h"], deps = [ - "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:call_options", + "//tensorflow/core/protobuf:eager_service_proto_cc", ], ) @@ -87,13 +85,13 @@ cc_library( deps = [ ":eager_client", "//tensorflow/core:core_cpu", - "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:shape_inference", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/protobuf:eager_service_proto_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -112,7 +110,6 @@ cc_library( "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_status_helper", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -130,6 +127,7 @@ cc_library( "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/protobuf:eager_service_proto_cc", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", @@ -146,7 +144,6 @@ tf_cc_test( ":remote_mgr", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", - "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -159,6 +156,7 @@ tf_cc_test( "//tensorflow/core/distributed_runtime:test_utils", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/protobuf:eager_service_proto_cc", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index fcc2d33b9c8..9d35ddf08f7 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -954,6 +954,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{}, /*runner=*/nullptr, /*collective_executor=*/nullptr, local_device, fdef_.signature().name(), + /*outputs_on_op_device=*/false, [ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); }, [=]() { return op_id; })); @@ -1001,6 +1002,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { /*composite_devices=*/{}, /*input_resource_dtypes_and_shapes=*/{}, /*runner=*/nullptr, /*collective_executor=*/nullptr, local_device, fdef_.signature().name(), + /*outputs_on_op_device=*/false, [ctx](const int64 step_id) { return ctx->CreateRendezvous(step_id); }, [=]() { return op_id; })); diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index f2c9889b549..14a358e8ac2 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -31,8 +31,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "c_srcs", data = glob([ @@ -96,10 +94,10 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache_logger", "//tensorflow/core/distributed_runtime:worker_interface", + "//tensorflow/core/protobuf:worker_proto_cc", tf_grpc_cc_dependency(), ], ) @@ -129,7 +127,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/flags:flag", tf_grpc_cc_dependency(), ], @@ -218,13 +216,13 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:graph_mgr", "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", "//tensorflow/core/distributed_runtime:worker", "//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_session", + "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/container:flat_hash_map", tf_grpc_cc_dependency(), ], @@ -236,8 +234,8 @@ cc_library( hdrs = ["grpc_worker_service_impl.h"], deps = [ ":grpc_util", - "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:tensor_coding", + "//tensorflow/core/protobuf:worker_proto_cc", tf_grpc_cc_dependency(), ], ) @@ -251,10 +249,10 @@ cc_library( ":grpc_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:master_proto_cc", "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:master_interface", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/protobuf:master_proto_cc", "@com_google_absl//absl/time", ], alwayslink = 1, @@ -271,9 +269,9 @@ cc_library( ":grpc_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:master_proto_cc", "//tensorflow/core/distributed_runtime:master", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/protobuf:master_proto_cc", tf_grpc_cc_dependency(), ], alwayslink = 1, @@ -284,7 +282,7 @@ cc_library( srcs = ["grpc_master_service_impl.cc"], hdrs = ["grpc_master_service_impl.h"], deps = [ - "//tensorflow/core:master_proto_cc", + "//tensorflow/core/protobuf:master_proto_cc", tf_grpc_cc_dependency(), ], ) @@ -447,13 +445,13 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:local_master", "//tensorflow/core/distributed_runtime:master_interface", "//tensorflow/core/distributed_runtime:message_wrappers", "//tensorflow/core/distributed_runtime:request_id", + "//tensorflow/core/protobuf:master_proto_cc", ], alwayslink = 1, ) @@ -481,7 +479,6 @@ tf_cuda_cc_tests( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -489,6 +486,7 @@ tf_cuda_cc_tests( "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:test_utils", "//tensorflow/core/platform:blocking_counter", + "//tensorflow/core/protobuf:master_proto_cc", ], ) @@ -510,7 +508,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", tf_grpc_cc_dependency(), ], ) @@ -523,7 +521,7 @@ tf_cc_test( ":grpc_util", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:worker_proto_cc", + "//tensorflow/core/protobuf:worker_proto_cc", tf_grpc_dependency(), tf_grpc_cc_dependency(), ], @@ -546,7 +544,6 @@ tf_cuda_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", "//tensorflow/core:state_ops_op_lib", "//tensorflow/core:test", @@ -557,6 +554,7 @@ tf_cuda_cc_test( "//tensorflow/core/kernels:dense_update_ops", "//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/protobuf:master_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index ac0339fe6cd..12dbbadefb8 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -9,14 +9,12 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "grpc_eager_service", srcs = ["grpc_eager_service.h"], hdrs = ["grpc_eager_service.h"], deps = [ - "//tensorflow/core:eager_service_proto_cc", + "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/stream_executor/platform", tf_grpc_cc_dependency(), ], @@ -28,7 +26,6 @@ cc_library( hdrs = ["grpc_eager_client.h"], deps = [ ":grpc_eager_service", - "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:call_options", @@ -37,6 +34,7 @@ cc_library( "//tensorflow/core/distributed_runtime/rpc:grpc_client_cq_tag", "//tensorflow/core/distributed_runtime/rpc:grpc_state", "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "//tensorflow/core/protobuf:eager_service_proto_cc", tf_grpc_cc_dependency(), ], ) @@ -47,7 +45,6 @@ cc_library( hdrs = ["grpc_eager_service_impl.h"], deps = [ ":grpc_eager_service", - "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:ptr_util", "//tensorflow/core/distributed_runtime/eager:eager_service_impl", @@ -56,6 +53,7 @@ cc_library( "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/protobuf:eager_service_proto_cc", tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 10bd8936a74..f833adcf932 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -194,7 +194,7 @@ class GrpcWorkerServiceThread { auto closure = [this, call]() { \ Status s = worker_->method(&call->request, &call->response); \ if (!s.ok()) { \ - VLOG(1) << "Bad response from " << #method << ": " << s; \ + VLOG(3) << "Bad response from " << #method << ": " << s; \ } \ call->SendResponse(ToGrpcStatus(s)); \ }; \ @@ -223,7 +223,7 @@ class GrpcWorkerServiceThread { Schedule([this, call]() { worker_->GetStepSequenceAsync( &call->request, &call->response, [call](const Status& s) { - VLOG(1) << "Bad response from GetStepSequence:" << s; + VLOG(3) << "Bad response from GetStepSequence:" << s; call->SendResponse(ToGrpcStatus(s)); }); }); @@ -232,7 +232,7 @@ class GrpcWorkerServiceThread { void MarkRecvFinishedHandler( WorkerCall* call) { - VLOG(1) << "Clean cache entry for request " << call->request.request_id(); + VLOG(3) << "Clean cache entry for request " << call->request.request_id(); worker_->RemoveCacheEntryForId(call->request.request_id()); call->SendResponse(::grpc::Status::OK); ENQUEUE_REQUEST(MarkRecvFinished, false); @@ -249,9 +249,9 @@ class GrpcWorkerServiceThread { worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response, [call, call_opts, wrapped_request, wrapped_response](const Status& s) { - VLOG(1) << "RunGraph::Done"; + VLOG(3) << "RunGraph::Done"; if (!s.ok()) { - VLOG(1) << "Bad response from RunGraph:" << s; + VLOG(3) << "Bad response from RunGraph:" << s; } call->ClearCancelCallback(); delete call_opts; @@ -275,7 +275,7 @@ class GrpcWorkerServiceThread { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { - VLOG(1) << "Bad response from RecvTensor:" << s; + VLOG(3) << "Bad response from RecvTensor:" << s; } call->SendResponse(ToGrpcStatus(s)); }); @@ -292,7 +292,7 @@ class GrpcWorkerServiceThread { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { - VLOG(1) << "Bad response from RecvBuf:" << s; + VLOG(3) << "Bad response from RecvBuf:" << s; } call->SendResponse(ToGrpcStatus(s)); }); @@ -311,7 +311,7 @@ class GrpcWorkerServiceThread { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { - VLOG(1) << "Bad response from CompleteGroup:" << s; + VLOG(3) << "Bad response from CompleteGroup:" << s; } call->SendResponse(ToGrpcStatus(s)); }); @@ -330,7 +330,7 @@ class GrpcWorkerServiceThread { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { - VLOG(1) << "Bad response from CompleteInstance:" << s; + VLOG(3) << "Bad response from CompleteInstance:" << s; } call->SendResponse(ToGrpcStatus(s)); }); @@ -430,7 +430,7 @@ GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config) } void GrpcWorker::EnableResponseCache() { - VLOG(1) << "Enabling gRPC tensor response cache."; + VLOG(3) << "Enabling gRPC tensor response cache."; response_cache_ = absl::make_unique(); } @@ -441,7 +441,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { - VLOG(1) << "GrpcRecvTensorAsync req: " << request->DebugString(); + VLOG(3) << "GrpcRecvTensorAsync req: " << request->DebugString(); const int64 request_id = request->request_id(); const int64 step_id = request->step_id(); diff --git a/tensorflow/core/example/BUILD b/tensorflow/core/example/BUILD index 5c5168d295f..19521c0ea82 100644 --- a/tensorflow/core/example/BUILD +++ b/tensorflow/core/example/BUILD @@ -17,7 +17,7 @@ load( package( default_visibility = [ - "//tensorflow/core:__subpackages__", + "//visibility:public", ], licenses = ["notice"], # Apache 2.0 ) @@ -31,7 +31,6 @@ cc_library( hdrs = ["example_parser_configuration.h"], copts = tf_copts(), linkstatic = 1, - visibility = ["//visibility:public"], deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -69,7 +68,6 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:direct_session_internal", "//tensorflow/core:example_parser_configuration", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -80,6 +78,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:example_parsing_ops", ], ) @@ -91,9 +90,6 @@ filegroup( ], ) -# TODO(bmzhao): Split this target into separate tf_proto_libraries. -# This target is a holdover from the target tensorflow/core:example_protos, -# but splitting it while require multiple LSCs. tf_proto_library( name = "example_protos", srcs = [ @@ -102,9 +98,6 @@ tf_proto_library( ], cc_api_version = 2, make_default_target_header_only = True, - tags = [ - "alt_dep=//third_party/tensorflow/core:protos_all", - ], ) tf_proto_library( @@ -146,7 +139,6 @@ tf_pyclif_proto_library( closure_proto_library( name = "example_protos_closure", - visibility = ["//visibility:public"], deps = [":example_protos"], ) diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index e255eafde16..215a4ddccbf 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -6,6 +6,7 @@ load( ) load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_cc_tests", "tf_copts", "tf_cuda_library", @@ -14,6 +15,9 @@ load( # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_generate_proto_text_sources") @@ -22,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_selective_registration_deps") load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", + "tf_cuda_tests_tags", ) load( "//tensorflow/core/platform:rules_cc.bzl", @@ -45,7 +50,9 @@ exports_files( "control_flow.h", "dataset.h", "dataset_stateful_op_allowlist.h", + "device.h", "device_base.h", + "device_factory.h", "function.h", "function_handle_cache.h", "graph_def_util.h", @@ -173,7 +180,9 @@ filegroup( "control_flow.h", "dataset.h", "dataset_stateful_op_allowlist.h", + "device.h", "device_base.h", + "device_factory.h", "function.h", "function_handle_cache.h", "graph_def_util.h", @@ -246,7 +255,9 @@ filegroup( "cancellation.cc", "collective.cc", "dataset.cc", + "device.cc", "device_base.cc", + "device_factory.cc", "function.cc", "function_handle_cache.cc", "graph_def_util.cc", @@ -337,8 +348,12 @@ filegroup( "dataset.cc", "dataset.h", "dataset_stateful_op_allowlist.h", + "device.cc", + "device.h", "device_base.cc", "device_base.h", + "device_factory.cc", + "device_factory.h", "function.cc", "function.h", "function_handle_cache.cc", @@ -454,7 +469,7 @@ cc_library( "allocator.h", ], features = ["parse_headers"], - visibility = ["//tensorflow/core:__subpackages__"], + visibility = ["//visibility:public"], deps = [ ":numeric_types", ":type_traits", @@ -467,6 +482,7 @@ cc_library( "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/lib/strings:stringprintf", "//tensorflow/core/platform:env", + "//tensorflow/core/platform:env_impl", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", @@ -533,6 +549,7 @@ cc_library( srcs = ["shape_inference_testutil.cc"], hdrs = ["shape_inference_testutil.h"], copts = tf_copts(), + visibility = ["//tensorflow:internal"], deps = [ ":node_def_proto_cc", "//tensorflow/core:framework", @@ -545,7 +562,7 @@ cc_library( name = "reader_base", srcs = ["reader_base.cc"], hdrs = ["reader_base.h"], - visibility = ["//tensorflow/core:__subpackages__"], + visibility = ["//visibility:public"], deps = [ ":reader_base_proto_cc", "//tensorflow/core:framework", @@ -557,7 +574,7 @@ cc_library( name = "op_gen_lib", srcs = ["op_gen_lib.cc"], hdrs = ["op_gen_lib.h"], - visibility = ["//tensorflow/core:__subpackages__"], + visibility = ["//visibility:public"], deps = [ ":api_def_proto_cc", ":attr_value_proto_cc", @@ -1020,6 +1037,110 @@ exports_files( ) # Framework tests. +tf_cc_test( + name = "framework_op_gen_lib_test", + size = "small", + srcs = ["op_gen_lib_test.cc"], + deps = [ + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/framework:op_gen_lib", + ], +) + +tf_cc_test_gpu( + name = "variant_op_copy_test", + size = "small", + srcs = ["variant_op_copy_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "framework_run_handler_util_test", + size = "small", + srcs = ["run_handler_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "framework_run_handler_test", + size = "small", + srcs = ["run_handler_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:matmul_op", + "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ], +) + +tf_cc_test( + name = "framework_op_segment_test", + size = "small", + srcs = ["op_segment_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", + ], +) + tf_cc_tests( name = "higher_level_tests", size = "small", @@ -1085,7 +1206,6 @@ tf_cc_tests( "//tensorflow/core", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:direct_session_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -1095,6 +1215,7 @@ tf_cc_tests( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/platform:regexp", "//tensorflow/core/util:protos_test_cc", diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index df19752d197..36e26ba9fe9 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -54,10 +54,23 @@ string CollGroupRuntimeDetails::ToString() const { } string CollGroupParams::ToString() const { - return strings::StrCat( + string v = strings::StrCat( "CollGroupParams {group_key=", group_key, " group_size=", group_size, " device_type=", device_type.type_string(), " num_tasks=", num_tasks, - " runtime_details=", runtime_details.ToString(), "}"); + " runtime_details=", runtime_details.ToString(), " devices {"); + for (const auto& d : device_names) { + strings::StrAppend(&v, d, ","); + } + strings::StrAppend(&v, "} task_names={"); + for (const auto& n : task_names) { + strings::StrAppend(&v, n, ", "); + } + strings::StrAppend(&v, "} num_devices_per_task={"); + for (const auto& dpt : num_devices_per_task) { + strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", "); + } + strings::StrAppend(&v, "}"); + return v; } CollInstanceParams& CollInstanceParams::operator=( @@ -67,12 +80,6 @@ CollInstanceParams& CollInstanceParams::operator=( type = other.type; data_type = other.data_type; shape = other.shape; - device_names.clear(); - device_names.assign(other.device_names.begin(), other.device_names.end()); - task_names.assign(other.task_names.begin(), other.task_names.end()); - same_num_devices_per_task = other.same_num_devices_per_task; - num_devices_per_task = other.num_devices_per_task; - gpu_ring_order = other.gpu_ring_order; impl_details.subdiv_offsets.assign( other.impl_details.subdiv_offsets.begin(), other.impl_details.subdiv_offsets.end()); @@ -96,17 +103,6 @@ string CollInstanceParams::ToString() const { strings::StrCat("CollInstanceParams { instance_key=", instance_key, " type=", type, " data_type=", DataTypeString(data_type), " shape=", shape.DebugString(), " devices {"); - for (const auto& d : device_names) { - strings::StrAppend(&v, d, ","); - } - strings::StrAppend(&v, "} task_names={"); - for (const auto& n : task_names) { - strings::StrAppend(&v, n, ", "); - } - strings::StrAppend(&v, "} num_devices_per_task={"); - for (const auto& dpt : num_devices_per_task) { - strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", "); - } strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name, ", subdiv_offsets={"); strings::StrAppend(&v, "}, subdiv_offsets={"); @@ -186,7 +182,7 @@ CollectiveContext::CollectiveContext( input(input), output(output), device(nullptr), - device_name(col_params.instance.device_names[col_params.default_rank]) {} + device_name(col_params.group.device_names[col_params.default_rank]) {} /*static*/ int64 CollectiveExecutor::kInvalidId = -1; diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index f6020a29748..e63efcc15bf 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -63,6 +63,17 @@ struct CollGroupParams { int32 group_key; int32 group_size; DeviceType device_type; + // Fully qualified name of device for each member, in default rank order. + std::vector device_names; + // Task name prefix of corresponding device name. + std::vector task_names; + // True if every task has the same number of devices. + bool same_num_devices_per_task = false; + // Task -> number of devices on that task. + std::unordered_map num_devices_per_task; + // If passed in to GPUOptions in ConfigProto, defines a good ring order for + // GPUs. Assumes same GPU configuration at each worker. + string gpu_ring_order = ""; int32 num_tasks; // number of distinct tasks in group CollGroupRuntimeDetails runtime_details; string ToString() const; @@ -98,17 +109,6 @@ struct CollInstanceParams { CollectiveType type = UNDEFINED_COLLECTIVE; DataType data_type = DT_FLOAT; TensorShape shape = {0}; - // Fully qualified name of device for each member, in default rank order. - std::vector device_names; - // Task name prefix of corresponding device name. - std::vector task_names; - // True if every task has the same number of devices. - bool same_num_devices_per_task = false; - // Task -> number of devices on that task. - std::unordered_map num_devices_per_task; - // If passed in to GPUOptions in ConfigProto, defines a good ring order for - // GPUs. Assumes same GPU configuration at each worker. - string gpu_ring_order = ""; CollImplDetails impl_details; string ToString() const; CollInstanceParams& operator=(const struct CollInstanceParams& other); diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index c54418ba648..60c95e04799 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1120,9 +1120,26 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status AvgPoolGradShape(shape_inference::InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); + c->set_output(0, s); + return Status::OK(); +} + Status FusedBatchNormShape(shape_inference::InferenceContext* c) { + string data_format_str; + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + TensorFormat data_format; + if (!FormatFromString(data_format_str, &data_format)) { + return errors::InvalidArgument("Invalid data format string: ", + data_format_str); + } + const int rank = + (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4; ShapeHandle x; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x)); bool is_training; TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); @@ -1131,14 +1148,8 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) { exponential_avg_factor = 1.0f; // default value } int number_inputs = (is_training && exponential_avg_factor == 1.0f) ? 3 : 5; - string data_format_str; - TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); - TensorFormat data_format; - if (!FormatFromString(data_format_str, &data_format)) { - return errors::InvalidArgument("Invalid data format string: ", - data_format_str); - } - int channel_dim_index = GetTensorFeatureDimIndex(4, data_format); + + int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format); DimensionHandle channel_dim = c->Dim(x, channel_dim_index); // covers scale, offset, and if is_training is false, mean, variance @@ -1191,13 +1202,6 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { } Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { - ShapeHandle y_backprop; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); - ShapeHandle x; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); - - bool is_training; - TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; @@ -1205,7 +1209,17 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { return errors::InvalidArgument("Invalid data format string: ", data_format_str); } - int channel_dim_index = GetTensorFeatureDimIndex(4, data_format); + const int rank = + (data_format_str == "NDHWC" or data_format_str == "NCDHW") ? 5 : 4; + ShapeHandle y_backprop; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &y_backprop)); + ShapeHandle x; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &x)); + + bool is_training; + TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); + + int channel_dim_index = GetTensorFeatureDimIndex(rank, data_format); DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index); TF_RETURN_IF_ERROR( c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim)); @@ -1577,6 +1591,10 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false); } +Status MaxPoolGradShape(shape_inference::InferenceContext* c) { + return UnchangedShapeWithRank(c, 4); +} + Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) { return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true); } @@ -1765,6 +1783,18 @@ Status Pool3DShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status MaxPool3DGradShape(shape_inference::InferenceContext* c) { + return UnchangedShapeWithRank(c, 5); +} + +Status AvgPool3DGradShape(shape_inference::InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); +} + Status UnknownShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->UnknownShape()); diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index 3b14666305e..ba03514c739 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -147,6 +147,9 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c); // Shape function for AvgPool-like operations. Status AvgPoolShape(shape_inference::InferenceContext* c); +// Shape function for AvgPoolGrad-like operations. +Status AvgPoolGradShape(shape_inference::InferenceContext* c); + // Shape function for FusedBatchNorm and FusedBatchNormV2 operations. Status FusedBatchNormShape(shape_inference::InferenceContext* c); @@ -178,9 +181,18 @@ Status MaxPoolShape(shape_inference::InferenceContext* c); // Shape function for MaxPoolV2-like operations. Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs); +// Shape function for MaxPoolGrad-like operations. +Status MaxPoolGradShape(shape_inference::InferenceContext* c); + // Shape function for 3D Pooling operations. Status Pool3DShape(shape_inference::InferenceContext* c); +// Shape function for MaxPool3DGrad-like operations. +Status MaxPool3DGradShape(shape_inference::InferenceContext* c); + +// Shape function for AvgPool3DGrad-like operations. +Status AvgPool3DGradShape(shape_inference::InferenceContext* c); + // Shape function for use with ops whose output shapes are unknown. Status UnknownShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/framework/device.cc similarity index 92% rename from tensorflow/core/common_runtime/device.cc rename to tensorflow/core/framework/device.cc index 9925814a48a..50453822230 100644 --- a/tensorflow/core/common_runtime/device.cc +++ b/tensorflow/core/framework/device.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/op_segment.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/framework/device.h b/tensorflow/core/framework/device.h new file mode 100644 index 00000000000..0f544bdd123 --- /dev/null +++ b/tensorflow/core/framework/device.h @@ -0,0 +1,202 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A Device is a something that can perform computations as part of a +// model. Devices can be local (runs computation on this machine), or +// remote (contacts a device local to another machine using an RPC to +// do the work). Devices are registered in a DeviceSet, which is also +// responsible for the Device <-> id mapping. +// +// Device names +// * Every Device should have a unique name with the format: +// /job:___/replica:___/task:___/(gpu|cpu):___ +// An example name would be "/job:train/replica:0/task:3/device:GPU:2". +// * Task numbers are within the specified replica, so there are as +// many "task zeros" as replicas. + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ + +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class Device : public DeviceBase { + public: + // Callback type that takes a Status and returns void. + typedef std::function DoneCallback; + + Device(Env* env, const DeviceAttributes& device_attributes); + ~Device() override; + + // Full name of this device (see top comment). + const std::string& name() const override { return device_attributes_.name(); } + + // Parsed name of this device + const DeviceNameUtils::ParsedName& parsed_name() const { + return parsed_name_; + } + + // Describes what kind of device this is. This is intended to be + // human-readable and not computer-parsed, except that two devices + // with the same device_type() are expected to perform similarly + // (both from a computation and communication perspective). + const std::string& device_type() const { + return device_attributes_.device_type(); + } + + // Returns an aggregation of device attributes. + const DeviceAttributes& attributes() const override { + return device_attributes_; + } + + // Performs the actual compute function. + // + // Subclasses may override this function if they wish to perform + // some initialization before each compute. + virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) { + op_kernel->Compute(context); + } + + // Asynchronous kernel's compute. + virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) { + op_kernel->ComputeAsync(context, std::move(done)); + } + + // Blocks until all operations queued on the device at the time of + // the call have completed. Returns any error pending on the device + // at completion. + virtual Status Sync() = 0; + + // Calls the given callback when all operations queued on the device at the + // time of the call have completed. The callback is passed any error pending + // on the device at completion. + // TODO(b/112409994): Consolidate these two APIs, removing the synchronous + // version. + virtual void Sync(const DoneCallback& done); + + // On session completion, the executor may call Device::Sync() depending on + // flag settings. Override this to return false for devices that don't allow + // such calls. Instead, these devices must use other mechanisms (such as + // num_deferred_ops) to ensure the device has finished processing necessary + // work at session completion. In addition, for these devices, RefreshStatus + // must be called at session completion to retrieve execution result status. + // + // Devices that override this function must also implement RefreshStatus. + virtual bool AllowsSyncOnCompletion() const { return true; } + + // This is used in conjunction with AllowsSyncOnCompletion to allow the + // executor to get execution result status at session completion. + // + // For supported devices, this call returns the underlying device stream's + // current status in a non-blocking way, without using blocking calls such as + // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device + // status is also updated with the retrieved stream status. + virtual Status RefreshStatus() { + return errors::Unimplemented( + "RefreshStatus is not supported on this device."); + } + + // Optionally modify the device's GraphDef before execution. + // + // This method should be considered experimental and is supplied to enable + // prototyping of TensorFlow device implementations that need to modify + // the GraphDef before execution. + // + // 'graph' supplies the partition of the graph assigned to this + // device. + virtual Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { + return Status::OK(); + } + + // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr + // if the device does not support contexts. Returns an error status if any + // error occurred while trying to create a context, otherwise OK. + // + // The caller takes ownership of one reference on the output DeviceContext*, + // and should call Unref(). + virtual Status TryGetDeviceContext(DeviceContext** out_context) { + *out_context = nullptr; + return Status::OK(); + } + + // Returns the op segment of this device. The caller can reuse op + // kernels registered for the same session running on this device. + OpSegment* op_segment() { return &op_seg_; } + + // Returns the resource manager associated w/ this device. + virtual ResourceMgr* resource_manager() { return rmgr_; } + + // Summarizes the status of this Device, for debugging. + std::string DebugString() const { return device_attributes_.DebugString(); } + + // Assembles the parameter components into a complete DeviceAttributes value. + static DeviceAttributes BuildDeviceAttributes( + const std::string& name, DeviceType device, Bytes memory_limit, + const DeviceLocality& locality, const std::string& physical_device_desc); + + static DeviceAttributes BuildDeviceAttributes( + const std::string& name, DeviceType device, Bytes memory_limit, + const DeviceLocality& locality) { + // Pass in an empty string as physical device name. + return BuildDeviceAttributes(name, device, memory_limit, locality, ""); + } + + // Clears the resource manager associated with this device. + void ClearResourceMgr() { rmgr_->Clear(); } + + virtual bool IsLocal() const { return true; } + + protected: + void DeleteResourceMgr() { + delete rmgr_; + rmgr_ = nullptr; + } + + private: + const DeviceAttributes device_attributes_; + DeviceNameUtils::ParsedName parsed_name_; + + // op_seg_ maps session handle and op name to OpKernel objects. + OpSegment op_seg_; + + // Resources associated w/ this device. E.g., shared variables, etc. + ResourceMgr* rmgr_ = nullptr; + + TF_DISALLOW_COPY_AND_ASSIGN(Device); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/framework/device_factory.cc similarity index 98% rename from tensorflow/core/common_runtime/device_factory.cc rename to tensorflow/core/framework/device_factory.cc index 2872da15c26..dda5d5f0101 100644 --- a/tensorflow/core/common_runtime/device_factory.cc +++ b/tensorflow/core/framework/device_factory.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/device_factory.h" #include #include #include #include -#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/framework/device_factory.h b/tensorflow/core/framework/device_factory.h new file mode 100644 index 00000000000..43d1d615555 --- /dev/null +++ b/tensorflow/core/framework/device_factory.h @@ -0,0 +1,151 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ + +#include +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Device; +struct SessionOptions; + +class DeviceFactory { + public: + virtual ~DeviceFactory() {} + static void Register(const std::string& device_type, DeviceFactory* factory, + int priority); + static DeviceFactory* GetFactory(const std::string& device_type); + + // Append to "*devices" all suitable devices, respecting + // any device type specific properties/counts listed in "options". + // + // CPU devices are added first. + static Status AddDevices(const SessionOptions& options, + const std::string& name_prefix, + std::vector>* devices); + + // Helper for tests. Create a single device of type "type". The + // returned device is always numbered zero, so if creating multiple + // devices of the same type, supply distinct name_prefix arguments. + static std::unique_ptr NewDevice(const string& type, + const SessionOptions& options, + const string& name_prefix); + + // Iterate through all device factories and build a list of all of the + // possible physical devices. + // + // CPU is are added first. + static Status ListAllPhysicalDevices(std::vector* devices); + + // Get details for a specific device among all device factories. + // 'device_index' indexes into devices from ListAllPhysicalDevices. + static Status GetAnyDeviceDetails( + int device_index, std::unordered_map* details); + + // For a specific device factory list all possible physical devices. + virtual Status ListPhysicalDevices(std::vector* devices) = 0; + + // Get details for a specific device for a specific factory. Subclasses + // can store arbitrary device information in the map. 'device_index' indexes + // into devices from ListPhysicalDevices. + virtual Status GetDeviceDetails(int device_index, + std::unordered_map* details) { + return Status::OK(); + } + + // Most clients should call AddDevices() instead. + virtual Status CreateDevices( + const SessionOptions& options, const std::string& name_prefix, + std::vector>* devices) = 0; + + // Return the device priority number for a "device_type" string. + // + // Higher number implies higher priority. + // + // In standard TensorFlow distributions, GPU device types are + // preferred over CPU, and by default, custom devices that don't set + // a custom priority during registration will be prioritized lower + // than CPU. Custom devices that want a higher priority can set the + // 'priority' field when registering their device to something + // higher than the packaged devices. See calls to + // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used + // for built-in devices. + static int32 DevicePriority(const std::string& device_type); +}; + +namespace dfactory { + +template +class Registrar { + public: + // Multiple registrations for the same device type with different priorities + // are allowed. Priorities are used in two different ways: + // + // 1) When choosing which factory (that is, which device + // implementation) to use for a specific 'device_type', the + // factory registered with the highest priority will be chosen. + // For example, if there are two registrations: + // + // Registrar("CPU", 125); + // Registrar("CPU", 150); + // + // then CPUFactory2 will be chosen when + // DeviceFactory::GetFactory("CPU") is called. + // + // 2) When choosing which 'device_type' is preferred over other + // DeviceTypes in a DeviceSet, the ordering is determined + // by the 'priority' set during registration. For example, if there + // are two registrations: + // + // Registrar("CPU", 100); + // Registrar("GPU", 200); + // + // then DeviceType("GPU") will be prioritized higher than + // DeviceType("CPU"). + // + // The default priority values for built-in devices is: + // GPU: 210 + // GPUCompatibleCPU: 70 + // ThreadPoolDevice: 60 + // Default: 50 + explicit Registrar(const std::string& device_type, int priority = 50) { + DeviceFactory::Register(device_type, new Factory(), priority); + } +}; + +} // namespace dfactory + +#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + __COUNTER__, ##__VA_ARGS__) + +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + ctr, ...) \ + static ::tensorflow::dfactory::Registrar \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \ + ##__VA_ARGS__) + +// __COUNTER__ must go through another macro to be properly expanded +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_ + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ diff --git a/tensorflow/core/framework/local_rendezvous.cc b/tensorflow/core/framework/local_rendezvous.cc index 3535b57e7db..34053808b4a 100644 --- a/tensorflow/core/framework/local_rendezvous.cc +++ b/tensorflow/core/framework/local_rendezvous.cc @@ -187,6 +187,20 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, CancellationToken token = CancellationManager::kInvalidToken; bool already_cancelled = false; if (cm != nullptr) { + // Increment the refcount when cancellation manager is present, to make + // sure the rendezvous outlives the recv and its cancel callbacks. + // This refcount is dropped in exactly one of the following cases: + // (1) Recv registers cancellation callback to cm, and then cm is + // cancelled, unref in the cancellation callback; + // (2) Recv registers cancellation callback to cm, but cm is already + // cancelled, unref in the already_cancelled check; + // (3) Recv is successful, and item done callback finishes deregistering + // the cancellation callback, unref in the item done callback; + // (4) Recv is successful, but the item done callback fails to deregister + // the cancellation callback because cm already StartCancel, in this + // case the cancellation callback will be invoked by the cm anyway, + // unref in the cancellation callback. + if (rc_owner_) rc_owner_->Ref(); token = cm->get_cancellation_token(); already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] { Item* item = nullptr; @@ -230,10 +244,14 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false); delete item; } + // Unref case (1) and (4) + if (rc_owner_) rc_owner_->Unref(); }); } if (already_cancelled) { mu_.unlock(); + // Unref case (2) + if (rc_owner_) rc_owner_->Unref(); done(StatusGroup::MakeDerived( errors::Cancelled("RecvAsync is cancelled.")), Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false); @@ -250,10 +268,17 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, // cancellation manager may no longer be live after `done` is called. queue->push_back(new Item( recv_args, - [cm, token, done = std::move(done)]( + [this, cm, token, done = std::move(done)]( const Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, bool dead) { - cm->TryDeregisterCallback(token); + // TryDeregisterCallback returns true when the cancellation callback + // is successfully deregistered. If it fails because the CM already + // StartAbort, Unref will happen inside the cancellation callback + // when called by the CM. + if (cm->TryDeregisterCallback(token)) { + // Unref case (3) + if (this->rc_owner_) this->rc_owner_->Unref(); + } done(s, send_args, recv_args, v, dead); }, token)); diff --git a/tensorflow/core/framework/local_rendezvous.h b/tensorflow/core/framework/local_rendezvous.h index 19c218793b6..ed3a5d4dd73 100644 --- a/tensorflow/core/framework/local_rendezvous.h +++ b/tensorflow/core/framework/local_rendezvous.h @@ -35,7 +35,11 @@ namespace tensorflow { // is not expected to be needed. class LocalRendezvous { public: - LocalRendezvous() = default; + // If the class wrapping LocalRendezvous is refcounted (i.e., extending + // Rendezvous), pass in its pointer in constructor so the LocalRendezvous + // can make sure it outlives the async recv requests. + // Pass in nullptr if the wrapping class is not refcounted. + explicit LocalRendezvous(Rendezvous* owner) : rc_owner_(owner) {} ~LocalRendezvous(); Status Send(const Rendezvous::ParsedKey& key, @@ -62,6 +66,9 @@ class LocalRendezvous { typedef gtl::FlatMap Table; + // Pointer to the owner class of this LocalRendezvous if it is refcounted. + const Rendezvous* rc_owner_; + // TODO(zhifengc): shard table_. mutex mu_; Table table_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index f5aff3a4e11..dccfbbfaeca 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -42,6 +42,12 @@ auto* graph_run_time_usecs_histogram = monitoring::Sampler<0>::New( // Power of 2 with bucket count 20 (> 17 minutes) {monitoring::Buckets::Exponential(1000, 2, 20)}); +auto* graph_pending_queue_length_histogram = monitoring::Sampler<0>::New( + {"/tensorflow/core/graph_pending_queue_length_histogram", + "The number of pending (ready but not running) tasks in graph executor."}, + // Power of 1.5 with bucket count 30 (> 191k) + {monitoring::Buckets::Exponential(1, 1.5, 30)}); + auto* graph_run_input_tensor_bytes = monitoring::Sampler<0>::New( {"/tensorflow/core/graph_run_input_tensor_bytes", "The size of input tensors in bytes."}, @@ -252,6 +258,12 @@ void UpdateGraphExecTime(const uint64 running_time_usecs) { } } +void UpdateGraphPendingQueueLength(uint64 len) { + static auto* graph_pending_queue_length_cell = + graph_pending_queue_length_histogram->GetCell(); + graph_pending_queue_length_cell->Add(len); +} + void UpdateGraphOptimizationPassTime(const string& pass_name, const uint64 running_time_usecs) { if (running_time_usecs > 0) { diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h index f7c90ce593e..ef04e09aaa4 100644 --- a/tensorflow/core/framework/metrics.h +++ b/tensorflow/core/framework/metrics.h @@ -95,6 +95,7 @@ void RecordGraphInputTensors(const size_t size); void RecordGraphOutputTensors(const size_t size); void UpdateGraphExecTime(const uint64 running_time_usecs); +void UpdateGraphPendingQueueLength(uint64 len); // Records that one output of an op of type `op_name` was unused. void RecordUnusedOutput(const string& op_name); diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index ef0ed993b6f..951052b794a 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -1488,7 +1488,7 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget, double best_delta = -1.0L; Parameter* best_parameter = nullptr; for (auto& pair : parameters) { - if (pair.second->value == pair.second->max) { + if (pair.second->value >= pair.second->max) { continue; } pair.second->value++; diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 0eda2c6492f..3b7e597e425 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -145,7 +145,7 @@ bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) { return true; } -void FinalizeAttr(StringPiece spec, OpDef* op_def, +void FinalizeAttr(StringPiece spec, bool allow_attr_type_any, OpDef* op_def, std::vector* errors) { OpDef::AttrDef* attr = op_def->add_attr(); StringPiece orig(spec); @@ -175,6 +175,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, type = "tensor"; } else if (absl::ConsumePrefix(&spec, "func")) { type = "func"; + } else if (absl::ConsumePrefix(&spec, "any") && allow_attr_type_any) { + type = "any"; } else if (ConsumeCompoundAttrType(&spec, &type_string)) { type = "type"; AttrValue* allowed = attr->mutable_allowed_values(); @@ -633,13 +635,18 @@ OpDefBuilder& OpDefBuilder::SetShapeFn(OpShapeInferenceFn fn) { return *this; } +OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() { + allow_attr_type_any_ = true; + return *this; +} + Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { std::vector errors = errors_; *op_reg_data = op_reg_data_; OpDef* op_def = &op_reg_data->op_def; for (StringPiece attr : attrs_) { - FinalizeAttr(attr, op_def, &errors); + FinalizeAttr(attr, allow_attr_type_any_, op_def, &errors); } for (StringPiece input : inputs_) { FinalizeInputOrOutput(input, false, op_def, &errors); diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h index b69ee46cd59..6789558bc80 100644 --- a/tensorflow/core/framework/op_def_builder.h +++ b/tensorflow/core/framework/op_def_builder.h @@ -142,6 +142,10 @@ class OpDefBuilder { // python/framework/common_shapes.py OpDefBuilder& SetShapeFn(OpShapeInferenceFn fn); + // Allows the `` in calls to `Attr()` to be "any". + // This is used by PythonAPIWrapper for pass-through parameters. + OpDefBuilder& AllowAttrTypeAny(); + // Sets op_reg_data->op_def to the requested OpDef and // op_reg_data->shape_inference_fn to the requested shape inference function, // or returns an error. @@ -168,6 +172,7 @@ class OpDefBuilder { std::vector control_outputs_; std::string doc_; std::vector errors_; + bool allow_attr_type_any_ = false; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 1648471b6ef..9b5648927d1 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -1152,30 +1152,17 @@ TEST(RegisteredKernels, GetRegisteredKernelsForOp) { EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU"); } -struct KernelNameAndBuilderFunction { - using FnType = std::unique_ptr(); - char const* name; - FnType* builder; -}; - -#define EXTRACT_KERNEL_NAME_TO_STRUCT_IMPL(name, kernel_builder, ...) \ - KernelNameAndBuilderFunction { \ - name, +[]() { \ - return std::unique_ptr(kernel_builder.Build()); \ - } \ - } -#define EXTRACT_KERNEL_NAME_TO_STRUCT(kernel_builder) \ - TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_TO_STRUCT_IMPL, kernel_builder) +// EXTRACT_KERNEL_NAME_TO_STRING wraps TF_EXTRACT_KERNEL_NAME for testing +// (it involves quite a bit of macro-magic). +#define EXTRACT_KERNEL_NAME_TO_STRING_IMPL(name, kernel_builder, ...) name +#define EXTRACT_KERNEL_NAME_TO_STRING(kernel_builder) \ + TF_EXTRACT_KERNEL_NAME(EXTRACT_KERNEL_NAME_TO_STRING_IMPL, kernel_builder) TEST(RegisterKernelMacro, ExtractName) { - constexpr char const* kName = "Foo"; - constexpr char const* kLabel = "Label"; - constexpr KernelNameAndBuilderFunction kKernelInfo = - EXTRACT_KERNEL_NAME_TO_STRUCT(Name(kName).Label(kLabel)); - EXPECT_THAT(kKernelInfo.name, ::testing::StrEq(kName)); - std::unique_ptr kernel_def = kKernelInfo.builder(); - EXPECT_THAT(kernel_def->op(), ::testing::StrEq(kName)); - EXPECT_THAT(kernel_def->label(), ::testing::StrEq(kLabel)); + static constexpr char const* kName = "Foo"; + static constexpr char const* kExtractedName = + EXTRACT_KERNEL_NAME_TO_STRING(Name(kName).Label("Label")); + EXPECT_THAT(kExtractedName, ::testing::StrEq(kName)); } } // namespace diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index 764f8995d02..9d63265af62 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -151,7 +151,7 @@ Status RendezvousInterface::Recv(const ParsedKey& key, const Args& args, namespace { class LocalRendezvousWrapper : public Rendezvous { public: - LocalRendezvousWrapper() = default; + LocalRendezvousWrapper() : impl_(this) {} Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val, const bool is_dead) override { diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index bb79b278cb1..10b54476d18 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -133,9 +133,14 @@ struct DimensionOrConstant { struct ShapeAndType { ShapeAndType() {} ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} + ShapeAndType(ShapeHandle s, DataType t, SpecializedType specialized_t) + : shape(s), dtype(t), specialized_type(specialized_t) {} ShapeHandle shape; DataType dtype = DT_INVALID; + // The type of a variant-dtype tensor sometimes affects graph building + // (e.g. for vectorization), and needs to be know statically in such cases. + SpecializedType specialized_type = ST_INVALID; }; // Shape inference functions registered on ops in REGISTER_OP implement diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 565014d14b1..b564ac144cc 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -276,7 +276,6 @@ void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) { // set_ndims_byte(b.ndims_byte()); // set_data_type(b.data_type()); } else { - DCHECK_EQ(b.tag(), REP_OUT_OF_LINE); set_ndims_byte(b.ndims_byte()); set_data_type(b.data_type()); if (tag() == REP_OUT_OF_LINE) { @@ -502,8 +501,8 @@ TensorShapeIter TensorShapeBase::begin() const { template TensorShapeIter TensorShapeBase::end() const { - CHECK(!unknown_rank()); - return TensorShapeIter(static_cast(this), dims()); + const int max_dim = unknown_rank() ? -1 : dims(); + return TensorShapeIter(static_cast(this), max_dim); } string TensorShapeRep::DebugString() const { diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc index 89a20b9a039..804d5df31ed 100644 --- a/tensorflow/core/framework/tensor_testutil.cc +++ b/tensorflow/core/framework/tensor_testutil.cc @@ -216,6 +216,8 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { switch (x.dtype()) { case DT_HALF: return ExpectClose(x, y, atol, rtol); + case DT_BFLOAT16: + return ExpectClose(x, y, atol, rtol); case DT_FLOAT: return ExpectClose(x, y, atol, rtol); case DT_DOUBLE: diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto index 900132c0db9..61549ae08ce 100644 --- a/tensorflow/core/framework/types.proto +++ b/tensorflow/core/framework/types.proto @@ -74,3 +74,14 @@ enum DataType { // https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, // https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, // https://www.tensorflow.org/code/tensorflow/python/framework/function.py) + +// For identifying the underlying type of a variant. For variants, the types +// listed here are a subset of the types in the variant type registry, +// corresponding to commonly used variants which must occasionally be +// special-cased. +enum SpecializedType { + // Invalid/unknown specialized type. + ST_INVALID = 0; + // "tensorflow::TensorList" in the variant type registry. + ST_TENSOR_LIST = 1; +} \ No newline at end of file diff --git a/tensorflow/core/graph/BUILD b/tensorflow/core/graph/BUILD index b9461291bfc..e3b5076b9f6 100644 --- a/tensorflow/core/graph/BUILD +++ b/tensorflow/core/graph/BUILD @@ -64,6 +64,7 @@ filegroup( "graph_node_util.h", "node_builder.h", "tensor_id.h", + "types.h", ], ) diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 3d3921d68c0..dedbdb29624 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -617,9 +617,9 @@ class Graph { int num_edge_ids() const { return edges_.size(); } // Returns the Edge associated with an id, or nullptr if no edge - // with that id (the node with that id was removed and the id has + // with that id (the edge with that id was removed and the id has // not yet been re-used). *this owns the returned instance. - // REQUIRES: 0 <= id < num_node_ids(). + // REQUIRES: 0 <= id < num_edge_ids(). const Edge* FindEdgeId(int id) const { return edges_[id]; } // Access to the set of all edges. Example usage: diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 8a9b23ff96f..f3bea5c75a3 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -74,7 +74,7 @@ tf_cuda_library( hdrs = ["devices.h"], cuda_deps = [ "//tensorflow/core/common_runtime/gpu:gpu_init", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ], visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 3299e2b4dc6..41910fc9efc 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -854,6 +854,15 @@ class SymbolicShapeRefiner { } } + // ReplaceInputWithConst() may break GraphView's internal node mapping + // structure; hence, we separately build node name to NodeDef* map, for the + // output nodes (before GraphView becomes invalid). Note that we use string, + // not string_view. + absl::flat_hash_map output_nodes; + for (const auto& output_arg : grappler_function_item.outputs()) { + output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name); + } + // Replace input nodes with Consts, if values are known. Note that // we don't check exceptions here as it's done in the above loop. auto* ctx = GetNodeContext(function_node); @@ -884,11 +893,13 @@ class SymbolicShapeRefiner { &grappler_function_item)); } } + // node_name to NodeDef* map in GraphView gv can be broken due to + // ReplaceInputWithConst(). gv should not be used after this. // Replace output _Retval nodes with Identity nodes. _Retval is a system op // without outputs and registered shape function. for (const auto& output_arg : grappler_function_item.outputs()) { - NodeDef* output_node = gv.GetNode(output_arg.node_name); + NodeDef* output_node = output_nodes[output_arg.node_name]; DCHECK_EQ(output_node->op(), "_Retval"); output_node->set_op("Identity"); output_node->mutable_attr()->erase("index"); @@ -911,12 +922,12 @@ class SymbolicShapeRefiner { // inputs, so port_id >= 0. TensorId out_tensor = ParseTensorName(out_arg.node_name); - const NodeDef* retnode = gv.GetNode(out_tensor.node()); - if (retnode == nullptr) { + if (output_nodes.count(out_tensor.node()) <= 0) { return errors::FailedPrecondition( "Unable to find return function_node ", out_tensor.node(), " for ", function_node->name()); } + const NodeDef* retnode = output_nodes[out_tensor.node()]; auto output_properties = gp.GetOutputProperties(retnode->name()); int output_properties_size = output_properties.size(); diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 9fd43cdee12..5409e6077bc 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -34,6 +34,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" +#ifdef INTEL_MKL +#include "tensorflow/core/graph/mkl_graph_util.h" +#endif namespace tensorflow { namespace grappler { @@ -251,9 +254,13 @@ TEST_F(GraphPropertiesTest, DynamicProperties) { EXPECT_EQ(1, prop.shape().dim(1).size()); const auto out_props = properties.GetOutputProperties(node.name()); #ifdef INTEL_MKL - // Intel MKL AddN OP would have two output. - // One is the real output, another one for MKL metadata - EXPECT_EQ(2, out_props.size()); + if (!NativeFormatEnabled()) { + // Intel MKL AddN OP would have two output. + // One is the real output, another one for MKL metadata + EXPECT_EQ(2, out_props.size()); + } else { + EXPECT_EQ(1, out_props.size()); + } #else EXPECT_EQ(1, out_props.size()); #endif // INTEL_MKL diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 69735fdca07..16b62ab3f3a 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -2364,6 +2364,11 @@ Costs OpLevelCostEstimator::PredictResizeBilinear( const OpContext& op_context) const { bool found_unknown_shapes = false; + if (op_context.op_info.outputs().empty() || + op_context.op_info.inputs().empty()) { + return Costs::ZeroCosts(/*inaccurate=*/true); + } + const int64 input_size = CalculateTensorSize(op_context.op_info.inputs(0), &found_unknown_shapes); const int64 output_size = diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index dd38acbb93c..e38e8d187a1 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -56,6 +56,13 @@ class TestOpLevelCostEstimator : public OpLevelCostEstimator { DeviceInfo device_info_; }; +void ExpectZeroCost(const Costs& cost) { + EXPECT_TRUE(cost.inaccurate); + EXPECT_EQ(cost.compute_time, Costs::Duration::zero()); + EXPECT_EQ(cost.execution_time, Costs::Duration::zero()); + EXPECT_EQ(cost.memory_time, Costs::Duration::zero()); +} + // Wrangles the minimum number of proto fields to set up a matrix. void DescribeMatrix(int rows, int columns, OpInfo* op_info) { auto input = op_info->add_inputs(); @@ -2082,10 +2089,27 @@ TEST_F(OpLevelCostEstimatorTest, PureMemoryOpExecutionTime) { EXPECT_EQ(cost.persistent_memory, 0); } } + TEST_F(OpLevelCostEstimatorTest, ResizeBilinearExecutionTime) { const int kImageDim = 255; const int kChannelSize = 10; const int kComputeLerpCost = 9; + { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("ResizeBilinear"); + DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize, + op_context.op_info.add_inputs()); + // Test with no output. + auto cost = PredictCosts(op_context); + ExpectZeroCost(cost); + op_context.op_info.clear_inputs(); + + DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs()); + // Test with no input. + cost = PredictCosts(op_context); + ExpectZeroCost(cost); + } { // Test with size 0 output. OpContext op_context; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index c1643bc7bee..a8a337bc3fa 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -361,6 +361,8 @@ std::unique_ptr ReadyNodeManagerFactory( return nullptr; } +SchedulerState::~SchedulerState() {} + SchedulerState::SchedulerState(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, @@ -1259,15 +1261,23 @@ void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) { node_state.time_scheduled = device.GetCurrTime(); } +VirtualScheduler::~VirtualScheduler() {} + VirtualScheduler::VirtualScheduler(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, ReadyNodeManager* ready_nodes, std::unique_ptr placer) - : scheduler_state_(use_static_shapes, use_aggressive_shape_inference, - cluster, std::move(placer)), + : scheduler_state_(absl::make_unique( + use_static_shapes, use_aggressive_shape_inference, cluster, + std::move(placer))), ready_nodes_(ready_nodes) {} +VirtualScheduler::VirtualScheduler( + ReadyNodeManager* ready_nodes, + std::unique_ptr scheduler_state) + : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {} + Status VirtualScheduler::Init(const GrapplerItem* item) { // SchedulerState::Init() preprocesses the input grappler_item and // graph_properties to extract necessary information for emulating tensorflow @@ -1275,7 +1285,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { // DeviceState) for virtual scheduling. TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates())); std::vector initial_nodes; - auto status = scheduler_state_.Init(item, &initial_nodes); + auto status = scheduler_state_->Init(item, &initial_nodes); if (status.ok()) { // Add the set of initial nodes to ready_nodes_ for (auto node : initial_nodes) { @@ -1285,17 +1295,17 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { return status; } -OpContext VirtualScheduler::GetCurrNode() const { +OpContext VirtualScheduler::GetCurrNode() { const NodeDef* node = ready_nodes_->GetCurrNode(); - return scheduler_state_.CreateOpContext(node); + return scheduler_state_->CreateOpContext(node); } bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { // Update graph_costs_ and per-op costs. const NodeDef* node = ready_nodes_->GetCurrNode(); - auto new_nodes = scheduler_state_.MarkNodeExecuted( + auto new_nodes = scheduler_state_->MarkNodeExecuted( node, node_costs, - scheduler_state_.CreateOpContext(ready_nodes_->GetCurrNode())); + scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode())); ready_nodes_->RemoveCurrNode(); // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_. for (auto node : new_nodes) { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 0968d2ae11d..04f1e571ae5 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -324,6 +324,21 @@ class SchedulerState { SchedulerState(const bool use_static_shapes, const bool use_aggressive_shape_inference, Cluster* cluster, std::unique_ptr placer); + // Move constructor. Explicitly defined because it otherwise gets implicitly + // deleted. SchedulerState is a move-only class, as we have a + // for it in VirtualScheduler. A derivative of VirtualScheduler can move a + // SchedulerState to VirtualScheduler when it is constructed, + // which is where this move constructor is needed. + SchedulerState(SchedulerState&& arg) = default; + // We explicitly delete assinment and copy operators, this is done implicitly, + // but we state it here explicitly for clarity. + SchedulerState& operator=(SchedulerState&& arg) = delete; + SchedulerState(const SchedulerState&) = delete; + SchedulerState& operator=(const SchedulerState&) = delete; + // Destructor. Must be defined such that a derivative class can override it + // and allow proper desctruction of the derivative class. If this is not done + // properly, memory leaks can occur. + virtual ~SchedulerState(); // Sets up the graph while also performing some necessary transformations // initial_nodes is the set of nodes (primary inputs) discovered by Init() // which may be added by a ReadyNodeManager (or related/derivative scheduler) @@ -332,12 +347,14 @@ class SchedulerState { std::vector* initial_nodes, bool create_explicit_channel_device = true); - Costs Summary() const; + virtual Costs Summary() const; // Like the above, but writes detailed stats to RunMetadata. // If metadata is nullptr, then just calls and return Summary(). - Costs Summary(RunMetadata* metadata); + virtual Costs Summary(RunMetadata* metadata); // Generates RunMetadata's step_stats and partition_graphs fields from results // of the virtual execution of the graph. + // TODO(rdegruijl) See if we can make this function and caller Summary() + // const. void GenerateRunMetadata(RunMetadata* metadata); // Returns per device memory usage. @@ -438,6 +455,15 @@ class VirtualScheduler { const bool use_aggressive_shape_inference, Cluster* cluster, ReadyNodeManager* ready_nodes, std::unique_ptr placer); + // This constructor can be called by a derivative of VirtualScheduler to + // construct the base class. It lets VirtualScheduler take ownership of + // a new SchedulerState or a derivative thereof. + // Note that this constructor does not set a VirtualPlacer, in this + // constructor the VirtialPlacer is passed as a member of the SchedulerState + // that is passed as an argument. + VirtualScheduler(ReadyNodeManager* ready_nodes, + std::unique_ptr scheduler_state); + virtual ~VirtualScheduler(); // Initializes the scheduler for the specific grappler item. // Should be called immediately after the c'tor or when the scheduler will be @@ -447,51 +473,51 @@ class VirtualScheduler { // This function should be called at least once after the scheduler is // constructed. An uninitialized or failed-to-initialize scheduler will cause // undefined behavior. - Status Init(const GrapplerItem* item); + virtual Status Init(const GrapplerItem* item); // Gets the current scheduled node for execution; the caller of this function // can accordingly simulate the execution of the current scheduled node. - OpContext GetCurrNode() const; + virtual OpContext GetCurrNode(); // Marks the current scheduled node as executed. Note that we should call this // function only after the execution of the node has been simulated; // node_costs_ capture the simulated costs of the node. // Returns true if there is any node to be scheduled. - bool MarkCurrNodeExecuted(const Costs& node_costs); + virtual bool MarkCurrNodeExecuted(const Costs& node_costs); // Prints out summary of execution (timing, memory usage, etc.) - Costs Summary() const { return scheduler_state_.Summary(); } + Costs Summary() const { return scheduler_state_->Summary(); } // Like the above, but writes detailed stats to RunMetadata. // If metadata is nullptr, then just calls and return Summary(). Costs Summary(RunMetadata* metadata) { - return scheduler_state_.Summary(metadata); + return scheduler_state_->Summary(metadata); } // Generates RunMetadata's step_stats and partition_graphs fields from results // of the virtual execution of the graph. void GenerateRunMetadata(RunMetadata* metadata) { - scheduler_state_.GenerateRunMetadata(metadata); + scheduler_state_->GenerateRunMetadata(metadata); } // Returns per device memory usage. const std::unordered_map GetPeakMemoryUsage() const { - return scheduler_state_.GetPeakMemoryUsage(); + return scheduler_state_->GetPeakMemoryUsage(); } const std::unordered_map GetPersistentMemoryUsage() const { - return scheduler_state_.GetPersistentMemoryUsage(); + return scheduler_state_->GetPersistentMemoryUsage(); } // Returns VirtualScheduler (read only) device and node states. const std::unordered_map* GetDeviceStates() const { - return scheduler_state_.GetDeviceStates(); + return scheduler_state_->GetDeviceStates(); } const std::unordered_map* GetNodeStates() const { - return scheduler_state_.GetNodeStates(); + return scheduler_state_->GetNodeStates(); } void enable_mem_usage_tracking() { - scheduler_state_.enable_mem_usage_tracking(); + scheduler_state_->enable_mem_usage_tracking(); } - private: + protected: // The state of the scheduler and the execution of the graph is encapsulated // by the scheduler_state_ object. - SchedulerState scheduler_state_; + std::unique_ptr scheduler_state_; // ready_nodes_ is responsible for ordering the traversal of the graph. ReadyNodeManager* ready_nodes_; // Not owned. }; diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 94907d2ee6c..9b41df4cab5 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -965,6 +965,7 @@ class AutoMixedPrecisionImpl { bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const; void ConvertBatchNormOpsToV2(); bool SupportsF16(const NodeTypeId& node_type) const; + bool SupportsF16DataType(const NodeTypeId& node_type) const; const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const; bool IsSourceOrSinkOp(const string& op) const; void FindFloat32TensorListOpClustersAndDenylistUnsafe( @@ -974,6 +975,7 @@ class AutoMixedPrecisionImpl { const absl::flat_hash_set& tensor_list_nodes, std::vector* implicit_data_edges) const; void AddAllowlistOps(absl::flat_hash_set* allow_set) const; + void RemoveAllowsetWithFp32(absl::flat_hash_set* allow_set) const; void PropagateDenyFwdThroughClearAndInfer( absl::flat_hash_set* deny_set) const; void ForceColorMatchBetweenTensorListOps( @@ -1200,6 +1202,15 @@ bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const { NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr); } +bool AutoMixedPrecisionImpl::SupportsF16DataType( + const NodeTypeId& node_type) const { + const OpDef* op_def; + Status status = + OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def); + if (!status.ok()) return false; + return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_); +} + // TODO(mconley): Make this change the node's name (to aid debugging). Need to // make sure that doing this won't break anything. void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() { @@ -1366,6 +1377,12 @@ Status AutoMixedPrecisionImpl::Optimize() { PropagateAllowThroughClear(deny_set, &allow_set); VLOG(2) << "Finished pass 4"; + VLOG(2) << "Beginning pass 5 to remove some nodes which could not be changed " + "to F16" + "from allow set"; + RemoveAllowsetWithFp32(&allow_set); + VLOG(2) << "Finished pass 5"; + VLOG(2) << "Forcing color match between data structure ops"; for (const auto& cluster : tensor_list_clusters) { ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set); @@ -1684,6 +1701,25 @@ void AutoMixedPrecisionImpl::PropagateAllowThroughClear( } } +// If ops have one or more type_attr, But this type_attr could not be converted +// to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only +// support float. So we will remove this node from allow_set. +void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32( + absl::flat_hash_set* allow_set) const { + for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { + const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); + if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) && + !SupportsF16DataType(root)) { + auto erased = allow_set->erase(root_idx); + if (VLOG_IS_ON(2) && erased) { + VLOG(2) << "UnPainting type " << root.type_attr.DebugString() + << " of node " << root.node->name() << " ALLOW because its op " + << root.node->op() << " is not support F16 DataType"; + } + } + } +} + // Forces NextIteration nodes and their output Merge node(s) to have the same // color. Specifically, it removes them all from allow_set if any of the Merge // nodes is not in allow_set, otherwise it adds the NextIteration node to diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 6492fb00d5d..db75284a60d 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -17,6 +17,7 @@ cc_library( deps = [ ":autotune_buffer_sizes", ":disable_intra_op_parallelism", + ":disable_prefetch_legacy_autotune", ":enable_gradient_descent", ":filter_fusion", ":filter_with_random_uniform_fusion", @@ -131,6 +132,41 @@ tf_cc_test( ], ) +cc_library( + name = "disable_prefetch_legacy_autotune", + srcs = ["disable_prefetch_legacy_autotune.cc"], + hdrs = ["disable_prefetch_legacy_autotune.h"], + deps = [ + ":graph_utils", + ":optimizer_base", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", + ] + tf_protos_all(), + alwayslink = 1, +) + +tf_cc_test( + name = "disable_prefetch_legacy_autotune_test", + srcs = ["disable_prefetch_legacy_autotune_test.cc"], + deps = [ + ":disable_prefetch_legacy_autotune", + ":graph_test_utils", + ":graph_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + ], +) + cc_library( name = "enable_gradient_descent", srcs = ["enable_gradient_descent.cc"], diff --git a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc index fc34c63a741..92d1d8911cc 100644 --- a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc +++ b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.cc @@ -31,7 +31,6 @@ namespace grappler { namespace { constexpr char kLegacyAutotune[] = "legacy_autotune"; -constexpr char kBufferSizeMin[] = "buffer_size_min"; constexpr char kPrefetchDataset[] = "PrefetchDataset"; constexpr std::array kAsyncDatasetOps = { @@ -55,45 +54,12 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster, } MutableGraphView graph(output); - absl::flat_hash_set already_prefetched; - // 1) Collect about all existing `PrefetchDataset` nodes, replacing - // `prefetch(N)` with `prefetch(AUTOTUNE, buffer_size_min=N)` for all N !=-1. - for (NodeDef& node : *(output->mutable_node())) { - if (node.op() == kPrefetchDataset) { - NodeDef* buffer_size_node = graph.GetNode(node.input(1)); - // We only consider to rewrite if `buffer_size` is constant. - if (buffer_size_node->op() == "Const") { - int64 initial_buffer_size = - buffer_size_node->attr().at("value").tensor().int64_val(0); - if (initial_buffer_size != data::model::kAutotune) { - buffer_size_node->mutable_attr() - ->at("value") - .mutable_tensor() - ->set_int64_val(0, data::model::kAutotune); - node.mutable_attr()->at(kBufferSizeMin).set_i(initial_buffer_size); - } - } else { - return errors::FailedPrecondition( - "The autotune_buffer_sizes rewrite does not currently support " - "non-constant buffer_size input."); - } - NodeDef* prefetched_node = graph_utils::GetInputNode(node, graph); - if (prefetched_node) { - already_prefetched.insert(prefetched_node->name()); - } - } - } - std::vector async_datasets; - // 2) Insert `prefetch(AUTOTUNE)` after all asynchronous transformations that - // are not followed by a `prefetch` yet. for (const NodeDef& node : item.graph.node()) { - if (already_prefetched.find(node.name()) != already_prefetched.end()) { - continue; - } for (const auto& async_dataset_op : kAsyncDatasetOps) { if (node.op() == async_dataset_op) { async_datasets.push_back(&node); + stats->num_changes++; break; } } @@ -115,6 +81,7 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster, *prefetch_node.mutable_input()->Add() = async_dataset_node->name(); // `buffer_size` input *prefetch_node.mutable_input()->Add() = autotune_value->name(); + for (const auto& attr_name : {"output_types", "output_shapes"}) { graph_utils::CopyAttribute(attr_name, *async_dataset_node, &prefetch_node); @@ -125,15 +92,6 @@ Status AutotuneBufferSizes::OptimizeAndCollectStats(Cluster* cluster, graph.UpdateFanouts(async_dataset_node->name(), added_node->name())); } - for (NodeDef& node : *output->mutable_node()) { - // 3) Switch from using legacy algorithm to using performance model - // based algorithm for autotuning of all `prefetch` nodes. - if (node.op() == kPrefetchDataset) { - (*node.mutable_attr())[kLegacyAutotune].set_b(false); - stats->num_changes++; - } - } - return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h index 24127132a8e..c0eb43dd6a4 100644 --- a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h +++ b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h @@ -24,18 +24,8 @@ namespace grappler { constexpr char kAutotune[] = "autotune"; -// This optimization does the following: -// -// 1. Adds `prefetch(AUTOTUNE)` after all asynchronous tf.data transformations -// (e.g. parallel map, parallel interleave, and map + batch) if they are not -// followed by a `prefetch` yet. -// -// 2. If there exists any `prefetch(buffer_size=N)` for `N>=0`, it will replace -// the transformation with autotunable version of `prefetch` which uses N as -// the minimum size of the buffer. -// -// 3. Switches from using legacy autotuning for `prefetch` to using an algorithm -// based on the performance model. +// This optimization adds `prefetch(AUTOTUNE)` after all asynchronous tf.data +// transformations (e.g. parallel map, parallel interleave, and map + batch). class AutotuneBufferSizes : public TFDataOptimizerBase { public: AutotuneBufferSizes() = default; diff --git a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes_test.cc b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes_test.cc index bd87539ea3c..7b4f9d10a8c 100644 --- a/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes_test.cc +++ b/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes_test.cc @@ -105,7 +105,8 @@ TEST_P(SimpleInject, AutotuneBufferSizesTest) { EXPECT_TRUE(graph_utils::ContainsNodeWithOp("PrefetchDataset", output)); int index = graph_utils::FindGraphNodeWithOp("PrefetchDataset", output); const NodeDef prefetch_node = output.node(index); - EXPECT_FALSE(prefetch_node.attr().at("legacy_autotune").b()); + EXPECT_TRUE(prefetch_node.attr().find("legacy_autotune") == + prefetch_node.attr().end()); EXPECT_EQ(prefetch_node.input_size(), 2); NodeDef async_node = output.node( graph_utils::FindGraphNodeWithName(prefetch_node.input(0), output)); @@ -149,100 +150,6 @@ TEST_P(AutotuneSetting, AutotuneBufferSizesTest) { INSTANTIATE_TEST_SUITE_P(Test, AutotuneSetting, ::testing::Values(false, true)); -class MultipleNodes : public ::testing::TestWithParam> { -}; - -TEST_P(MultipleNodes, AutotuneBufferSizesTest) { - const bool legacy_autotune = std::get<0>(GetParam()); - const int64 initial_buffer_size = std::get<1>(GetParam()); - - GrapplerItem item; - MutableGraphView graph(&item.graph); - - NodeDef *start_val = graph_utils::AddScalarConstNode(0, &graph); - NodeDef *stop_val = graph_utils::AddScalarConstNode(10, &graph); - NodeDef *step_val = graph_utils::AddScalarConstNode(1, &graph); - - std::vector range_inputs(3); - range_inputs[0] = start_val->name(); - range_inputs[1] = stop_val->name(); - range_inputs[2] = step_val->name(); - std::vector> range_attrs; - NodeDef *range_node = graph_utils::AddNode("range", "RangeDataset", - range_inputs, range_attrs, &graph); - - NodeDef *parallelism_val = graph_utils::AddScalarConstNode(1, &graph); - std::vector map_inputs1(2); - map_inputs1[0] = range_node->name(); - map_inputs1[1] = parallelism_val->name(); - std::vector> map_attrs(4); - AttrValue attr_val; - SetAttrValue("value", &attr_val); - map_attrs[0] = std::make_pair("f", attr_val); - map_attrs[1] = std::make_pair("Targuments", attr_val); - map_attrs[2] = std::make_pair("output_types", attr_val); - map_attrs[3] = std::make_pair("output_shapes", attr_val); - NodeDef *map_node1 = graph_utils::AddNode("map1", "ParallelMapDatasetV2", - map_inputs1, map_attrs, &graph); - - NodeDef *buffer_size_val = - graph_utils::AddScalarConstNode(initial_buffer_size, &graph); - std::vector prefetch_inputs(2); - prefetch_inputs[0] = map_node1->name(); - prefetch_inputs[1] = buffer_size_val->name(); - std::vector> prefetch_attrs(2); - AttrValue legacy_autotune_attr; - SetAttrValue(legacy_autotune, &legacy_autotune_attr); - prefetch_attrs[0] = std::make_pair("legacy_autotune", legacy_autotune_attr); - AttrValue buffer_size_min_attr; - SetAttrValue(0, &buffer_size_min_attr); - prefetch_attrs[1] = std::make_pair("buffer_size_min", buffer_size_min_attr); - NodeDef *prefetch_node = graph_utils::AddNode( - "prefetch", "PrefetchDataset", prefetch_inputs, prefetch_attrs, &graph); - - std::vector map_inputs2(2); - map_inputs2[0] = prefetch_node->name(); - map_inputs2[1] = parallelism_val->name(); - graph_utils::AddNode("map2", "ParallelMapDatasetV2", map_inputs2, map_attrs, - &graph); - - EXPECT_EQ(item.graph.node_size(), 9); - - GraphDef output; - TF_ASSERT_OK(OptimizeWithAutotuneBufferSizes(item, &output, true)); - EXPECT_EQ(output.node_size(), 11); - - std::vector prefetch_indices = - graph_utils::FindAllGraphNodesWithOp("PrefetchDataset", output); - EXPECT_EQ(prefetch_indices.size(), 2); - NodeDef new_prefetch_node1 = output.node(prefetch_indices[0]); - NodeDef new_prefetch_node2 = output.node(prefetch_indices[1]); - - EXPECT_EQ(new_prefetch_node1.input_size(), 2); - EXPECT_FALSE(new_prefetch_node1.attr().at("legacy_autotune").b()); - EXPECT_EQ(new_prefetch_node1.attr().at("buffer_size_min").i(), - (initial_buffer_size == -1 ? 0 : initial_buffer_size)); - NodeDef new_map_node1 = output.node( - graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(0), output)); - EXPECT_EQ(new_map_node1.name(), "map1"); - NodeDef new_buffer_size_val1 = output.node( - graph_utils::FindGraphNodeWithName(new_prefetch_node1.input(1), output)); - EXPECT_EQ(new_buffer_size_val1.attr().at("value").tensor().int64_val(0), -1); - - EXPECT_EQ(new_prefetch_node2.input_size(), 2); - EXPECT_FALSE(new_prefetch_node2.attr().at("legacy_autotune").b()); - NodeDef new_map_node2 = output.node( - graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(0), output)); - EXPECT_EQ(new_map_node2.name(), "map2"); - NodeDef new_buffer_size_val2 = output.node( - graph_utils::FindGraphNodeWithName(new_prefetch_node2.input(1), output)); - EXPECT_EQ(new_buffer_size_val2.attr().at("value").tensor().int64_val(0), -1); -} - -INSTANTIATE_TEST_SUITE_P(Test, MultipleNodes, - ::testing::Combine(::testing::Values(true, false), - ::testing::Values(-1, 3))); - } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.cc b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.cc new file mode 100644 index 00000000000..8b0d783c334 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.cc @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char kLegacyAutotune[] = "legacy_autotune"; +constexpr char kPrefetchDataset[] = "PrefetchDataset"; + +} // namespace + +Status DisablePrefetchLegacyAutotune::OptimizeAndCollectStats( + Cluster* cluster, const GrapplerItem& item, GraphDef* output, + OptimizationStats* stats) { + *output = item.graph; + if (!autotune_) { + VLOG(1) << "The optimization disable_prefetch_legacy_autotune is not " + "applied if autotune is off."; + return Status::OK(); + } + + MutableGraphView graph(output); + + for (NodeDef& node : *output->mutable_node()) { + if (node.op() == kPrefetchDataset) { + if (node.attr().find(kLegacyAutotune) == node.attr().end() || + node.attr().at(kLegacyAutotune).b()) { + // If `legacy_autotune` does not exist as attr or it is true, set it to + // false. + (*node.mutable_attr())[kLegacyAutotune].set_b(false); + stats->num_changes++; + } + } + } + + return Status::OK(); +} + +void DisablePrefetchLegacyAutotune::Feedback(Cluster* cluster, + const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(DisablePrefetchLegacyAutotune, + "disable_prefetch_legacy_autotune"); + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h new file mode 100644 index 00000000000..5910c2a19e0 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization disables the lagacy autotune option for PrefetchDataset. +class DisablePrefetchLegacyAutotune : public TFDataOptimizerBase { + public: + DisablePrefetchLegacyAutotune() = default; + ~DisablePrefetchLegacyAutotune() override = default; + + string name() const override { return "disable_prefetch_legacy_autotune"; }; + + bool UsesFunctionLibrary() const override { return false; } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return Status::OK(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return Status::OK(); + } + + Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_ diff --git a/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune_test.cc b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune_test.cc new file mode 100644 index 00000000000..849e52f6e1b --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +using test::function::NDef; + +Status OptimizeWithDisablePrefetchLegacyAutotune(const GrapplerItem &item, + GraphDef *output, + bool autotune) { + DisablePrefetchLegacyAutotune optimizer; + RewriterConfig_CustomGraphOptimizer config; + if (autotune) { + (*config.mutable_parameter_map())["autotune"].set_s("true"); + } else { + (*config.mutable_parameter_map())["autotune"].set_s("false"); + } + TF_RETURN_IF_ERROR(optimizer.Init(&config)); + return optimizer.Optimize(nullptr, item, output); +} + +class RewriteTest : public ::testing::TestWithParam {}; + +TEST_P(RewriteTest, DisablePrefetchLegacyAutotune) { + const bool autotune = GetParam(); + GrapplerItem item; + + item.graph = test::function::GDef({ + NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, + {{"output_shapes", gtl::ArraySlice{}}, + {"output_types", gtl::ArraySlice{}}}), + NDef("prefetch1", "PrefetchDataset", {"range"}, + {{"legacy_autotune", true}}), + NDef("prefetch2", "PrefetchDataset", {"prefetch1"}, + {{"legacy_autotune", false}}), + NDef("prefetch3", "PrefetchDataset", {"prefetch2"}, {}), + }); + + GraphDef output; + TF_ASSERT_OK( + OptimizeWithDisablePrefetchLegacyAutotune(item, &output, autotune)); + + NodeDef prefetch_node1 = + output.node(graph_utils::FindGraphNodeWithName("prefetch1", output)); + EXPECT_EQ(prefetch_node1.attr().at("legacy_autotune").b(), !autotune); + NodeDef prefetch_node2 = + output.node(graph_utils::FindGraphNodeWithName("prefetch2", output)); + EXPECT_FALSE(prefetch_node2.attr().at("legacy_autotune").b()); + NodeDef prefetch_node3 = + output.node(graph_utils::FindGraphNodeWithName("prefetch3", output)); + if (autotune) { + EXPECT_FALSE(prefetch_node3.attr().at("legacy_autotune").b()); + } else { + EXPECT_TRUE(prefetch_node3.attr().find("legacy_autotune") == + prefetch_node3.attr().end()); + } +} + +INSTANTIATE_TEST_SUITE_P(Test, RewriteTest, ::testing::Values(false, true)); + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc index fa31be8f4cb..4ece16542c8 100644 --- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc +++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.cc @@ -31,6 +31,8 @@ namespace { constexpr char kAlgorithm[] = "algorithm"; constexpr char kModelDataset[] = "ModelDataset"; +constexpr char kRetValOp[] = "_Retval"; + constexpr int64 HILL_CLIMB = 0; constexpr int64 GRADIENT_DESCENT = 1; @@ -47,9 +49,21 @@ Status EnableGradientDescent::OptimizeAndCollectStats( } MutableGraphView graph(output); - int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output); + for (const auto& fetch_name : item.fetch) { + // If the GrapplerItem is derived from a FunctionDef, we don't optimize it, + // because we only want to enable gradient descent on the main dataset + // pipeline. + auto fetch = graph.GetNode(fetch_name); + if (fetch == nullptr || fetch->op() == kRetValOp) { + // Heuristic: If the fetch nodes are Retval ops, this item is from a + // function. + return Status::OK(); + } + } + int index = graph_utils::FindGraphNodeWithOp(kModelDataset, *output); NodeDef& model_node = *(output->mutable_node(index)); + if (model_node.attr().at(kAlgorithm).i() == HILL_CLIMB) { (*model_node.mutable_attr())[kAlgorithm].set_i(GRADIENT_DESCENT); stats->num_changes++; diff --git a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc index 86d66f386ab..c623c53f0b1 100644 --- a/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc +++ b/tensorflow/core/grappler/optimizers/data/enable_gradient_descent_test.cc @@ -41,12 +41,13 @@ Status OptimizeWithEnableGradientDescent(const GrapplerItem &item, return optimizer.Optimize(nullptr, item, output); } -class SimpleRewrite : public ::testing::TestWithParam> { -}; +class SimpleRewrite + : public ::testing::TestWithParam> {}; TEST_P(SimpleRewrite, EnableGradientDescentTest) { const bool autotune = std::get<0>(GetParam()); const int64 algorithm_index = std::get<1>(GetParam()); + const string op = std::get<2>(GetParam()); using test::function::NDef; GrapplerItem item; @@ -58,7 +59,9 @@ TEST_P(SimpleRewrite, EnableGradientDescentTest) { NDef("batch_size", "Const", {}, {{"value", 5}, {"dtype", DT_INT32}}), NDef("batch", "BatchDataset", {"range", "batch_size"}, {}), NDef("model", "ModelDataset", {"batch"}, - {{"algorithm", algorithm_index}})}); + {{"algorithm", algorithm_index}}), + NDef("Sink", op, {"model"}, {})}); + item.fetch.push_back("Sink"); GraphDef output; TF_ASSERT_OK(OptimizeWithEnableGradientDescent(item, &output, autotune)); @@ -67,12 +70,13 @@ TEST_P(SimpleRewrite, EnableGradientDescentTest) { NodeDef model_node = output.node(graph_utils::FindGraphNodeWithName("model", output)); EXPECT_EQ(model_node.attr().at("algorithm").i(), - autotune ? 1 : algorithm_index); + (autotune && op != "_Retval") ? 1 : algorithm_index); } -INSTANTIATE_TEST_SUITE_P(Test, SimpleRewrite, - ::testing::Combine(::testing::Values(false, true), - ::testing::Values(0, 1))); +INSTANTIATE_TEST_SUITE_P( + Test, SimpleRewrite, + ::testing::Combine(::testing::Values(false, true), ::testing::Values(0, 1), + ::testing::Values("Identity", "_Retval"))); } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc index cefa71f8944..cd46a7356ac 100644 --- a/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/data/meta_optimizer.cc @@ -35,7 +35,7 @@ using ConfigMap = std::map; // tf.data optimizations, in the order we want to perform them. -constexpr std::array kTFDataOptimizations = { +constexpr std::array kTFDataOptimizations = { "noop_elimination", "disable_intra_op_parallelism", "shuffle_and_repeat_fusion", @@ -53,6 +53,7 @@ constexpr std::array kTFDataOptimizations = { "reorder_data_discarding_ops", "slack", "autotune_buffer_sizes", + "disable_prefetch_legacy_autotune", "enable_gradient_descent"}; // Parses a list of string optimizer configurations into a map from diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc index 8307b37407e..0a2b308af63 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc @@ -81,6 +81,7 @@ constexpr int kDepthOut = 16; { 0, 3, 1, 2 } #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) +template Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, const string& padding, const string& device) { int batch_size = 8; @@ -91,15 +92,15 @@ Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, int stride = 1; TensorShape input_shape( DIMS(batch_size, input_height, input_width, input_depth)); - Tensor input_data(DT_FLOAT, input_shape); - test::FillIota(&input_data, 1.0f); + Tensor input_data(DataTypeToEnum::value, input_shape); + test::FillIota(&input_data, static_cast(1)); Output input = ops::Const(s->WithOpName("Input"), Input::Initializer(input_data)); TensorShape filter_shape( {filter_size, filter_size, input_depth, filter_count}); - Tensor filter_data(DT_FLOAT, filter_shape); - test::FillIota(&filter_data, 1.0f); + Tensor filter_data(DataTypeToEnum::value, filter_shape); + test::FillIota(&filter_data, static_cast(1)); Output filter = ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); @@ -356,6 +357,25 @@ TEST_F(GenericLayoutOptimizerTest, CPUDevice) { #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) } +TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 4, 2, "VALID", ""); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + GenericLayoutOptimizer optimizer(REWRITER_CONFIG); + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); + + Status status; + utils::GraphView graph_view(&output, &status); + TF_ASSERT_OK(status); + auto* conv_node = graph_view.GetNode("Conv2D"); + ASSERT_NE(conv_node, nullptr); + VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT); +} + TEST_F(GenericLayoutOptimizerTest, Connectivity) { Scope scope = Scope::NewRootScope(); auto conv = SimpleConv2D(&scope, 4, 2, "VALID", diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index c425ef51c2f..ef68b7e7898 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -30,11 +30,13 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/grappler/utils/graph_view.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/util/device_name_utils.h" @@ -83,6 +85,16 @@ inline bool AttrDataFormatMatch(const utils::MutableNodeView& node, return AttrDataFormatMatch(node, src_data_format, &missing); } +bool IsNonFloatingConv2D(const utils::MutableNodeView& node) { + if (IsConv2D(*node.node()) || IsConv2DBackpropInput(*node.node())) { + const auto* attr = node.GetAttr(kAttrT); + if (attr != nullptr) { + return !kDataTypeIsFloating.Contains(attr->type()); + } + } + return false; +} + // Utils for layout agnostic transposer. bool IsComparisonOp(const NodeDef& node) { @@ -155,6 +167,55 @@ std::vector GetDimensionIndicesFromLabel( return indices; } +// RAII-styled object for keeping track of 4D to 5D data format +// upgrade/conversion. Currently only NHWC -> NDHWC and NCHW -> NCDHW are +// supported. +class ScopedDataFormatUpgrader { + public: + ScopedDataFormatUpgrader(TransposeContext* context, int rank) + : context_(context) { + if (rank == 5 && IsSupportedDataFormat(context_->src_format) && + IsSupportedDataFormat(context_->dst_format)) { + old_src_format_ = context_->src_format; + old_dst_format_ = context_->dst_format; + std::string new_src_format = GetUpgradedDataFormat(context_->src_format); + std::string new_dst_format = GetUpgradedDataFormat(context_->dst_format); + context_->AssignDeviceAndDataFormats(context_->target_device, + new_src_format, new_dst_format); + upgraded_ = true; + } + } + + ScopedDataFormatUpgrader(const ScopedDataFormatUpgrader&) = delete; + ScopedDataFormatUpgrader& operator=(const ScopedDataFormatUpgrader&) = delete; + + ~ScopedDataFormatUpgrader() { + if (upgraded_) { + context_->AssignDeviceAndDataFormats(context_->target_device, + old_src_format_, old_dst_format_); + } + } + + private: + bool IsSupportedDataFormat(absl::string_view data_format) { + return data_format == "NHWC" || data_format == "NCHW"; + } + + std::string GetUpgradedDataFormat(absl::string_view data_format) { + if (data_format == "NHWC") { + return "NDHWC"; + } + + DCHECK_EQ(data_format, "NCHW"); + return "NCDHW"; + } + + TransposeContext* context_ = nullptr; + bool upgraded_ = false; + std::string old_src_format_; + std::string old_dst_format_; +}; + } // namespace // TransposeContext. @@ -205,15 +266,19 @@ bool Transposer::ShouldProcess(const TransposeContext& context, GetDeviceName(context.virtual_placer.get(), *node_def); string device; string task; - bool is_on_target_device = + const bool is_on_target_device = DeviceNameUtils::SplitDeviceName(device_name, &task, &device) && absl::StrContains(absl::AsciiStrToLower(device), absl::AsciiStrToLower(context.target_device)); // Only checks data format for layout sensitive op. - bool data_format_match = !IsLayoutSensitiveOp(*node_def) || - AttrDataFormatMatch(node, context.src_format); - return is_on_target_device && data_format_match && + const bool data_format_match = !IsLayoutSensitiveOp(*node_def) || + AttrDataFormatMatch(node, context.src_format); + + // Only transposes floating point nodes. + const bool is_integer_conv2d = IsNonFloatingConv2D(node); + + return is_on_target_device && data_format_match && !is_integer_conv2d && !context.nodes_to_preserve.contains(node_def->name()) && !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0); } @@ -654,7 +719,11 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context, Status DefaultLayoutSensitiveOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutSensitiveOp(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) { + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) { return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -755,13 +824,6 @@ Status Conv2DBackpropInputTransposer::TransposeNode( Status Conv3DTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3D(*node->node())); - // Update the format from 4D to 5D layout. - std::string src_format = context->src_format; - std::string dst_format = context->dst_format; - std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; - std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; - context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, - dst_format_3d); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) { return Status::OK(); } @@ -771,22 +833,12 @@ Status Conv3DTransposer::TransposeNode(TransposeContext* context, TF_RETURN_IF_ERROR(UpdateNode(context, node)); TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); - // Change back the format from 5D to 4D layout. - context->AssignDeviceAndDataFormats(context->target_device, src_format, - dst_format); return context->graph_view->GetMutationBuilder()->Apply(); } Status Conv3DBackpropFilterTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3DBackpropFilterV2(*node->node())); - // Update the format from 4D to 5D layout. - std::string src_format = context->src_format; - std::string dst_format = context->dst_format; - std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; - std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; - context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, - dst_format_3d); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) { return Status::OK(); } @@ -799,22 +851,12 @@ Status Conv3DBackpropFilterTransposer::TransposeNode( // No need to update output shape, as it is always of shape // [filter_height, filter_width, in_channels, out_channels], regardless of // whether NCHW or NHWC is used. - // Change back the format from 5D to 4D layout. - context->AssignDeviceAndDataFormats(context->target_device, src_format, - dst_format); return context->graph_view->GetMutationBuilder()->Apply(); } Status Conv3DBackpropInputTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3DBackpropInputV2(*node->node())); - // Update the format from 4D to 5D layout. - std::string src_format = context->src_format; - std::string dst_format = context->dst_format; - std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; - std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; - context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, - dst_format_3d); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 5)) { return Status::OK(); } @@ -826,9 +868,6 @@ Status Conv3DBackpropInputTransposer::TransposeNode( UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); - // Change back the format from 5D to 4D layout. - context->AssignDeviceAndDataFormats(context->target_device, src_format, - dst_format); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -865,7 +904,11 @@ bool FusedBatchNormGradTransposer::IsTraining( Status FusedBatchNormGradTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsFusedBatchNormGrad(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank) || !IsTraining(*node)) { return Status::OK(); } @@ -1046,10 +1089,20 @@ std::vector LayoutAgnosticOpTransposer::GetVariadic4DFaninPorts( Status DefaultLayoutAgnosticOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutAgnosticOp(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + if (rank != 4 && rank != 5) { + return Status::OK(); + } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); @@ -1074,19 +1127,20 @@ bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node, } bool BinaryOpTransposer::IsFaninShapeSupported( - const utils::MutableNodeView& node) { - return (IsNDOperateWithMD(node, 4, 0) || IsNDOperateWithMD(node, 4, 1) || - IsNDOperateWithMD(node, 4, 4) || IsNDOperateWithMD(node, 0, 4) || - IsNDOperateWithMD(node, 1, 4)); + const utils::MutableNodeView& node, int rank) { + return (IsNDOperateWithMD(node, rank, 0) || + IsNDOperateWithMD(node, rank, 1) || + IsNDOperateWithMD(node, rank, rank) || + IsNDOperateWithMD(node, 0, rank) || IsNDOperateWithMD(node, 1, rank)); } -std::vector BinaryOpTransposer::Get4DDataFaninPorts( - const utils::MutableNodeView& node) { +std::vector BinaryOpTransposer::GetNDDataFaninPorts( + const utils::MutableNodeView& node, int rank) { std::vector values; - if (IsFaninPortRankN(node, 0, 4)) { + if (IsFaninPortRankN(node, 0, rank)) { values.push_back(0); } - if (IsFaninPortRankN(node, 1, 4)) { + if (IsFaninPortRankN(node, 1, rank)) { values.push_back(1); } return values; @@ -1116,12 +1170,10 @@ Status BinaryOpTransposer::AddNodeReshape( return status; } -Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, - absl::string_view node_name, - absl::string_view node_device, - bool node_in_frame, - int num_channels, - absl::string_view depended_node) { +Status BinaryOpTransposer::AddNodeShapeConst( + utils::Mutation* mutation, absl::string_view node_name, + absl::string_view node_device, bool node_in_frame, int num_channels, + absl::string_view depended_node, int rank) { NodeDef new_node; new_node.set_name(string(node_name)); new_node.set_op(kOpConst); @@ -1131,8 +1183,9 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, new_node.mutable_attr()->insert({"dtype", attr_data_type}); AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({4})); - std::vector shape = {1, num_channels, 1, 1}; + Tensor tensor(DT_INT32, TensorShape({rank})); + std::vector shape(rank, 1); + shape[1] = num_channels; for (int i = 0; i < static_cast(shape.size()); i++) { tensor.flat()(i) = shape[i]; } @@ -1150,12 +1203,13 @@ Status BinaryOpTransposer::AddNodeShapeConst(utils::Mutation* mutation, return status; } -Status BinaryOpTransposer::MaybeReshapeVectorFanin( - TransposeContext* context, utils::MutableNodeView* node) { +Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context, + utils::MutableNodeView* node, + int rank) { int vector_index = -1; - if (IsNDOperateWithMD(*node, 4, 1)) { + if (IsNDOperateWithMD(*node, rank, 1)) { vector_index = 1; - } else if (IsNDOperateWithMD(*node, 1, 4)) { + } else if (IsNDOperateWithMD(*node, 1, rank)) { vector_index = 0; } if (vector_index != -1) { @@ -1177,7 +1231,7 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin( TF_RETURN_IF_ERROR( AddNodeShapeConst(mutation, shape_const_node_name, node_device, context->frames.IsInFrame(*node->node()), vector_size, - fanin_node->GetName())); + fanin_node->GetName(), rank)); const auto* t_attr = node->GetAttr(kAttrT); if (t_attr == nullptr) { return errors::InvalidArgument("Missing attribute ", kAttrT); @@ -1195,13 +1249,20 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin( Status BinaryOpTransposer::TransposeNode(TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBinaryOp(*node->node())); - if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node) || + const auto* output_shape_attr = node->GetAttr(kAttrOutputShape); + const auto& shape = output_shape_attr->list().shape(0); + const int rank = shape.dim_size(); + ScopedDataFormatUpgrader data_format_upgrader(context, rank); + if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) || !IsAfterDstToSrcTransform(*context, *node)) { return Status::OK(); } - TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, Get4DDataFaninPorts(*node), - node, kOpTranspose)); - TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node)); + VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() + << "' with op '" << node->GetOp() << "' from data format '" + << context->src_format << "' to '" << context->dst_format << "'"; + TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp( + context, GetNDDataFaninPorts(*node, rank), node, kOpTranspose)); + TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); } @@ -1376,25 +1437,10 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context, regular_fanin.node_view()->GetAttr(kAttrOutputShape); const auto& shape = output_shape_attr->list().shape(0); const int rank = shape.dim_size(); - std::string src_format = context->src_format; - std::string dst_format = context->dst_format; - // Update the format from 4D to 5D layout if necessary. - bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") && - (dst_format == "NHWC" || dst_format == "NCHW"); - if (allow_5d) { - std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW"; - std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW"; - context->AssignDeviceAndDataFormats(context->target_device, src_format_3d, - dst_format_3d); - } + ScopedDataFormatUpgrader data_format_upgrader(context, rank); if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) || !IsReduceAxisSupported(*context, *node) || !IsAfterDstToSrcTransform(*context, *node)) { - // Change back to the original layout due to early exit. - if (allow_5d) { - context->AssignDeviceAndDataFormats(context->target_device, src_format, - dst_format); - } return Status::OK(); } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() @@ -1407,11 +1453,6 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context, TF_RETURN_IF_ERROR( UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); } - // Change back the format from 5D to 4D layout. - if (allow_5d) { - context->AssignDeviceAndDataFormats(context->target_device, src_format, - dst_format); - } return context->graph_view->GetMutationBuilder()->Apply(); } diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index 61720df791b..11a223ee097 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -347,19 +347,21 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer { private: bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m); - bool IsFaninShapeSupported(const utils::MutableNodeView& node); - std::vector Get4DDataFaninPorts(const utils::MutableNodeView& node); + bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank); + std::vector GetNDDataFaninPorts(const utils::MutableNodeView& node, + int rank); Status AddNodeShapeConst(utils::Mutation* mutation, absl::string_view node_name, absl::string_view node_device, bool node_in_frame, - int num_channels, absl::string_view depended_node); + int num_channels, absl::string_view depended_node, + int rank); Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name, absl::string_view node_device, absl::string_view input_name, absl::string_view shape_const_node_name, const DataType& data_type); Status MaybeReshapeVectorFanin(TransposeContext* context, - utils::MutableNodeView* node); + utils::MutableNodeView* node, int rank); }; class ConcatOpTransposer : public LayoutAgnosticOpTransposer { diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 6f88f81548b..0bd683b7b64 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1100,7 +1100,8 @@ class Conv2DProcessor : public NodeProcessor { protected: bool ShouldProcess() const override { return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) && - HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU(); + HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU() && + IsDataTypeFloat(); } TensorShapeProto GetShape(const string& input_name) const { @@ -1131,6 +1132,13 @@ class Conv2DProcessor : public NodeProcessor { return false; } + bool IsDataTypeFloat() const { + if (node_->attr().find("T") != node_->attr().end()) { + return kDataTypeIsFloating.Contains(node_->attr().at("T").type()); + } + return false; + } + // The logic inside this function is based on the internal implementation of // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus // needs to be updated accordingly if the internal implementation changes. diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index c92693adef4..f9c9747fd8c 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/virtual_placer.h" @@ -56,11 +57,13 @@ class LayoutOptimizerTest : public GrapplerTest { void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); } + template Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, const string& padding) { - return SimpleConv2D(s, input_size, filter_size, padding, ""); + return SimpleConv2D(s, input_size, filter_size, padding, ""); } + template Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, const string& padding, const string& device) { int batch_size = 8; @@ -71,15 +74,15 @@ class LayoutOptimizerTest : public GrapplerTest { int stride = 1; TensorShape input_shape( {batch_size, input_height, input_width, input_depth}); - Tensor input_data(DT_FLOAT, input_shape); - test::FillIota(&input_data, 1.0f); + Tensor input_data(DataTypeToEnum::value, input_shape); + test::FillIota(&input_data, static_cast(1)); Output input = ops::Const(s->WithOpName("Input"), Input::Initializer(input_data)); TensorShape filter_shape( {filter_size, filter_size, input_depth, filter_count}); - Tensor filter_data(DT_FLOAT, filter_shape); - test::FillIota(&filter_data, 1.0f); + Tensor filter_data(DataTypeToEnum::value, filter_shape); + test::FillIota(&filter_data, static_cast(1)); Output filter = ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); @@ -359,6 +362,20 @@ TEST_F(LayoutOptimizerTest, ExplicitPadding) { EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer")); } +TEST_F(LayoutOptimizerTest, DataTypeIsInt32) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 4, 2, "EXPLICIT"); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + EXPECT_FALSE( + node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer")); +} + TEST_F(LayoutOptimizerTest, Pad) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto conv = SimpleConv2D(&s, 4, 2, "VALID"); diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index f5cc44b1464..f8574a4e0d3 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -175,7 +175,9 @@ class MklRemapperTest : public GrapplerTest { auto tensors = EvaluateNodes(output, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + // Using relative tolerance since oneDNN could produce different results + // when float32 numbers need to be rounded during accumulation. + test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6); } }; diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 115428ff5ef..e8dd1a68bb3 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -838,12 +838,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, #ifndef ENABLE_MKLDNN_V1 // We fuse FusedBatchNorm on GPU or MKL CPU. if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false; -#else - if (NativeFormatEnabled()) { - // Temporarily disable FusedBatchNorm fusion on CPU until - // we support it under native format mode - if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false; - } #endif DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T"); @@ -1438,29 +1432,41 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); Status status; - if (fused_node.attr().at(kDataFormat).s() == "NCHW") { + string x_format = fused_node.attr().at(kDataFormat).s(); + if (x_format == "NCHW" or x_format == "NCDHW") { // Need to reshape the last 4 inputs NodeDef new_shape; const string new_shape_name = - AddPrefixToNodeName("NCHWShape", fused_node.name()); + AddPrefixToNodeName(x_format + "Shape", fused_node.name()); new_shape.set_name(new_shape_name); new_shape.set_op("Const"); new_shape.set_device(fused_node.device()); *new_shape.add_input() = AsControlDependency(scale); (*new_shape.mutable_attr())["dtype"].set_type(DT_INT32); - Tensor t(DT_INT32, {4}); - t.flat()(0) = 1; - t.flat()(1) = -1; - t.flat()(2) = 1; - t.flat()(3) = 1; - t.AsProtoTensorContent( - (*new_shape.mutable_attr())["value"].mutable_tensor()); + if (x_format == "NCHW") { + Tensor t(DT_INT32, {4}); + t.flat()(0) = 1; + t.flat()(1) = -1; + t.flat()(2) = 1; + t.flat()(3) = 1; + t.AsProtoTensorContent( + (*new_shape.mutable_attr())["value"].mutable_tensor()); + } else { + Tensor t(DT_INT32, {5}); + t.flat()(0) = 1; + t.flat()(1) = -1; + t.flat()(2) = 1; + t.flat()(3) = 1; + t.flat()(4) = 1; + t.AsProtoTensorContent( + (*new_shape.mutable_attr())["value"].mutable_tensor()); + } mutation->AddNode(std::move(new_shape), &status); TF_RETURN_IF_ERROR(status); NodeDef reshaped_scale; reshaped_scale.set_name( - AddPrefixToNodeName("NCHWShapedScale", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedScale", fused_node.name())); reshaped_scale.set_op("Reshape"); reshaped_scale.set_device(fused_node.device()); *reshaped_scale.add_input() = scale; @@ -1473,7 +1479,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { NodeDef reshaped_offset; reshaped_offset.set_name( - AddPrefixToNodeName("NCHWShapedOffset", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedOffset", fused_node.name())); reshaped_offset.set_op("Reshape"); reshaped_offset.set_device(fused_node.device()); *reshaped_offset.add_input() = offset; @@ -1486,7 +1492,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { NodeDef reshaped_mean; reshaped_mean.set_name( - AddPrefixToNodeName("NCHWShapedMean", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedMean", fused_node.name())); reshaped_mean.set_op("Reshape"); reshaped_mean.set_device(fused_node.device()); *reshaped_mean.add_input() = mean; @@ -1499,7 +1505,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { NodeDef reshaped_variance; reshaped_variance.set_name( - AddPrefixToNodeName("NCHWShapedVariance", fused_node.name())); + AddPrefixToNodeName(x_format + "ShapedVariance", fused_node.name())); reshaped_variance.set_op("Reshape"); reshaped_variance.set_device(fused_node.device()); *reshaped_variance.add_input() = variance; diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 9f5a2c4716a..2aa564104db 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -451,8 +451,10 @@ class RemapperFuseMatMulWithBiasTest : public RemapperTest { ASSERT_EQ(tensors_expected.size(), 1); auto tensors = EvaluateNodes(output, item.fetch, item.feed); ASSERT_EQ(tensors.size(), 1); - typedef typename EnumToDataType::Type T; - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + if (DTYPE == DT_BFLOAT16) + test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2); + else + test::ExpectClose(tensors[0], tensors_expected[0], 1e-6); } }; @@ -704,8 +706,10 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest { ASSERT_EQ(tensors_expected.size(), 1); auto tensors = EvaluateNodes(output, item.fetch, item.feed); ASSERT_EQ(tensors.size(), 1); - typedef typename EnumToDataType::Type T; - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + if (DTYPE == DT_BFLOAT16) + test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2); + else + test::ExpectClose(tensors[0], tensors_expected[0], 1e-6); } } }; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a3cb0096489..6c045152434 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -19,6 +19,9 @@ load("//tensorflow/core/kernels/mlir_generated:build_defs.bzl", "if_mlir_generat # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cc_header_only_library") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_kernel_library") @@ -511,10 +514,10 @@ tf_cuda_library( hdrs = ["gpu_utils.h"], deps = [ ":gpu_util_hdrs", - "//tensorflow/core:autotuning_proto_cc", - "//tensorflow/core:conv_autotuning_proto_cc", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", + "//tensorflow/core/protobuf:autotuning_proto_cc", + "//tensorflow/core/protobuf:conv_autotuning_proto_cc", "//tensorflow/core/util:env_var", "//tensorflow/core/util/proto:proto_utils", "//tensorflow/stream_executor/gpu:asm_compiler", @@ -556,11 +559,11 @@ tf_cc_test( deps = [ "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", - "//tensorflow/core:direct_session_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:direct_session_internal", ], ) @@ -712,7 +715,7 @@ cc_library( alias( name = "bounds_check", - actual = "//tensorflow/core:framework_bounds_check", + actual = "//tensorflow/core/framework:bounds_check", visibility = [":friends"], ) @@ -890,6 +893,7 @@ cc_library( hdrs = [ "eigen_spatial_convolutions-inl.h", ], + compatible_with = get_compatible_with_portable(), deps = [ ":eigen_convolution_helpers", ], @@ -900,6 +904,7 @@ cc_library( hdrs = [ "eigen_convolution_helpers.h", ], + compatible_with = get_compatible_with_portable(), ) # OpKernel libraries ---------------------------------------------------------- @@ -1100,7 +1105,7 @@ tf_kernel_library( tf_kernel_library( name = "one_hot_op", prefix = "one_hot_op", - deps = ARRAY_DEPS + ["//tensorflow/core:overflow"], + deps = ARRAY_DEPS + ["//tensorflow/core/util:overflow"], ) tf_kernel_library( @@ -1491,7 +1496,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//third_party/eigen3", ], ) @@ -1647,10 +1652,6 @@ tf_cuda_cc_test( name = "conv_ops_test", size = "medium", srcs = ["conv_ops_test.cc"], - tags = [ - "no_cuda11", # b/159664089 - "no_oss", - ], deps = [ ":conv_ops", ":ops_testutil", @@ -3114,7 +3115,7 @@ tf_kernel_library( tf_kernel_library( name = "summary_image_op", prefix = "summary_image_op", - deps = LOGGING_DEPS + ["//tensorflow/core:png_internal"], + deps = LOGGING_DEPS + ["//tensorflow/core/lib/png:png_io"], ) # TODO(b/162630222): remove this target @@ -3851,7 +3852,7 @@ tf_kernel_library( "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/stream_executor:tf_allocator_adapter", "//tensorflow/stream_executor:stream_executor_headers", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ]) + if_cuda_or_rocm([ ":gpu_utils", "//tensorflow/stream_executor/gpu:redzone_allocator", @@ -3989,7 +3990,7 @@ tf_kernel_library( ":reduction_ops", ]) + if_cuda([ "@local_config_cuda//cuda:cub_headers", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/stream_executor/cuda:cuda_stream", ]) + if_rocm([ "@local_config_rocm//rocm:rocprim", @@ -4004,7 +4005,7 @@ tf_kernel_library( ":redux_functor", ":transpose_functor", ] + if_cuda([ - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ]), ) @@ -4224,7 +4225,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//third_party/eigen3", ], ) @@ -6191,6 +6192,7 @@ filegroup( "queue_op.cc", "queue_ops.cc", "ragged_range_op.cc", + "ragged_gather_op.cc", "ragged_tensor_to_sparse_kernel.cc", "ragged_tensor_to_tensor_op.cc", "random_op.cc", @@ -6249,6 +6251,7 @@ filegroup( "string_lower_op.cc", "string_util.cc", "string_split_op.cc", + "string_strip_op.cc", "string_to_hash_bucket_op.cc", "substr_op.cc", "tensor_array.cc", @@ -6443,12 +6446,6 @@ filegroup( ) # LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt) -alias( - name = "android_tensorflow_kernels", - actual = ":portable_tensorflow_kernels", - visibility = ["//visibility:public"], -) - cc_library( name = "portable_tensorflow_kernels", srcs = if_mobile([ @@ -6465,9 +6462,9 @@ cc_library( deps = [ "//tensorflow/core:portable_gif_internal", "//tensorflow/core:portable_jpeg_internal", - "//tensorflow/core:portable_png_internal", "//tensorflow/core:portable_tensorflow_lib_lite", "//tensorflow/core:protos_all_cc_impl", + "//tensorflow/core/lib/png:png_io", "//tensorflow/core/platform:strong_hash", "//third_party/eigen3", "//third_party/fft2d:fft2d_headers", @@ -6630,7 +6627,7 @@ tf_cc_binary( deps = [ ] + select({ "//tensorflow:android": [ - ":android_tensorflow_kernels", + ":portable_tensorflow_kernels", "//tensorflow/core:portable_tensorflow_lib", "//tensorflow/core:portable_tensorflow_test_lib", ], @@ -6691,7 +6688,7 @@ cc_binary( "//tensorflow/cc:client_session", ] + select({ "//tensorflow:android": [ - ":android_tensorflow_kernels", + ":portable_tensorflow_kernels", "//tensorflow/core:portable_tensorflow_lib", "//tensorflow/core:portable_tensorflow_test_lib", ], @@ -6776,7 +6773,7 @@ cc_binary( "//tensorflow/cc:client_session", ] + select({ "//tensorflow:android": [ - ":android_tensorflow_kernels", + ":portable_tensorflow_kernels", "//tensorflow/core:portable_tensorflow_lib", "//tensorflow/core:portable_tensorflow_test_lib", ], @@ -6897,7 +6894,7 @@ cc_binary( "//tensorflow/cc:client_session", ] + select({ "//tensorflow:android": [ - ":android_tensorflow_kernels", + ":portable_tensorflow_kernels", "//tensorflow/core:portable_tensorflow_lib", "//tensorflow/core:portable_tensorflow_test_lib", ], @@ -7039,7 +7036,7 @@ cc_binary( "//tensorflow/cc:client_session", ] + select({ "//tensorflow:android": [ - ":android_tensorflow_kernels", + ":portable_tensorflow_kernels", "//tensorflow/core:portable_tensorflow_lib", "//tensorflow/core:portable_tensorflow_test_lib", ], diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 16de576d2c9..8f233957032 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -249,6 +249,8 @@ cc_library( "//tensorflow/core/kernels/batching_util:threadsafe_status", "//tensorflow/core/platform:status", "//tensorflow/core/platform:thread_annotations", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/core/util:incremental_barrier", ], ) diff --git a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h index cdaf5331e9f..df09bfd6099 100644 --- a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h @@ -141,6 +141,39 @@ class BasicBatchScheduler : public BatchScheduler { // (Keep them mirrored to the ones in SharedBatchScheduler::QueueOptions and // SharedBatchScheduler::Options.) struct Options { + // Options related with (underlying) shared batch scheduler. + // 'thread_pool_name' and 'num_batch_threads' are used to initialize + // a shared batch scheduler underlyingly iff 'shared_batch_scheduler' is + // nullptr. + // + // There are two ways to specify threading: + // 1) Have each session create its own pool. + // 2) Have multiple sessions share the same pool. + // + // In general, the number of threads should be tied to roughly the number of + // compute resources (CPU cores or accelerator cores) backing the threads. + // Sharing a thread pool helps alleviate potential over allocation of + // threads to limited compute resources. + + // To have each session create its own thread pool (1) set + // thread_pool_name/num_batch_threads. + + // To share a thread pool (2) create a scheduler and pass it in. + + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + + // The number of threads to use to process batches. + // Must be >= 1, and should be tuned carefully. + int num_batch_threads = port::MaxParallelism(); + + // If specified, this scheduler will be used underlyingly to schedule + // batches. Note setting this means `thread_pool_name` and + // `num_batch_threads` are ignored. + std::shared_ptr> shared_batch_scheduler = + nullptr; + + // Options for queue. // The maximum size of each batch. // // The scheduler may form batches of any size between 1 and this number @@ -163,12 +196,6 @@ class BasicBatchScheduler : public BatchScheduler { // avoid latency spikes. int64 batch_timeout_micros = 0; - // The name to use for the pool of batch threads. - string thread_pool_name = {"batch_threads"}; - - // The number of threads to use to process batches. - // Must be >= 1, and should be tuned carefully. - int num_batch_threads = port::MaxParallelism(); // The maximum allowable number of enqueued (accepted by Schedule() but // not yet being processed on a batch thread) tasks in terms of batches. @@ -272,13 +299,19 @@ Status BasicBatchScheduler::Create( std::function>)> process_batch_callback, std::unique_ptr* scheduler) { - typename SharedBatchScheduler::Options shared_scheduler_options; - shared_scheduler_options.thread_pool_name = options.thread_pool_name; - shared_scheduler_options.num_batch_threads = options.num_batch_threads; - shared_scheduler_options.env = options.env; std::shared_ptr> shared_scheduler; - TF_RETURN_IF_ERROR(SharedBatchScheduler::Create( - shared_scheduler_options, &shared_scheduler)); + + if (options.shared_batch_scheduler == nullptr) { + typename SharedBatchScheduler::Options shared_scheduler_options; + shared_scheduler_options.thread_pool_name = options.thread_pool_name; + shared_scheduler_options.num_batch_threads = options.num_batch_threads; + shared_scheduler_options.env = options.env; + + TF_RETURN_IF_ERROR(SharedBatchScheduler::Create( + shared_scheduler_options, &shared_scheduler)); + } else { + shared_scheduler = options.shared_batch_scheduler; + } typename SharedBatchScheduler::QueueOptions shared_scheduler_queue_options; diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index 98175b5b9d0..d638760b833 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/concat_split_util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/monitoring/percentile_sampler.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" #include "tensorflow/core/util/incremental_barrier.h" namespace tensorflow { @@ -202,6 +204,11 @@ Status BatchResourceBase::ConcatInputTensors( const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size()); const int padding_amount = padded_batch_size - batch.size(); + profiler::TraceMe trace_me([padded_batch_size, padding_amount]() { + return profiler::TraceMeEncode( + "ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size}, + {"padding_amount", padding_amount}}); + }); RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size); RecordProcessedBatchSize(padded_batch_size, GetModelName(context)); diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 04b84e6054e..9bb853f708e 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -809,9 +809,7 @@ void Queue::ProcessBatch(std::unique_ptr> batch) { profiler::TraceMeConsumer trace_me( [&] { return profiler::TraceMeEncode( - "ProcessBatch", - {{"size", batch->size()}, - {"padding", max_execution_batch_size() - batch->size()}}); + "ProcessBatch", {{"batch_size_before_padding", batch->size()}}); }, profiler::ContextType::kSharedBatchScheduler, batch->traceme_context_id()); diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD index 8b25e00965e..ec7082318b3 100644 --- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD +++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD @@ -10,8 +10,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "weighted_quantiles_hdrs", srcs = [ diff --git a/tensorflow/core/kernels/collective_nccl_test.cc b/tensorflow/core/kernels/collective_nccl_test.cc index e288ba66804..f7725151d8a 100644 --- a/tensorflow/core/kernels/collective_nccl_test.cc +++ b/tensorflow/core/kernels/collective_nccl_test.cc @@ -136,15 +136,14 @@ class NcclTestBase : public ::testing::Test { col_params_.instance.data_type = DT_FLOAT; col_params_.instance.impl_details.collective_name = collective_name_; const string task_name = "/job:worker/replica:0/task:0"; - col_params_.instance.num_devices_per_task[task_name] = num_ranks; + col_params_.group.num_devices_per_task[task_name] = num_ranks; for (int rank = 0; rank < num_ranks; ++rank) { - col_params_.instance.device_names.push_back( - device_names[rank % num_gpus]); - col_params_.instance.task_names.push_back(task_name); + col_params_.group.device_names.push_back(device_names[rank % num_gpus]); + col_params_.group.task_names.push_back(task_name); } for (int rank = 0; rank < num_ranks; ++rank) { instances_.push_back(absl::make_unique( - rank, col_params_.instance.device_names[rank], this)); + rank, col_params_.group.device_names[rank], this)); } } @@ -251,9 +250,7 @@ class NcclTestBase : public ::testing::Test { << parent_->dev_mgr_->DebugString(); col_params_.name = parent_->col_params_.name; col_params_.default_rank = rank; - col_params_.group.group_key = parent_->col_params_.group.group_key; - col_params_.group.device_type = parent_->col_params_.group.device_type; - col_params_.group.group_size = parent_->col_params_.group.group_size; + col_params_.group = parent_->col_params_.group; col_params_.instance = parent->col_params_.instance; } diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index eed2c4c1d54..a3db45dfea6 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -64,8 +64,7 @@ class CollectiveOpKernel : public AsyncOpKernel { // immediately. bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec, const DoneCallback& done) { - if (col_params_.group.group_size > - col_params_.instance.device_names.size()) { + if (col_params_.group.group_size > col_params_.group.device_names.size()) { // This is the first invocation: Finish initializing col_params_. // Schedule the `CompleteParamsAsync` call on a work queue that can handle // blocking work because it's not guaranteed that this call cannot block. @@ -457,6 +456,9 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { OP_REQUIRES_OK( c, c->GetAttr("communication_hint", &col_params_->instance.impl_details.communication_hint)); + OP_REQUIRES_OK( + c, c->GetAttr("timeout_seconds", + &col_params_->instance.impl_details.timeout_seconds)); // Prepare OpKernels for reduction and final operations. // The merge_op takes two inputs NodeDef sub_node; @@ -511,7 +513,8 @@ class CollectiveReduceV2OpKernel : public AsyncOpKernel { col_params->instance.data_type = col_params_->instance.data_type; col_params->instance.impl_details.communication_hint = col_params_->instance.impl_details.communication_hint; - col_params->instance.impl_details.timeout_seconds = 0; + col_params->instance.impl_details.timeout_seconds = + col_params_->instance.impl_details.timeout_seconds; col_params->instance.impl_details.subdiv_offsets = col_params_->instance.impl_details.subdiv_offsets; col_params->merge_op = std::move(col_params_->merge_op); diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 02a82892fed..a78af454f09 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -149,21 +149,9 @@ class SelectV2Op : public OpKernel { // The `cond`, `then`, and `else` are broadcastable (bcast.IsValid()), // This matches the behavior of numpy. - // TODO (yongtang): Consolidate into n-ary broadcast, instead of multiple - // 2-ary broadcast. - - // Combine `then` and `else`. - BCast then_else_bcast(BCast::FromShape(then->shape()), - BCast::FromShape(else_->shape()), false); - OP_REQUIRES(ctx, then_else_bcast.IsValid(), - errors::InvalidArgument( - "then ", then->shape().DebugString(), " and else ", - else_->shape().DebugString(), " must be broadcastable")); - // Combine `cond` with `then` and `else`. - BCast bcast( - BCast::FromShape(cond->shape()), - BCast::FromShape(BCast::ToShape(then_else_bcast.output_shape())), - false); + BCastList<3> bcast({cond->shape().dim_sizes(), then->shape().dim_sizes(), + else_->shape().dim_sizes()}, + false); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( "condition ", cond->shape().DebugString(), ", then ", @@ -172,12 +160,9 @@ class SelectV2Op : public OpKernel { // Broadcast `cond`, `then` and `else` to combined shape, // in order to obtain the reshape. - BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), - BCast::FromShape(cond->shape()), false); - BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), - BCast::FromShape(then->shape()), false); - BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), - BCast::FromShape(else_->shape()), false); + BCast cond_bcast(bcast.output_shape(), cond->shape().dim_sizes(), false); + BCast then_bcast(bcast.output_shape(), then->shape().dim_sizes(), false); + BCast else_bcast(bcast.output_shape(), else_->shape().dim_sizes(), false); OP_REQUIRES( ctx, cond_bcast.IsValid() && then_bcast.IsValid() && else_bcast.IsValid(), diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 681bc1f7c35..22c93dc5306 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -465,9 +465,16 @@ Status MakeIteratorFromInputElement( GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); // Create an iterator for the dataset that was returned by `f`. - return returned_dataset->MakeIterator( - ctx, parent, strings::StrCat(prefix, "[", thread_index, "]"), - out_iterator); + std::string iterator_prefix = strings::StrCat(prefix, "[", thread_index, "]"); + if (ctx->split_provider() == nullptr) { + return returned_dataset->MakeIterator(ctx, parent, iterator_prefix, + out_iterator); + } + // Strip out the split provider so that it doesn't apply to sub-iterators. + IteratorContext::Params params(ctx); + params.split_provider = nullptr; + return returned_dataset->MakeIterator(IteratorContext(std::move(params)), + parent, iterator_prefix, out_iterator); } /* static */ diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 1e5af931cdd..c3ef8122ebb 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -132,12 +132,12 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler/optimizers/data:graph_utils", "//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:name_utils", "//tensorflow/core/kernels/data:serialization_utils", + "//tensorflow/core/platform:regexp", ], ) @@ -307,7 +307,7 @@ tf_cc_test( name = "lmdb_dataset_op_test", size = "small", srcs = ["lmdb_dataset_op_test.cc"], - data = ["//tensorflow/core:lmdb_testdata"], + data = ["//tensorflow/core/lib/lmdb:lmdb_testdata"], deps = [ ":lmdb_dataset_op", "//tensorflow/core:experimental_dataset_ops_op_lib", diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index 53e20c12449..01044957857 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -108,6 +108,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash, const std::string& path, const std::string& compression, + const std::string& reader_prefix, const std::string& writer_prefix, std::unique_ptr reader_func, std::unique_ptr shard_func); @@ -138,6 +139,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { const uint64 hash_; const tstring path_; const std::string compression_; + const std::string reader_prefix_; + const std::string writer_prefix_; std::unique_ptr reader_func_; std::unique_ptr shard_func_; @@ -294,6 +297,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough SnapshotDatasetV2Op::Dataset::Dataset( OpKernelContext* ctx, const DatasetBase* input, uint64 hash, const std::string& path, const std::string& compression, + const std::string& reader_prefix, const std::string& writer_prefix, std::unique_ptr reader_func, std::unique_ptr shard_func) : DatasetBase(DatasetContext(ctx)), @@ -302,6 +306,8 @@ SnapshotDatasetV2Op::Dataset::Dataset( path_(path), compression_(compression == kCompressionAuto ? io::compression::kSnappy : compression), + reader_prefix_(reader_prefix), + writer_prefix_(writer_prefix), reader_func_(std::move(reader_func)), shard_func_(std::move(shard_func)) { input_->Ref(); @@ -363,6 +369,12 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal( AttrValue compression_attr; b->BuildAttrValue(compression_, &compression_attr); + AttrValue reader_prefix_attr; + b->BuildAttrValue(reader_prefix_, &reader_prefix_attr); + + AttrValue writer_prefix_attr; + b->BuildAttrValue(writer_prefix_, &writer_prefix_attr); + AttrValue reader_func_attr; b->BuildAttrValue(reader_func_->func(), &reader_func_attr); @@ -386,6 +398,8 @@ Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal( std::make_pair(3, shard_func_other_args)}, /*attrs=*/ {{kCompression, compression_attr}, + {kReaderPrefix, reader_prefix_attr}, + {kWriterPrefix, writer_prefix_attr}, {kReaderFunc, reader_func_attr}, {kShardFunc, shard_func_attr}, {kReaderFuncTarguments, reader_func_arguments_types_attr}, @@ -401,7 +415,8 @@ SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params) Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize( IteratorContext* ctx) { - return ctx->env()->RecursivelyCreateDir(hash_dir_); + return ctx->env()->RecursivelyCreateDir( + io::JoinPath(dataset()->writer_prefix_, hash_dir_)); } Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal( @@ -460,7 +475,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator( experimental::SnapshotMetadataRecord metadata; bool file_exists; TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile( - ctx->env(), hash_dir_, &metadata, &file_exists)); + ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_), + &metadata, &file_exists)); if (!file_exists) { return errors::DataLoss("Snapshot metadata file in ", hash_dir_, " does not exist any more."); @@ -476,7 +492,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator( experimental::SnapshotMetadataRecord metadata; bool file_exists; TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile( - ctx->env(), hash_dir_, &metadata, &file_exists)); + ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_), + &metadata, &file_exists)); // `pending_snapshot_expiry_seconds` is a legacy option where we would not // write snapshots that we think were still on-going. We decided that this @@ -522,8 +539,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize( TF_RETURN_IF_ERROR( dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_)); - auto hash_dir = - snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_); + auto hash_dir = snapshot_util::HashDirectory( + io::JoinPath(dataset()->reader_prefix_, dataset()->path_), + dataset()->hash_); bool metadata_file_exists; experimental::SnapshotMetadataRecord metadata; TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile( @@ -624,8 +642,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile( metadata.add_dtype(output_dtype); } metadata.set_finalized(finalized); - tstring hash_directory = - snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_); + tstring hash_directory = io::JoinPath( + dataset()->writer_prefix_, + snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_)); return snapshot_util::WriteMetadataFile(env, hash_directory, &metadata); } @@ -678,7 +697,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal( // Creates the run directory. run_dir_ = snapshot_util::RunDirectory( - snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_), + snapshot_util::HashDirectory( + io::JoinPath(dataset()->writer_prefix_, dataset()->path_), + dataset()->hash_), run_id_); TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_)); TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false)); @@ -757,7 +778,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal( run_id_ = static_cast(run_id_signed); run_dir_ = snapshot_util::RunDirectory( - snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_), + snapshot_util::HashDirectory( + io::JoinPath(dataset()->writer_prefix_, dataset()->path_), + dataset()->hash_), run_id_); current_checkpoint_id_ = static_cast(current_checkpoint_id); @@ -798,6 +821,14 @@ SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx) OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_)); + if (ctx->HasAttr(kReaderPrefix)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kReaderPrefix, &reader_prefix_)); + } + + if (ctx->HasAttr(kWriterPrefix)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kWriterPrefix, &writer_prefix_)); + } + OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params, &reader_func_metadata_)); OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params, @@ -821,6 +852,9 @@ void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input, ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def)); OP_REQUIRES_OK(ctx, HashGraph(graph_def, &graph_hash)); + // Different compression modes should result in different graph hashes. + graph_hash = Hash64Combine(graph_hash, Hash64(compression_)); + std::unique_ptr reader_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, reader_func_metadata_, @@ -831,8 +865,8 @@ void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input, kShardFuncOtherArgs, &shard_func)); *output = new SnapshotDatasetV2Op::Dataset( - ctx, input, graph_hash, path, compression_, std::move(reader_func), - std::move(shard_func)); + ctx, input, graph_hash, path, compression_, reader_prefix_, + writer_prefix_, std::move(reader_func), std::move(shard_func)); } namespace { diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h index 041819205c8..d7097f43190 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h @@ -44,6 +44,8 @@ class SnapshotDatasetV2Op : public UnaryDatasetOpKernel { static constexpr const char* const kOutputTypes = "output_types"; static constexpr const char* const kOutputShapes = "output_shapes"; static constexpr const char* const kCompression = "compression"; + static constexpr const char* const kReaderPrefix = "reader_prefix"; + static constexpr const char* const kWriterPrefix = "writer_prefix"; static constexpr const char* const kCompressionAuto = "AUTO"; static constexpr const char* const kReaderFunc = "reader_func"; static constexpr const char* const kShardFunc = "shard_func"; @@ -71,6 +73,8 @@ class SnapshotDatasetV2Op : public UnaryDatasetOpKernel { std::vector output_shapes_; std::string compression_; + std::string reader_prefix_; + std::string writer_prefix_; std::shared_ptr reader_func_metadata_; std::shared_ptr shard_func_metadata_; diff --git a/tensorflow/core/kernels/data/hash_utils.cc b/tensorflow/core/kernels/data/hash_utils.cc index 4518a525761..688d6d847c9 100644 --- a/tensorflow/core/kernels/data/hash_utils.cc +++ b/tensorflow/core/kernels/data/hash_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_util.h" @@ -160,6 +161,9 @@ class GraphHasher { while (!bfs_queue.empty()) { const NodeDef* node = bfs_queue.front(); bfs_queue.pop(); + if (visited.contains(node->name())) { + continue; + } visited.insert(node->name()); NodeRep node_rep; for (int i = 0; i < node->input_size(); ++i) { @@ -253,7 +257,7 @@ class GraphHasher { if (!s.ok()) { return errors::FailedPrecondition("Nodes ", this_node->name(), " and ", that_node->name(), - " are not the same: ", s); + " are not the same:\n", s); } return s; } @@ -297,15 +301,20 @@ class GraphHasher { Status HashNodeNonInput(const NodeDef* node, bool hash_functions, uint64* hash) { - // Hash Op. - uint64 op_hash = Hash64(node->op()); - // Hash Attrs. We get the list of attrs from the op registry and then look // up their values in the NodeDef attr map. This avoids looping over // a map which is non-deterministic. uint64 attrs_hash = 0; const OpRegistrationData* reg; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(node->op(), ®)); + TF_RETURN_IF_ERROR(flib_->LookUp(node->op(), ®)); + uint64 op_hash = 0; + if (reg->is_function_op) { + if (hash_functions) { + TF_RETURN_IF_ERROR(HashFunction(node->op(), node->attr(), &op_hash)); + } + } else { + op_hash = Hash64(node->op()); + } for (const auto& attr : reg->op_def.attr()) { const auto& attr_key = attr.name(); @@ -331,17 +340,24 @@ class GraphHasher { Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that, const NodeDef* that_node, bool compare_functions) { - if (this_node->op() != that_node->op()) { - return errors::FailedPrecondition( - "ops for nodes ", this_node->name(), " and ", that_node->name(), - " are different: ", this_node->op(), " != ", that_node->op()); - } - // We get the list of attrs from the op registry and then look // up their values in the NodeDef attr map. This avoids looping over // a map which is non-deterministic. const OpRegistrationData* reg; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(this_node->op(), ®)); + TF_RETURN_IF_ERROR(flib_->LookUp(this_node->op(), ®)); + if (reg->is_function_op) { + if (compare_functions) { + TF_RETURN_IF_ERROR( + CheckFunctionsEqual(this_node->op(), this_node->attr(), that, + that_node->op(), that_node->attr())); + } + } else { + if (this_node->op() != that_node->op()) { + return errors::FailedPrecondition( + "ops for nodes ", this_node->name(), " and ", that_node->name(), + " are different: ", this_node->op(), " != ", that_node->op()); + } + } for (const auto& attr : reg->op_def.attr()) { const auto& attr_key = attr.name(); @@ -443,12 +459,17 @@ class GraphHasher { } Status HashFunction(const NameAttrList& func, uint64* hash) { - const FunctionDef* fdef = flib_->Find(func.name()); + return HashFunction(func.name(), func.attr(), hash); + } + + Status HashFunction(const std::string& name, const AttrValueMap& attrs, + uint64* hash) { + const FunctionDef* fdef = flib_->Find(name); // Convert to a GraphDef. std::unique_ptr fbody; TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()), flib_, &fbody)); + FunctionDefToBodyHelper(*fdef, AttrSlice(&attrs), flib_, &fbody)); GraphDef graph_def = fbody->graph->ToGraphDefDebug(); // For each return node, we create a new GraphHasher to compute a hash. @@ -476,35 +497,44 @@ class GraphHasher { Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that, const NameAttrList& that_func) { - Status s = CheckFunctionsEqualHelper(this_func, that, that_func); + return CheckFunctionsEqual(this_func.name(), this_func.attr(), that, + that_func.name(), that_func.attr()); + } + Status CheckFunctionsEqual(const std::string& this_name, + const AttrValueMap& this_attrs, GraphHasher* that, + const std::string& that_name, + const AttrValueMap& that_attrs) { + Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, that_name, + that_attrs); if (!s.ok()) { - return errors::FailedPrecondition("Functions ", this_func.name(), " and ", - that_func.name(), - " are not the same: ", s); + return errors::FailedPrecondition("Functions ", this_name, " and ", + that_name, " are not the same:\n", s); } return s; } - Status CheckFunctionsEqualHelper(const NameAttrList& this_func, + Status CheckFunctionsEqualHelper(const std::string& this_name, + const AttrValueMap& this_attrs, GraphHasher* that, - const NameAttrList& that_func) { - const FunctionDef* this_fdef = flib_->Find(this_func.name()); - const FunctionDef* that_fdef = that->flib_->Find(that_func.name()); + const std::string& that_name, + const AttrValueMap& that_attrs) { + const FunctionDef* this_fdef = flib_->Find(this_name); + const FunctionDef* that_fdef = that->flib_->Find(that_name); // Convert to GraphDefs. std::unique_ptr this_fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *this_fdef, AttrSlice(&this_func.attr()), flib_, &this_fbody)); + *this_fdef, AttrSlice(&this_attrs), flib_, &this_fbody)); GraphDef this_graph_def = this_fbody->graph->ToGraphDefDebug(); std::unique_ptr that_fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *that_fdef, AttrSlice(&that_func.attr()), that->flib_, &that_fbody)); + *that_fdef, AttrSlice(&that_attrs), that->flib_, &that_fbody)); GraphDef that_graph_def = that_fbody->graph->ToGraphDefDebug(); if (this_fbody->ret_nodes.size() != that_fbody->ret_nodes.size()) { return errors::FailedPrecondition( - "Different numbers of ret nodes for functions ", this_func.name(), - " and ", that_func.name(), ": ", this_fbody->ret_nodes.size(), " vs ", + "Different numbers of ret nodes for functions ", this_name, " and ", + that_name, ": ", this_fbody->ret_nodes.size(), " vs ", that_fbody->ret_nodes.size()); } for (int i = 0; i < this_fbody->ret_nodes.size(); ++i) { diff --git a/tensorflow/core/kernels/data/hash_utils_test.cc b/tensorflow/core/kernels/data/hash_utils_test.cc index 3bb55006fa6..f61ab2a9c44 100644 --- a/tensorflow/core/kernels/data/hash_utils_test.cc +++ b/tensorflow/core/kernels/data/hash_utils_test.cc @@ -603,6 +603,106 @@ TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionListsDifferentNames) { TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3)); } +TEST_F(DatasetHashUtilsTest, HashNodeSameFunctionsOps) { + GraphDef gd; + + FunctionDefLibrary* fl1 = gd.mutable_library(); + FunctionDef* f1 = fl1->add_function(); + + FunctionDef func = FunctionDefHelper::Create( + "AddAndMul", {"i: float"}, {"o: float"}, {}, + {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}}, + {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}}, + /*ret_def=*/{{"o", "ret:z:0"}}, + /*control_ret_def=*/{{"must_execute", "add"}}); + *f1 = func; + + FunctionDef* f2 = fl1->add_function(); + func = FunctionDefHelper::Create( + "AddAndMul2", {"i: float"}, {"o: float"}, {}, + {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}}, + {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}}, + /*ret_def=*/{{"o", "ret:z:0"}}, + /*control_ret_def=*/{{"must_execute", "add"}}); + *f2 = func; + FunctionLibraryDefinition flib(OpRegistry::Global(), gd.library()); + + NodeDef* n1 = gd.add_node(); + TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const") + .Attr("value", 1) + .Device("CPU:0") + .Finalize(n1)); + + NodeDef* n2 = gd.add_node(); + TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul", &flib) + .Input(n1->name(), 0, DT_FLOAT) + .Device("CPU:0") + .Finalize(n2)); + + NodeDef* n3 = gd.add_node(); + TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul2", &flib) + .Input(n1->name(), 0, DT_FLOAT) + .Device("CPU:0") + .Finalize(n3)); + + uint64 hash1 = GetHash(gd, *n2); + uint64 hash2 = GetHash(gd, *n3); + EXPECT_EQ(hash1, hash2); + TF_EXPECT_OK(CheckSubgraphsEqual(gd, n2, gd, n3)); +} + +TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctionsOps) { + GraphDef gd; + + FunctionDefLibrary* fl1 = gd.mutable_library(); + FunctionDef* f1 = fl1->add_function(); + + FunctionDef func = FunctionDefHelper::Create( + "AddAndMul", {"i: float"}, {"o: float"}, {}, + {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}}, + {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}}, + /*ret_def=*/{{"o", "ret:z:0"}}, + /*control_ret_def=*/{{"must_execute", "add"}}); + *f1 = func; + + FunctionDef* f2 = fl1->add_function(); + func = FunctionDefHelper::Create( + "AddAndMul2", {"i: float"}, {"o: float"}, {}, + {{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}}, + {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}}, + /*ret_def=*/{{"o", "ret:z:0"}}, + /*control_ret_def=*/{{"must_execute", "ret"}}); + *f2 = func; + FunctionLibraryDefinition flib(OpRegistry::Global(), gd.library()); + + NodeDef* n1 = gd.add_node(); + TF_CHECK_OK(NodeDefBuilder("graph_1/node_1", "Const") + .Attr("value", 1) + .Device("CPU:0") + .Finalize(n1)); + + NodeDef* n2 = gd.add_node(); + TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul", &flib) + .Input(n1->name(), 0, DT_FLOAT) + .Device("CPU:0") + .Finalize(n2)); + + NodeDef* n3 = gd.add_node(); + TF_CHECK_OK(NodeDefBuilder("graph_1/node_2", "AddAndMul2", &flib) + .Input(n1->name(), 0, DT_FLOAT) + .Device("CPU:0") + .Finalize(n3)); + + uint64 hash1 = GetHash(gd, *n2); + uint64 hash2 = GetHash(gd, *n3); + EXPECT_NE(hash1, hash2); + Status s = CheckSubgraphsEqual(gd, n2, gd, n3); + EXPECT_NE(s.code(), error::OK); + EXPECT_THAT( + s.error_message(), + ContainsRegex("Functions AddAndMul and AddAndMul2 are not the same")); +} + TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctions) { GraphDef gd; diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 61383789f60..6e079937885 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -83,7 +83,9 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, // The map that stores the live experiment names and for how much percentage // of the Borg jobs, the experiments will be randomly turned on. // clang-format off - absl::flat_hash_map live_experiments; + absl::flat_hash_map live_experiments = { + {"enable_gradient_descent", 5} + }; // clang-format on auto hash_func = [](const string& str) { return Hash64(str); }; optimizations = SelectOptimizations( @@ -112,15 +114,13 @@ void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, std::vector graduated_experiments = {"disable_intra_op_parallelism"}; // clang-format on - // Add the graduated experiments to the optimization list. Also log and - // record. + // Add the graduated experiments to the optimization list and log them. for (auto& experiment : graduated_experiments) { if (std::find(optimizations.begin(), optimizations.end(), experiment) == optimizations.end()) { optimizations.push_back(experiment); } VLOG(1) << "The graduated experiment \"" << experiment << "\" is applied."; - metrics::RecordTFDataExperiment(experiment); } // If there are no optimizations to be applied, directly return the input. diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 87ea4531d5d..8c1a49d9dda 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -405,7 +405,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { TF_LOCKS_EXCLUDED(*mu_) { mutex_lock l(*mu_); num_calls_--; - RecordBufferEnqueue(ctx.get(), result->return_values); result->notification.Notify(); cond_var_->notify_all(); } @@ -428,6 +427,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { auto done = [this, ctx, result](Status status) { result->status.Update(status); + RecordBufferEnqueue(ctx.get(), result->return_values); CallCompleted(ctx, result); }; diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc index b76a0e5714a..db7b186c9f3 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc @@ -20,12 +20,11 @@ limitations under the License. namespace tensorflow { namespace data { -PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size, - int64 buffer_size_min) +PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size) : buffer_limit_(initial_buffer_size) { if (initial_buffer_size == model::kAutotune) { mode_ = Mode::kUpswing; - buffer_limit_ = std::max(int64{1}, buffer_size_min); + buffer_limit_ = 1; } } diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h index 28a243e6261..3257dcdea1a 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.h +++ b/tensorflow/core/kernels/data/prefetch_autotuner.h @@ -39,7 +39,7 @@ namespace data { // PrefetchAutotuner is NOT thread safe. class PrefetchAutotuner { public: - explicit PrefetchAutotuner(int64 initial_buffer_size, int64 buffer_size_min); + explicit PrefetchAutotuner(int64 initial_buffer_size); int64 buffer_limit() const { return buffer_limit_; } diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc index a5d7ae6fa3c..e9d291d22f7 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc @@ -23,7 +23,7 @@ namespace data { namespace { TEST(PrefetchAutotuner, Disabled) { - PrefetchAutotuner t(2, 0); + PrefetchAutotuner t(2); EXPECT_EQ(2, t.buffer_limit()); t.RecordConsumption(0); t.RecordConsumption(2); @@ -33,7 +33,7 @@ TEST(PrefetchAutotuner, Disabled) { } TEST(PrefetchAutotuner, Enabled) { - PrefetchAutotuner t(model::kAutotune, 0); + PrefetchAutotuner t(model::kAutotune); EXPECT_EQ(1, t.buffer_limit()); t.RecordConsumption(0); // Expect buffer limit to stay the same. EXPECT_EQ(1, t.buffer_limit()); @@ -58,7 +58,7 @@ TEST(PrefetchAutotuner, Enabled) { } TEST(PrefetchAutotuner, EnabledSteady) { - PrefetchAutotuner t(model::kAutotune, 0); + PrefetchAutotuner t(model::kAutotune); EXPECT_EQ(1, t.buffer_limit()); t.RecordConsumption(0); // Expect buffer limit to stay the same! EXPECT_EQ(1, t.buffer_limit()); @@ -80,29 +80,6 @@ TEST(PrefetchAutotuner, EnabledSteady) { } } -TEST(PrefetchAutotuner, StartWithMin) { - PrefetchAutotuner t(model::kAutotune, 2); - EXPECT_EQ(2, t.buffer_limit()); - t.RecordConsumption(0); // Expect buffer limit to stay the same! - EXPECT_EQ(2, t.buffer_limit()); - t.RecordConsumption(2); - EXPECT_EQ(2, t.buffer_limit()); - t.RecordConsumption(0); // Expect buffer limit to increase. - EXPECT_EQ(4, t.buffer_limit()); - t.RecordConsumption(4); // Expect buffer limit to stay the same! - EXPECT_EQ(4, t.buffer_limit()); - t.RecordConsumption(0); // Expect buffer limit to increase. - EXPECT_EQ(8, t.buffer_limit()); - - // Never reach zero again. - std::vector consumption_values = {3, 5, 7, 1, 4, 6, 8, 3, 5, 1, 2, 4}; - for (int i = 0; i < consumption_values.size(); ++i) { - t.RecordConsumption(consumption_values[i]); - EXPECT_EQ(8, t.buffer_limit()) - << "Failed at index " << i << " with value: " << consumption_values[i]; - } -} - } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 792de71f98c..4a55514ffd1 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -45,7 +45,6 @@ namespace data { /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes; /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod; /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune; -/* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin; // Determines the fraction of slack time by which to delay prefetching of data. constexpr double kSleepFactor = 0.2; @@ -58,13 +57,12 @@ constexpr char kErrorMessageSuffix[] = ".error_message"; class PrefetchDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, - int64 slack_period, bool legacy_autotune, int64 buffer_size_min) + int64 slack_period, bool legacy_autotune) : DatasetBase(DatasetContext(ctx)), input_(input), buffer_size_(buffer_size), slack_period_(slack_period), - legacy_autotune_(legacy_autotune), - buffer_size_min_(buffer_size_min) { + legacy_autotune_(legacy_autotune) { input_->Ref(); } @@ -111,14 +109,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { b->BuildAttrValue(slack_period_, &slack_period_attr); AttrValue legacy_autotune_attr; b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr); - AttrValue buffer_size_min_attr; - b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr); - TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph_node, buffer_size}, {std::make_pair(kSlackPeriod, slack_period_attr), - std::make_pair(kLegacyAutotune, legacy_autotune_attr), - std::make_pair(kBufferSizeMin, buffer_size_min_attr)}, + std::make_pair(kLegacyAutotune, legacy_autotune_attr)}, output)); return Status::OK(); } @@ -130,11 +124,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { : DatasetIterator(params), mu_(std::make_shared()), cond_var_(std::make_shared()), - buffer_size_min_(params.dataset->buffer_size_min_), - auto_tuner_(params.dataset->buffer_size_, buffer_size_min_), + auto_tuner_(params.dataset->buffer_size_), legacy_autotune_(params.dataset->legacy_autotune_), buffer_size_(std::make_shared( - params.dataset->buffer_size_, mu_, cond_var_)) { + legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_, + cond_var_)) { slack_us_ = 0; } @@ -146,7 +140,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (buffer_size_->value == model::kAutotune) { - buffer_size_->value = buffer_size_min_; + buffer_size_->value = 0; } TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { CancelThreads(); }, @@ -219,8 +213,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { return model::MakeAsyncKnownRatioNode( std::move(args), /*ratio=*/1, - {model::MakeParameter(kBufferSize, buffer_size_, - /*min=*/buffer_size_min_, + {model::MakeParameter(kBufferSize, buffer_size_, /*min=*/0, /*max=*/std::numeric_limits::max())}); } @@ -524,7 +517,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_); std::unique_ptr input_impl_ TF_GUARDED_BY(input_mu_); const std::shared_ptr cond_var_; - const int64 buffer_size_min_; PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_); std::deque buffer_ TF_GUARDED_BY(*mu_); std::unique_ptr prefetch_thread_ TF_GUARDED_BY(*mu_); @@ -550,10 +542,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { // Determines whether legacy autotuning should be used. const bool legacy_autotune_ = true; - // If autotune is enabled, determines the minimal value of `buffer_size` - // parameter. - const int64 buffer_size_min_ = 0; - TraceMeMetadata traceme_metadata_; }; @@ -565,9 +553,6 @@ PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx) if (ctx->HasAttr(kLegacyAutotune)) { OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_)); } - if (ctx->HasAttr(kBufferSizeMin)) { - OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_)); - } } void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, @@ -584,8 +569,8 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, metrics::RecordTFDataAutotune(kDatasetType); } - *output = new Dataset(ctx, input, buffer_size, slack_period_, - legacy_autotune_, buffer_size_min_); + *output = + new Dataset(ctx, input, buffer_size, slack_period_, legacy_autotune_); } namespace { diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h index 227bd3f5113..999f002bf16 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.h +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h @@ -31,7 +31,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { static constexpr const char* const kOutputShapes = "output_shapes"; static constexpr const char* const kSlackPeriod = "slack_period"; static constexpr const char* const kLegacyAutotune = "legacy_autotune"; - static constexpr const char* const kBufferSizeMin = "buffer_size_min"; explicit PrefetchDatasetOp(OpKernelConstruction* ctx); @@ -43,7 +42,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { class Dataset; int64 slack_period_ = 0; bool legacy_autotune_ = true; - int64 buffer_size_min_ = 0; }; } // namespace data diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc index ed466ddd550..737008bb4c9 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc @@ -60,8 +60,7 @@ class PrefetchDatasetParams : public DatasetParams { attr_vector->emplace_back(PrefetchDatasetOp::kSlackPeriod, slack_period_); attr_vector->emplace_back(PrefetchDatasetOp::kLegacyAutotune, legacy_autotune_); - attr_vector->emplace_back(PrefetchDatasetOp::kBufferSizeMin, - buffer_size_min_); + attr_vector->emplace_back("buffer_size_min", buffer_size_min_); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 4e239d0895c..69ad7ea3bb5 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -155,14 +155,17 @@ class WindowDatasetOp::Dataset : public DatasetBase { std::vector> window_elements; Status status = Status::OK(); { + const size_t target_size = TargetBufferSize(window_size, window_stride); + mutex_lock l(mu_); - if (!input_impl_ && buffer_.empty()) { + if (!input_impl_ && + (buffer_.empty() || + (dataset()->drop_remainder_ && buffer_.size() < target_size))) { *end_of_sequence = true; return Status::OK(); } // Add elements to the buffer. - size_t target_size = TargetBufferSize(window_size, window_stride); if (input_impl_) { *end_of_sequence = false; for (size_t i = buffer_.size(); i < target_size && !*end_of_sequence; diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h index 7db4a69a8b3..b9c9e549b5d 100644 --- a/tensorflow/core/kernels/eigen_pooling.h +++ b/tensorflow/core/kernels/eigen_pooling.h @@ -131,8 +131,8 @@ SpatialMaxPooling(const Input& input, DenseIndex patchRows, .extract_image_patches( patchRows, patchCols, strideRows, strideCols, in_strideRows, in_strideCols, padding_type, - -Eigen::NumTraits::Scalar>::type>::highest()) + Eigen::NumTraits::Scalar>::type>::lowest()) .maximum(reduction_dims) .reshape(post_reduce_dims); } diff --git a/tensorflow/core/kernels/fingerprint_op.cc b/tensorflow/core/kernels/fingerprint_op.cc index 340dcf111a5..89bd54009a9 100644 --- a/tensorflow/core/kernels/fingerprint_op.cc +++ b/tensorflow/core/kernels/fingerprint_op.cc @@ -91,7 +91,12 @@ class FingerprintOp : public OpKernel { input.shape())); const int64 dim0 = input.shape().dim_size(0); - const int64 dim1 = input.shape().num_elements() / dim0; + int64 dim1; + if (dim0 == 0) { + dim1 = 0; + } else { + dim1 = input.shape().num_elements() / dim0; + } Tensor* output; OP_REQUIRES_OK(context, diff --git a/tensorflow/core/kernels/fingerprint_op_test.cc b/tensorflow/core/kernels/fingerprint_op_test.cc index e4824775902..2925dee56d9 100644 --- a/tensorflow/core/kernels/fingerprint_op_test.cc +++ b/tensorflow/core/kernels/fingerprint_op_test.cc @@ -61,6 +61,15 @@ class FingerprintOpTest : public OpsTestBase { Tensor method_; }; +TEST_F(FingerprintOpTest, Empty) { + Tensor tensor(DT_UINT8, {0}); + + TF_ASSERT_OK(MakeFingerprintOp(&tensor)); + TF_ASSERT_OK(RunOpKernel()); + EXPECT_EQ(GetOutput(0)->shape(), (TensorShape{0, 8})); + EXPECT_EQ(GetOutput(0)->tensor_data(), ""); +} + // This test detects changes in fingerprint method. TEST_F(FingerprintOpTest, GoldenValue) { Tensor tensor(DT_UINT8, {1, 3, 4, 5, 6, 7}); diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 00ac9be6dcd..d8e58093b07 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -1241,15 +1241,15 @@ class FusedBatchNormOpBase : public OpKernel { // If use_reserved_space is false, we don't have 5th output. virtual void ComputeWithReservedSpace(OpKernelContext* context, bool use_reserved_space) { - const Tensor& x = context->input(0); + Tensor x = context->input(0); const Tensor& scale = context->input(1); const Tensor& offset = context->input(2); const Tensor& estimated_mean = context->input(3); const Tensor& estimated_variance = context->input(4); const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr; - OP_REQUIRES(context, x.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", + OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5, + errors::InvalidArgument("input must be 4 or 5-dimensional", x.shape().DebugString())); OP_REQUIRES(context, scale.dims() == 1, errors::InvalidArgument("scale must be 1-dimensional", @@ -1264,6 +1264,21 @@ class FusedBatchNormOpBase : public OpKernel { context, estimated_variance.dims() == 1, errors::InvalidArgument("estimated_variance must be 1-dimensional", estimated_variance.shape().DebugString())); + bool use_reshape = (x.dims() == 5); + auto x_shape = x.shape(); + TensorShape dest_shape; + if (use_reshape) { + const int64 in_batch = GetTensorDim(x, tensor_format_, 'N'); + int64 in_planes = GetTensorDim(x, tensor_format_, '0'); + int64 in_rows = GetTensorDim(x, tensor_format_, '1'); + int64 in_cols = GetTensorDim(x, tensor_format_, '2'); + const int64 in_depth = GetTensorDim(x, tensor_format_, 'C'); + dest_shape = ShapeFromFormat(tensor_format_, in_batch, + {{in_planes, in_rows * in_cols}}, in_depth); + OP_REQUIRES(context, x.CopyFrom(x, dest_shape), + errors::InvalidArgument("Error during tensor copy.")); + } + if (has_side_input_) { OP_REQUIRES(context, side_input->shape() == x.shape(), errors::InvalidArgument( @@ -1282,8 +1297,10 @@ class FusedBatchNormOpBase : public OpKernel { } Tensor* y = nullptr; + auto alloc_shape = use_reshape ? dest_shape : x_shape; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {0}, 0, x.shape(), &y)); + {0}, 0, alloc_shape, &y)); + Tensor* batch_mean = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {3}, 1, scale.shape(), &batch_mean)); @@ -1310,6 +1327,10 @@ class FusedBatchNormOpBase : public OpKernel { batch_mean, batch_var, saved_mean, saved_maybe_inv_var, tensor_format_, use_reserved_space); } + if (use_reshape) { + OP_REQUIRES(context, y->CopyFrom(*y, x_shape), + errors::InvalidArgument("Error during tensor copy.")); + } } private: @@ -1375,8 +1396,8 @@ class FusedBatchNormGradOpBase : public OpKernel { virtual void ComputeWithReservedSpace(OpKernelContext* context, bool use_reserved_space) { - const Tensor& y_backprop = context->input(0); - const Tensor& x = context->input(1); + Tensor y_backprop = context->input(0); + Tensor x = context->input(1); const Tensor& scale = context->input(2); // When is_training=True, batch mean and variance/inverted variance are // saved in the forward pass to be reused here. When is_training=False, @@ -1387,11 +1408,11 @@ class FusedBatchNormGradOpBase : public OpKernel { // saves inverted variance. const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4); - OP_REQUIRES(context, y_backprop.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", + OP_REQUIRES(context, y_backprop.dims() == 4 or y_backprop.dims() == 5, + errors::InvalidArgument("input must be 4 or 5-dimensional", y_backprop.shape().DebugString())); - OP_REQUIRES(context, x.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", + OP_REQUIRES(context, x.dims() == 4 or x.dims() == 5, + errors::InvalidArgument("input must be 4 or 5-dimensional", x.shape().DebugString())); OP_REQUIRES(context, scale.dims() == 1, errors::InvalidArgument("scale must be 1-dimensional", @@ -1404,10 +1425,27 @@ class FusedBatchNormGradOpBase : public OpKernel { errors::InvalidArgument( "saved variance must be 1-dimensional", saved_maybe_inv_var_or_pop_var.shape().DebugString())); + bool use_reshape = (x.dims() == 5); + auto x_shape = x.shape(); + TensorShape dest_shape; + if (use_reshape) { + const int64 in_batch = GetTensorDim(x, tensor_format_, 'N'); + int64 in_planes = GetTensorDim(x, tensor_format_, '0'); + int64 in_rows = GetTensorDim(x, tensor_format_, '1'); + int64 in_cols = GetTensorDim(x, tensor_format_, '2'); + const int64 in_depth = GetTensorDim(x, tensor_format_, 'C'); + dest_shape = ShapeFromFormat(tensor_format_, in_batch, + {{in_planes, in_rows * in_cols}}, in_depth); + OP_REQUIRES(context, x.CopyFrom(x, dest_shape), + errors::InvalidArgument("Error during tensor copy.")); + OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape), + errors::InvalidArgument("Error during tensor copy.")); + } Tensor* x_backprop = nullptr; + auto alloc_shape = use_reshape ? dest_shape : x_shape; OP_REQUIRES_OK(context, - context->allocate_output(0, x.shape(), &x_backprop)); + context->allocate_output(0, alloc_shape, &x_backprop)); const TensorShape& scale_offset_shape = scale.shape(); Tensor* scale_backprop = nullptr; @@ -1441,15 +1479,20 @@ class FusedBatchNormGradOpBase : public OpKernel { offset_backprop, use_reserved_space, tensor_format_); } else { // Necessary layout conversion is currently done in python. - CHECK(tensor_format_ == FORMAT_NHWC) - << "The implementation of FusedBatchNormGrad with is_training=False " - "only support " - << "NHWC tensor format for now."; + OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "The implementation of " + "FusedBatchNormGrad with is_training=False only support " + "NHWC tensor format for now.")); functor::FusedBatchNormFreezeGrad()( context, y_backprop, x, scale, saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, offset_backprop); } + if (use_reshape) { + OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape), + errors::InvalidArgument("Error during tensor copy.")); + } } private: diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD index a72b226bad5..d60455df8fb 100644 --- a/tensorflow/core/kernels/image/BUILD +++ b/tensorflow/core/kernels/image/BUILD @@ -133,7 +133,7 @@ IMAGE_DEPS = [ "//tensorflow/core:jpeg_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:png_internal", + "//tensorflow/core/lib/png:png_io", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:eigen_helpers", @@ -418,10 +418,10 @@ cc_library( linkopts = ["-ldl"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:android_gif_internal", - "//tensorflow/core:android_jpeg_internal", - "//tensorflow/core:android_png_internal", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:portable_jpeg_internal", "//tensorflow/core:portable_tensorflow_lib_lite", + "//tensorflow/core/lib/png:png_io", ], alwayslink = 1, ) diff --git a/tensorflow/core/kernels/linalg/BUILD b/tensorflow/core/kernels/linalg/BUILD index dabe8bd0b7d..0ceeb5f22ea 100644 --- a/tensorflow/core/kernels/linalg/BUILD +++ b/tensorflow/core/kernels/linalg/BUILD @@ -212,7 +212,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:fill_functor", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ] + if_cuda([ "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/util:cuda_solvers", diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 53168105aba..cd9680c8212 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -1212,9 +1212,11 @@ class MaxPoolingNoMaskOp : public OpKernel { data_format_, tensor_in, out_shape, propagate_nans_); } else { +#if !defined(TENSORFLOW_USE_ROCM) OP_REQUIRES(context, padding_ != EXPLICIT, errors::Unimplemented("Explicit padding is not supported ", "when CUDNN is not enabled.")); +#endif Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); if (is_int8x4) { diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD index 000a1212807..6689e0fb200 100644 --- a/tensorflow/core/kernels/mkl/BUILD +++ b/tensorflow/core/kernels/mkl/BUILD @@ -78,6 +78,7 @@ tf_cc_test_mkl( name = "mkl_quantized_conv_ops_perchannel_test", size = "small", srcs = ["mkl_quantized_conv_ops_perchannel_test.cc"], + linkstatic = 1, # Fixes dyld error on MacOS. deps = [ ":mkl_conv_op", "//tensorflow/core:array_ops_op_lib", @@ -92,6 +93,7 @@ tf_cc_test_mkl( name = "mkl_quantized_conv_ops_test", size = "small", srcs = ["mkl_quantized_conv_ops_test.cc"], + linkstatic = 1, # Fixes dyld error on MacOS. deps = [ ":mkl_conv_op", "//tensorflow/core:array_ops_op_lib", @@ -106,6 +108,7 @@ tf_cc_test_mkl( name = "mkl_qmatmul_op_test", size = "small", srcs = ["mkl_qmatmul_op_test.cc"], + linkstatic = 1, # Fixes dyld error on MacOS. deps = [ ":mkl_qmatmul_op", "//tensorflow/core:array_ops_op_lib", @@ -123,7 +126,7 @@ tf_mkl_kernel_library( srcs = ["mkl_quantize_op.cc"], deps = [ "//tensorflow/core/kernels:quantized_ops", - "//tensorflow/core:mkl_graph_util", + "//tensorflow/core/graph:mkl_graph_util", "@gemmlowp", ] + MKL_DEPS, ) @@ -132,6 +135,7 @@ tf_cc_test_mkl( name = "mkl_quantize_op_test", size = "small", srcs = ["mkl_quantize_op_test.cc"], + linkstatic = 1, # Fixes dyld error on MacOS. deps = [ ":mkl_quantize_op", "//tensorflow/core:array_ops_op_lib", @@ -144,6 +148,7 @@ tf_cc_test_mkl( name = "mkl_quantized_pooling_ops_test", size = "small", srcs = ["mkl_quantized_pooling_ops_test.cc"], + linkstatic = 1, # Fixes dyld error on MacOS. deps = [ ":mkl_pooling_ops", "//tensorflow/core:array_ops_op_lib", @@ -158,6 +163,7 @@ tf_cc_test_mkl( name = "mkl_quantized_concat_op_test", size = "small", srcs = ["mkl_quantized_concat_op_test.cc"], + linkstatic = 1, # Fixes dyld error on MacOS. deps = [ ":mkl_concat_op", "//tensorflow/core:array_ops_op_lib", @@ -249,7 +255,7 @@ tf_mkl_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", - "//tensorflow/core:mkl_graph_util", + "//tensorflow/core/graph:mkl_graph_util", "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core/kernels:concat_lib_hdrs", "//tensorflow/core/kernels:conv_ops", diff --git a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc index 754156c860a..c77601859cf 100644 --- a/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc @@ -40,7 +40,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -template +template class MklAvgPoolingOp : public MklPoolingForwardOpBase { public: explicit MklAvgPoolingOp(OpKernelConstruction* context) @@ -48,6 +48,7 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { // Workspace is an MKLDNN construct that is only used in Max Pooling. // So set workspace_enabled_ to false. this->workspace_enabled_ = false; + this->native_format_ = native_format; } void Compute(OpKernelContext* context) override { @@ -55,7 +56,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { const Tensor& input_tensor = MklGetInput(context, this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; - GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); + GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input, + this->native_format_); this->SanityCheckInput(context, input_tensor, dnn_shape_input); if (!context->status().ok()) return; @@ -116,7 +118,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, pooling_prop_kind, - static_cast(this->data_format_mkldnn_), input_md); + static_cast(this->data_format_mkldnn_), input_md, + this->native_format_); #else MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, @@ -174,11 +177,13 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { engine cpu_engine_ = engine(ENGINE_CPU, 0); }; // MklAvgPoolingOp -template +template class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { public: explicit MklAvgPoolingGradOp(OpKernelConstruction* context) - : MklPoolingBackwardOpBase(context) {} + : MklPoolingBackwardOpBase(context) { + this->native_format_ = native_format; + } void Compute(OpKernelContext* context) override { try { @@ -188,8 +193,10 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { MklGetInput(context, kInputTensorIndexInputGradient); MklDnnShape orig_input_mkl_shape, grad_mkl_shape; - GetMklShape(context, kInputTensorIndexInputShape, &orig_input_mkl_shape); - GetMklShape(context, kInputTensorIndexInputGradient, &grad_mkl_shape); + GetMklShape(context, kInputTensorIndexInputShape, &orig_input_mkl_shape, + this->native_format_); + GetMklShape(context, kInputTensorIndexInputGradient, &grad_mkl_shape, + this->native_format_); if (!context->status().ok()) return; // Used to allocate output_diff_src/diff_src. @@ -249,7 +256,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training, - static_cast(this->data_format_mkldnn_), src_md); + static_cast(this->data_format_mkldnn_), src_md, + this->native_format_); #else MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, @@ -273,7 +281,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { std::shared_ptr pooling_bwd_pd = pooling_bwd->GetPoolingBwdPd(); T* diff_dst_data = nullptr; - if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd, + if (!this->native_format_ && + IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd, pooling_bwd)) { grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( @@ -307,36 +316,56 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { engine cpu_engine_ = engine(ENGINE_CPU, 0); }; // MklAvgPoolingGradOp -#define REGISTER_MKL_AVGPOOL3D_KERNELS(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklAvgPool3D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklAvgPoolingOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklAvgPool3DGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklAvgPoolingGradOp); +#define REGISTER_MKL_AVGPOOL3D_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklAvgPool3D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklAvgPoolingOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklAvgPool3DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklAvgPoolingGradOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPool3D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklAvgPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPool3DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklAvgPoolingGradOp); TF_CALL_float(REGISTER_MKL_AVGPOOL3D_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_AVGPOOL3D_KERNELS); -#define REGISTER_MKL_AVGPOOL_KERNELS(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklAvgPool") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklAvgPoolingOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklAvgPoolGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklAvgPoolingGradOp); +#define REGISTER_MKL_AVGPOOL_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklAvgPool") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklAvgPoolingOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklAvgPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklAvgPoolingGradOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPool") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklAvgPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeAvgPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklAvgPoolingGradOp); TF_CALL_float(REGISTER_MKL_AVGPOOL_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_AVGPOOL_KERNELS); diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index da5a239c224..fbba4116e3b 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -15,26 +15,16 @@ limitations under the License. // See docs in ../ops/math_ops.cc. -// This file uses both oneDNN and MKL CBLAS batched xGEMM for acceleration of -// Batch Matrix-Matrix Multiplication (MatMul) operations. -// We currently register this kernel only for oneDNN supported data -// types (float, bfloat16). This file can be built with and without the use of -// the binary MKL CBLAS calls, controlled by the macro INTEL_MKL_DNN_ONLY. -// If INTEL_MKL_DNN_ONLY is defined, only oneDNN is used. For cases not -// supported by oneDNN (ex. Batchmatmul with broadcasting) we fall back to the -// default CPU implementation. -// if INTEL_MKL_DNN_ONLY is not defined, both oneDNN and MKL CBLAS -// implementations are used. This is only temporary, once we are able handle all -// cases with oneDNN, CBLAS calls will be removed. +// This file uses oneDNN library for acceleration of Batch Matrix-Matrix +// Multiplication (MatMul) operations. We currently register this kernel only +// for oneDNN supported data types (float, bfloat16). The maximum number of +// dimensions (rank) for output tensor is 12 in oneDNN. If output tensor rank +// exceeds 12, we fall back to Eigen library based kernel. #define EIGEN_USE_THREADS #if defined(INTEL_MKL) -#include -#if !defined(INTEL_MKL_DNN_ONLY) -#include "mkl_cblas.h" -#endif // !INTEL_MKL_DNN_ONLY #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -100,8 +90,8 @@ class BatchMatMulMkl : public OpKernel { } // lhs and rhs can have different dimensions - const int ndims_lhs = lhs.dims(); - const int ndims_rhs = rhs.dims(); + const auto ndims_lhs = lhs.dims(); + const auto ndims_rhs = rhs.dims(); // Get broadcast info MatMulBCast bcast(lhs.shape().dim_sizes(), rhs.shape().dim_sizes()); @@ -111,16 +101,7 @@ class BatchMatMulMkl : public OpKernel { "In[0] and In[1] must have compatible batch dimensions: ", lhs.shape().DebugString(), " vs. ", rhs.shape().DebugString())); -#if defined(INTEL_MKL_DNN_ONLY) - if (bcast.IsBroadcastingRequired()) { - // Calling Eigen Kernel for broadcasting case and return. Eigen does - // not have BF16 support, so we have to fail graciously in that case. - eigen_batch_mm_v2_.Compute(ctx); - return; - } -#endif // INTEL_MKL_DNN_ONLY TensorShape out_shape = bcast.output_batch_shape(); - auto batch_size = bcast.output_batch_size(); auto lhs_rows = lhs.dim_size(ndims_lhs - 2); auto lhs_cols = lhs.dim_size(ndims_lhs - 1); @@ -137,6 +118,12 @@ class BatchMatMulMkl : public OpKernel { out_shape.AddDim(lhs_rows); out_shape.AddDim(rhs_cols); + // The maximum number of dimensions for a tensor in DNNL is 12. + OP_REQUIRES( + ctx, out_shape.dims() <= 12, + errors::InvalidArgument( + "Rank of output tensor must be <= 12, but is ", out_shape.dims(), + ". Current implementation supports upto rank 12 tensors.")); Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); @@ -149,75 +136,17 @@ class BatchMatMulMkl : public OpKernel { return; } - auto rhs_reshaped = rhs.template flat_inner_dims(); - auto lhs_reshaped = lhs.template flat_inner_dims(); - auto out_reshaped = out->template flat_inner_dims(); - const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1); - const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2); - const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2); - - std::vector m_array(batch_size, M); - std::vector n_array(batch_size, N); - std::vector k_array(batch_size, K); - std::vector lda_array(batch_size, adj_x_ ? M : K); - std::vector ldb_array(batch_size, adj_y_ ? K : N); - std::vector ldc_array(batch_size, N); - std::vector group_size(1, batch_size); - - bool bcast_not_supported = false; -#if defined(INTEL_MKL_DNN_ONLY) - bcast_not_supported = true; -#endif // INTEL_MKL_DNN_ONLY - if (std::is_same::value || bcast_not_supported) { - // DNNL bfloat16 API requires a, b, and c as pointers to tensors - // represented as flat-byte array. - const Scalar* a = nullptr; - const Scalar* b = nullptr; - Scalar* c = nullptr; - a = &lhs_reshaped(0, 0, 0); - b = &rhs_reshaped(0, 0, 0); - OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(), - errors::Unimplemented("Broadcasting is not supported for " - "_MklBatchMatMul yet.")); - c = &out_reshaped(0, 0, 0); - // TODO(nhasabni): Use appropriate cast instead of passing addresses of - // a,b and c. - MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, - k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1, - group_size, ctx); - } else { - std::vector a_array; - std::vector b_array; - std::vector c_array; - a_array.reserve(batch_size); - b_array.reserve(batch_size); - c_array.reserve(batch_size); - - if (!bcast.IsBroadcastingRequired()) { - for (int64 i = 0; i < batch_size; i++) { - a_array.push_back(&lhs_reshaped(i, 0, 0)); - b_array.push_back(&rhs_reshaped(i, 0, 0)); - c_array.push_back(&out_reshaped(i, 0, 0)); - } - } else { - // Broadcasting is needed, so get the mapping from flattened output - // batch indices to x's and y's flattened batch indices. - const std::vector& a_batch_indices = bcast.x_batch_indices(); - const std::vector& b_batch_indices = bcast.y_batch_indices(); - - for (int64 i = 0; i < batch_size; i++) { - a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0)); - b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0)); - c_array.push_back(&out_reshaped(i, 0, 0)); - } - } - - // MKL CBLAS API requires a, b, and c as array of pointers, where each - // pointer is to 2D matrix. - MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, - k_array, &a_array[0], lda_array, &b_array[0], ldb_array, - &c_array[0], ldc_array, 1, group_size, ctx); - } + // Compute parameters for DNNL matmul primitive. + auto params = CreateMatMulParams(lhs.shape(), rhs.shape(), out_shape); + // Create or retrieve matmul primitive from cache. + MklMatMulPrimitive* matmul_prim = + MklMatMulPrimitiveFactory::Get( + *params, false /* value for do_not_cache */); + // Execute matmul primitive. + std::shared_ptr cpu_stream; + cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); + matmul_prim->Execute(lhs.flat().data(), rhs.flat().data(), + out->flat().data(), cpu_stream); } private: @@ -225,60 +154,78 @@ class BatchMatMulMkl : public OpKernel { bool adj_y_; BatchMatMulV2Op eigen_batch_mm_v2_; - void MklCblasGemmBatch( - const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, - const std::vector& M_Array, const std::vector& N_Array, - const std::vector& K_Array, const float** A_Array, - const std::vector& lda_Array, const float** B_Array, - const std::vector& ldb_Array, float** C_Array, - const std::vector& ldc_Array, const MKL_INT group_count, - const std::vector& group_size, OpKernelContext* ctx) { -#if !defined(INTEL_MKL_DNN_ONLY) - std::vector TransA_Array( - group_size[0], TransA ? CblasTrans : CblasNoTrans); - std::vector TransB_Array( - group_size[0], TransB ? CblasTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], &M_Array[0], - &N_Array[0], &K_Array[0], &alpha_Array[0], - reinterpret_cast(A_Array), &lda_Array[0], - reinterpret_cast(B_Array), &ldb_Array[0], - &beta_Array[0], reinterpret_cast(C_Array), - &ldc_Array[0], group_count, &group_size[0]); -#else - DCHECK(Layout == CblasRowMajor); - std::vector TransA_Array(group_size[0], TransA); - std::vector TransB_Array(group_size[0], TransB); - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - dnnl_gemm_batch(TransA_Array, TransB_Array, M_Array, N_Array, - K_Array, alpha_Array, *A_Array, *B_Array, beta_Array, - *C_Array, group_count, group_size, ctx); -#endif // !INTEL_MKL_DNN_ONLY + using dims = dnnl::memory::dims; + + // This method makes the rank (ndims) of input same as the output by adding + // new axes to the input. For example, if input shape is [a, b, c, d] and + // output shape is [e, f, g, h, i, j], then the reshaped input would have a + // shape of [1, 1, a, b, c, d]. + void ExpandInputDimsToOutputShape(const TensorShape& input_shape, + const TensorShape& output_shape, + dims* reshaped_dims) { + auto ndims_input = input_shape.dims(); + auto ndims_output = output_shape.dims(); + auto dim_offset = ndims_output - ndims_input; + DCHECK(dim_offset > 0); + reshaped_dims->clear(); + reshaped_dims->resize(ndims_output, 1); + auto input_dims = input_shape.dim_sizes(); + for (int dim_idx = 0; dim_idx < ndims_input; ++dim_idx) + reshaped_dims->at(dim_idx + dim_offset) = input_dims[dim_idx]; } -// BatchMatMul BFloat16 support only exists in DNNL 1.2 onwards. -#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) - void MklCblasGemmBatch( - const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, - const std::vector& M_Array, const std::vector& N_Array, - const std::vector& K_Array, const bfloat16** A_Array, - const std::vector& lda_Array, const bfloat16** B_Array, - const std::vector& ldb_Array, bfloat16** C_Array, - const std::vector& ldc_Array, const MKL_INT group_count, - const std::vector& group_size, OpKernelContext* ctx) { - DCHECK(Layout == CblasRowMajor); - std::vector TransA_Array(group_size[0], TransA); - std::vector TransB_Array(group_size[0], TransB); - std::vector alpha_Array(group_size[0], 1.0); - std::vector beta_Array(group_size[0], 0.0); - // TODO(nhasabni): Remove *A when we pass a, b, and c correctly. - // MKLDNN API does not require lda, ldb, and ldc. - dnnl_gemm_batch( - TransA_Array, TransB_Array, M_Array, N_Array, K_Array, alpha_Array, - *A_Array, *B_Array, beta_Array, *C_Array, group_count, group_size, ctx); + + std::unique_ptr CreateMatMulParams( + const TensorShape& lhs_shape, const TensorShape& rhs_shape, + const TensorShape& out_shape) { + const auto ndims_lhs = lhs_shape.dims(); + const auto ndims_rhs = rhs_shape.dims(); + const auto ndims_out = out_shape.dims(); + auto lhs_dims = TFShapeToMklDnnDims(lhs_shape); + auto rhs_dims = TFShapeToMklDnnDims(rhs_shape); + auto out_dims = TFShapeToMklDnnDims(out_shape); + + // DNNL matmul_primitive requires ranks of inputs and output to be same. + // Create dnnl::memory::dims for inputs and output of same rank. + // It is assumed here that MatMulBCast object creates output_batch_shape as + // a conforming superset of input batch shapes, i.e., ndims_out >= + // ndims_lhs and ndims_out >= ndims_rhs. + if (ndims_lhs < ndims_out) { + ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims); + } + if (ndims_rhs < ndims_out) { + ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims); + } + + using dim = dnnl::memory::dim; + dim m; // number of rows in x + dim k; // number of columns in x + dim n; // number of columns in y + auto lhs_strides = CalculateTFStrides(lhs_dims); + auto rhs_strides = CalculateTFStrides(rhs_dims); + auto out_strides = CalculateTFStrides(out_dims); + + if (adj_x_) { + int m_idx = ndims_out - 1; + int k_idx = ndims_out - 2; + m = lhs_dims[m_idx]; + k = lhs_dims[k_idx]; + std::swap(lhs_dims[m_idx], lhs_dims[k_idx]); + lhs_strides[m_idx] = m; + lhs_strides[k_idx] = 1; + } + + if (adj_y_) { + int k_idx = ndims_out - 1; + int n_idx = ndims_out - 2; + k = rhs_dims[k_idx]; + n = rhs_dims[n_idx]; + std::swap(rhs_dims[k_idx], rhs_dims[n_idx]); + rhs_strides[k_idx] = k; + rhs_strides[n_idx] = 1; + } + return std::make_unique( + lhs_dims, rhs_dims, out_dims, lhs_strides, rhs_strides, out_strides); } -#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 }; #define REGISTER_BATCH_MATMUL_MKL(TYPE) \ @@ -294,14 +241,11 @@ class BatchMatMulMkl : public OpKernel { .TypeConstraint("T") \ .Label(mkl_op_registry::kMklNameChangeOpLabel), \ BatchMatMulMkl) - #ifdef ENABLE_MKL TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2); -#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16) TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL); TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2); -#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16 #endif // ENABLE_MKL } // end namespace tensorflow diff --git a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc index 1e84d938ef7..5814eec0b76 100644 --- a/tensorflow/core/kernels/mkl/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl/mkl_conv_ops.cc @@ -490,8 +490,9 @@ class MklConvOp : public OpKernel { const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); OP_REQUIRES( context, stride_n == 1 && stride_c == 1, - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); is_filter_const_ = false; if (context->HasAttr("is_filter_const")) { @@ -939,13 +940,28 @@ class MklConvOp : public OpKernel { if (fuse_add_) { const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add); MklDnnShape add_mkl_shape; - GetMklShape(context, kInputIndex_Add, &add_mkl_shape); - + GetMklShape(context, kInputIndex_Add, &add_mkl_shape, native_format); + if (native_format) { + // Forward the summand tensor to the output only if it has no other + // references, otherwise make a copy of it. + if (!context->forward_input_to_output_with_shape( + kInputIndex_Add, kOutputIndex_Dst, output_tf_shape, + output_tensor)) { + AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, + output_tf_shape, *output_mkl_shape, + native_format); + bool result = + (*output_tensor)->CopyFrom(add_tensor, add_tensor.shape()); + DCHECK(result); + } + return; + } // Check if reorder is needed - if (add_mkl_shape == *output_mkl_shape) { - ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, - kOutputIndex_Dst, add_mkl_shape); - *output_tensor = context->mutable_output(kOutputIndex_Dst); + if (add_mkl_shape == *output_mkl_shape && + ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, + kOutputIndex_Dst, output_tensor, + add_mkl_shape, false)) { + return; } else { AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, output_tf_shape, *output_mkl_shape, @@ -1282,14 +1298,14 @@ class MklConvOp : public OpKernel { // Base class for fused convolution forward operations template + bool pad_enabled, bool native_format> class MklFusedConvOp : public MklConvOp { + Tpadding, false, false, false, native_format> { public: explicit MklFusedConvOp(OpKernelConstruction* context) : MklConvOp(context) { + Tpadding, false, false, false, native_format>(context) { // Since we came here through the registration of _MklFusedConv2D, get // all information from 'fused_ops' and 'num_args' std::vector fused_ops; @@ -1403,14 +1419,17 @@ class MklFusedConvOp template + bool pad_enabled, bool bias_enabled, bool is_depthwise, + bool native_format> class MklFusedDepthwiseConvOp : public MklConvOp { + Tpadding, bias_enabled, false, is_depthwise, + native_format> { public: explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context) : MklConvOp(context) { + Tpadding, bias_enabled, false, is_depthwise, native_format>( + context) { // Since we came here through the registration of // _MklFusedDepthwiseConv2dNative, get all // information from 'fused_ops' and 'num_args' @@ -1870,7 +1889,7 @@ class MklQuantizedConv2DSumReluOp GetMklShape(context, summand_idx, &summand_mkl_shape); auto dst_md = summand_mkl_shape.GetMklLayout(); - // TODO(mdfaijul): handle both non-MKL and MKL tensors + // TODO(intel-tf): Handle both non-MKL and MKL tensors if (summand_type == DT_QINT8) { OP_REQUIRES_OK( context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape())); @@ -1879,9 +1898,13 @@ class MklQuantizedConv2DSumReluOp summand_mkl_shape.SetMklLayout(&dst_md); summand_mkl_shape.SetElemType(MklDnnType()); } - ForwardMklTensorInToOutWithMklShape(context, summand_idx, 0, - summand_mkl_shape); - *output_tensor = const_cast(&summand); + // TODO(intel-tf): Support cases when summand cannot be forwarded. + OP_REQUIRES( + context, + ForwardMklTensorInToOutWithMklShape( + context, summand_idx, 0, output_tensor, summand_mkl_shape, false), + errors::InvalidArgument( + "Summand cannot be forwarded in the current fusion.")); return; } MklConvOp("T") \ .Label(mkl_op_registry::kMklNameChangeOpLabel), \ - MklConvOp); + MklConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeConv2DWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativePadWithConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativePadWithConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklConvOp); TF_CALL_float(REGISTER_MKL_CPU_2D); TF_CALL_bfloat16(REGISTER_MKL_CPU_2D); @@ -2474,7 +2517,14 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D); .TypeConstraint("T") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ MklFusedDepthwiseConvOp); \ + true, false>); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedDepthwiseConv2dNative") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedDepthwiseConvOp); \ REGISTER_KERNEL_BUILDER( \ Name("_MklNativeDepthwiseConv2dNative") \ .Device(DEVICE_CPU) \ @@ -2487,34 +2537,54 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE); // Note we are registering _MklFusedConv2D. // We check the fused_ops attributes to decide if bias is enabled or not. -#define REGISTER_MKL_CPU_2D_FUSED(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklFusedConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedConvOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklPadWithFusedConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("Tpaddings") \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedConvOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklPadWithFusedConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tpaddings") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedConvOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("__MklDummyPadWithFusedConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("Tpaddings") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklDummyOp); +#define REGISTER_MKL_CPU_2D_FUSED(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklFusedConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklPadWithFusedConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tpaddings") \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklPadWithFusedConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("__MklDummyPadWithFusedConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklDummyOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativePadWithFusedConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tpaddings") \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedConvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativePadWithFusedConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedConvOp); TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED); TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc index 4b979bfc5fd..c641fd56275 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc @@ -775,7 +775,7 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory { // with MKL. This is different from default where the classes are // derived. Moves enabling to compile-time rather than runtime. template + bool is_batch_norm_ex = false, bool native_format = false> class MklFusedBatchNormOp : public OpKernel { public: explicit MklFusedBatchNormOp(OpKernelConstruction* context) @@ -835,7 +835,7 @@ class MklFusedBatchNormOp : public OpKernel { TensorShape tf_shape_src; MklDnnShape dnn_shape_src; - GetMklShape(context, kSrcIndex, &dnn_shape_src); + GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); if (dnn_shape_src.IsMklTensor()) { tf_shape_src = dnn_shape_src.GetTfShape(); @@ -986,7 +986,7 @@ class MklFusedBatchNormOp : public OpKernel { // Check if reorder is needed for src. const T* src_data = nullptr; std::shared_ptr bn_fwd_pd = bn_fwd->GetBatchNormFwdPd(); - if (IS_SRC_REORDER_NEEDED(src_md, bn_fwd_pd, bn_fwd)) { + if (!native_format && IS_SRC_REORDER_NEEDED(src_md, bn_fwd_pd, bn_fwd)) { src.SetUsrMem(src_md, &src_tensor); src.CheckReorderToOpMem( MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(bn_fwd_pd), @@ -997,7 +997,7 @@ class MklFusedBatchNormOp : public OpKernel { src_data = static_cast(const_cast(src_tensor.flat().data())); } - // Allocate output (dst) tensor; always set it as MKL-DNN layout + // Allocate output (dst) tensor MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; dnn_shape_dst.SetMklTensor(true); @@ -1008,8 +1008,11 @@ class MklFusedBatchNormOp : public OpKernel { : src_tensor.shape().dims(); dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt); tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); + if (native_format) { + tf_shape_dst = dnn_shape_dst.GetTfShape(); + } AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, - dnn_shape_dst); + dnn_shape_dst, native_format); U* weights_op_data = weights_data; U* mean_op_data = saved_mean_tensor->flat().data(); @@ -1097,7 +1100,7 @@ class MklFusedBatchNormOp : public OpKernel { MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, - dnn_shape_dst); + dnn_shape_dst, native_format); DCHECK(*dst_tensor); memset(const_cast((*dst_tensor)->tensor_data().data()), 0, (*dst_tensor)->tensor_data().size()); @@ -1135,7 +1138,8 @@ class MklFusedBatchNormOp : public OpKernel { MklDnnShape mkl_shape_batch_mean; mkl_shape_batch_mean.SetMklTensor(false); AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, - tf_shape_scale, mkl_shape_batch_mean); + tf_shape_scale, mkl_shape_batch_mean, + native_format); DCHECK(*batch_mean_tensor); // Set NAN mean value in case of empty input tensor @@ -1148,7 +1152,7 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_batch_variance.SetMklTensor(false); AllocateOutputSetMklShape(context, kBatchVarianceIndex, batch_variance_tensor, tf_shape_scale, - mkl_shape_batch_variance); + mkl_shape_batch_variance, native_format); DCHECK(*batch_variance_tensor); // Set NAN variance value in case of empty input tensor @@ -1159,7 +1163,8 @@ class MklFusedBatchNormOp : public OpKernel { MklDnnShape mkl_shape_saved_mean; mkl_shape_saved_mean.SetMklTensor(false); AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, - tf_shape_scale, mkl_shape_saved_mean); + tf_shape_scale, mkl_shape_saved_mean, + native_format); DCHECK(*saved_mean_tensor); // Set 0 mean value in case of empty input tensor @@ -1170,7 +1175,7 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_saved_variance.SetMklTensor(false); AllocateOutputSetMklShape(context, kSavedVarianceIndex, saved_variance_tensor, tf_shape_scale, - mkl_shape_saved_variance); + mkl_shape_saved_variance, native_format); DCHECK(*saved_variance_tensor); // Set 0 variance value in case of empty input tensor @@ -1185,13 +1190,14 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_reserved_space.SetMklTensor(false); AllocateOutputSetMklShape(context, kReservedSpaceIndex, reserved_space_tensor, workspace_tf_shape, - mkl_shape_reserved_space); + mkl_shape_reserved_space, native_format); DCHECK((*reserved_space_tensor) != nullptr); } } }; -template +template class MklFusedBatchNormGradOp : public OpKernel { public: explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) @@ -1227,8 +1233,8 @@ class MklFusedBatchNormGradOp : public OpKernel { : Tensor(); MklDnnShape dnn_shape_src, dnn_shape_diff_dst; - GetMklShape(context, kSrcIndex, &dnn_shape_src); - GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); + GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); + GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst, native_format); TensorShape tf_shape_src, tf_shape_diff_dst; if (dnn_shape_diff_dst.IsMklTensor()) { @@ -1335,24 +1341,26 @@ class MklFusedBatchNormGradOp : public OpKernel { static_cast(const_cast(src_tensor.flat().data())); #ifdef ENABLE_MKLDNN_V1 - // MKL-DNN requires src and diff_dst to be in same memory layout, either - // blocked or native format. If these inputs are in different formats, - // convert the one in native format to blocked format as MKL-DNN gives - // better performance for blocked format. - if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { - reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - reorder_diff_dst.CheckReorderToOpMem( - MEMORY_PD_WITHOUT_DATA(src_md, cpu_engine_), context); - diff_dst_md = src_md; - diff_dst_data = - static_cast(reorder_diff_dst.GetOpMem().get_data_handle()); - } else if (!dnn_shape_src.IsMklTensor() && - dnn_shape_diff_dst.IsMklTensor()) { - reorder_src.SetUsrMem(src_md, &src_tensor); - reorder_src.CheckReorderToOpMem( - MEMORY_PD_WITHOUT_DATA(diff_dst_md, cpu_engine_), context); - src_md = diff_dst_md; - src_data = static_cast(reorder_src.GetOpMem().get_data_handle()); + if (!native_format) { + // MKL-DNN requires src and diff_dst to be in same memory layout, either + // blocked or native format. If these inputs are in different formats, + // convert the one in native format to blocked format as MKL-DNN gives + // better performance for blocked format. + if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { + reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + reorder_diff_dst.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(src_md, cpu_engine_), context); + diff_dst_md = src_md; + diff_dst_data = + static_cast(reorder_diff_dst.GetOpMem().get_data_handle()); + } else if (!dnn_shape_src.IsMklTensor() && + dnn_shape_diff_dst.IsMklTensor()) { + reorder_src.SetUsrMem(src_md, &src_tensor); + reorder_src.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(diff_dst_md, cpu_engine_), context); + src_md = diff_dst_md; + src_data = static_cast(reorder_src.GetOpMem().get_data_handle()); + } } #endif // ENABLE_MKLDNN_V1 @@ -1381,7 +1389,8 @@ class MklFusedBatchNormGradOp : public OpKernel { // Check if diff_dst input needs to be reordered std::shared_ptr bn_bwd_pd = bn_bwd->GetBatchNormBwdPd(); - if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) { + if (!native_format && + IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) { diff_dst.SetUsrMem(diff_dst_md, diff_dst_data); diff_dst.CheckReorderToOpMem( MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd), @@ -1390,7 +1399,7 @@ class MklFusedBatchNormGradOp : public OpKernel { diff_dst_data = static_cast(diff_dst.GetOpMem().get_data_handle()); } - if (IS_SRC_REORDER_NEEDED(src_md, bn_bwd_pd, bn_bwd)) { + if (!native_format && IS_SRC_REORDER_NEEDED(src_md, bn_bwd_pd, bn_bwd)) { src.SetUsrMem(src_md, src_data); src.CheckReorderToOpMem( MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(bn_bwd_pd), @@ -1412,8 +1421,12 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt); dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); + if (native_format) { + tf_shape_diff_src = dnn_shape_diff_src.GetTfShape(); + } AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, - tf_shape_diff_src, dnn_shape_diff_src); + tf_shape_diff_src, dnn_shape_diff_src, + native_format); U* mean_data = static_cast(const_cast(saved_mean_tensor.flat().data())); @@ -1479,7 +1492,7 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape dnn_shape_diff_src; dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, - tf_shape_src, dnn_shape_diff_src); + tf_shape_src, dnn_shape_diff_src, native_format); auto diff_src_data = (*diff_src_tensor)->flat().data(); std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), static_cast(0)); @@ -1506,7 +1519,8 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape mkl_shape_diff_scale; mkl_shape_diff_scale.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, - tf_shape_scale_shift, mkl_shape_diff_scale); + tf_shape_scale_shift, mkl_shape_diff_scale, + native_format); DCHECK(*diff_scale_tensor); auto diff_scale_data = (*diff_scale_tensor)->flat().data(); @@ -1516,7 +1530,8 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape mkl_shape_diff_shift; mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, - tf_shape_scale_shift, mkl_shape_diff_shift); + tf_shape_scale_shift, mkl_shape_diff_shift, + native_format); DCHECK(*diff_shift_tensor); auto diff_shift_data = (*diff_shift_tensor)->flat().data(); @@ -1529,11 +1544,11 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape mkl_shape_p; mkl_shape_p.SetMklTensor(false); AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), - mkl_shape_p); + mkl_shape_p, native_format); std::fill_n(p1_tensor->flat().data(), p1_tensor->shape().num_elements(), static_cast(0)); AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), - mkl_shape_p); + mkl_shape_p, native_format); std::fill_n(p2_tensor->flat().data(), p2_tensor->shape().num_elements(), static_cast(0)); } @@ -1547,7 +1562,13 @@ class MklFusedBatchNormGradOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormOp); + MklFusedBatchNormOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedBatchNorm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedBatchNormOp); TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU); TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); @@ -1560,7 +1581,14 @@ TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); .TypeConstraint("T") \ .TypeConstraint("U") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormOp); + MklFusedBatchNormOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedBatchNormV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedBatchNormOp); REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float); REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); @@ -1572,7 +1600,13 @@ REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormGradOp); + MklFusedBatchNormGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedBatchNormGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedBatchNormGradOp); TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); @@ -1585,7 +1619,14 @@ TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); .TypeConstraint("T") \ .TypeConstraint("U") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormGradOp); + MklFusedBatchNormGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedBatchNormGradV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedBatchNormGradOp); REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float); REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); @@ -1594,21 +1635,35 @@ REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); // TODO: FusedBatchNormV3 has an additional output that is used to // hold intermediate results. This parameter functionality is // not implemented on CPU. -#define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklFusedBatchNormV3") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("U") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklFusedBatchNormEx") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .TypeConstraint("U") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormOp); +#define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklFusedBatchNormV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedBatchNormOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklFusedBatchNormEx") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedBatchNormOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedBatchNormV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedBatchNormOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedBatchNormEx") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedBatchNormOp); REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float); REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float); @@ -1632,7 +1687,14 @@ REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") .TypeConstraint("T") \ .TypeConstraint("U") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormGradOp); + MklFusedBatchNormGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklNativeFusedBatchNormGradV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("U") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedBatchNormGradOp); REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float); REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float); diff --git a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc index c9b52f3da94..60068a09d62 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" @@ -275,22 +276,37 @@ class FusedBatchNormOpTest : public OpsTestBase { const bool is_training, Tensor* output, Tensor* batch_mean, Tensor* batch_var) { DataType dtype = DataTypeToEnum::v(); - TF_EXPECT_OK(NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNorm") - .Input(FakeInput(dtype)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Attr("exponential_avg_factor", exponential_avg_factor) - .Attr("epsilon", 0.001) - .Attr("is_training", is_training) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("MklFusedBatchNorm", "_MklFusedBatchNorm") + .Input(FakeInput(dtype)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Attr("exponential_avg_factor", exponential_avg_factor) + .Attr("epsilon", 0.001) + .Attr("is_training", is_training) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("MklNativeFusedBatchNorm", + "_MklNativeFusedBatchNorm") + .Input(FakeInput(dtype)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("exponential_avg_factor", exponential_avg_factor) + .Attr("epsilon", 0.001) + .Attr("is_training", is_training) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); AddInputFromArray(input.shape(), input.flat()); @@ -298,20 +314,29 @@ class FusedBatchNormOpTest : public OpsTestBase { AddInputFromArray(offset.shape(), offset.flat()); AddInputFromArray(mean.shape(), mean.flat()); AddInputFromArray(variance.shape(), variance.flat()); - for (int i = 0; i < 5; ++i) - AddInputFromArray(dummy_shape, dummy_tensor); + if (!NativeFormatEnabled()) { + for (int i = 0; i < 5; ++i) + AddInputFromArray(dummy_shape, dummy_tensor); + } TF_ASSERT_OK(RunOpKernel()); - CommonTestUtilities test_util; - test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5), output); + if (!NativeFormatEnabled()) { + CommonTestUtilities test_util; + test_util.PerformConversion(dtype, *GetOutput(0), *GetOutput(5), + output); - CommonTestUtilities test_util_mean; - test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6), - batch_mean); + CommonTestUtilities test_util_mean; + test_util_mean.PerformConversion(dtype, *GetOutput(1), *GetOutput(6), + batch_mean); - CommonTestUtilities test_util_var; - test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7), - batch_var); + CommonTestUtilities test_util_var; + test_util_var.PerformConversion(dtype, *GetOutput(2), *GetOutput(7), + batch_var); + } else { + *output = *GetOutput(0); + *batch_mean = *GetOutput(1); + *batch_var = *GetOutput(2); + } }; CommonTestUtilities::VerifyTensorsClose(exponential_avg_factor, diff --git a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc index 8b48007321a..7bd47e9d014 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" @@ -265,20 +266,34 @@ class MklFusedConv2DOpTest : public OpsTestBase { DataType dtype = DataTypeToEnum::v(); int num_args = static_cast(args.size()); - TF_EXPECT_OK(NodeDefBuilder("fused_conv_op", "_MklFusedConv2D") - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(num_args, dtype)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(num_args, DT_UINT8)) - .Attr("T", dtype) - .Attr("num_args", num_args) - .Attr("strides", {1, stride, stride, 1}) - .Attr("padding", "SAME") - .Attr("fused_ops", fused_ops) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("fused_conv_op", "_MklFusedConv2D") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Attr("T", dtype) + .Attr("num_args", num_args) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("fused_ops", fused_ops) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("fused_conv_op", "_MklNativeFusedConv2D") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Attr("T", dtype) + .Attr("num_args", num_args) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("fused_ops", fused_ops) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); @@ -286,20 +301,26 @@ class MklFusedConv2DOpTest : public OpsTestBase { AddInputFromArray(filter.shape(), filter.flat()); for (const Tensor& arg : args) AddInputFromArray(arg.shape(), arg.flat()); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); - for (const Tensor& arg : args) + if (!NativeFormatEnabled()) { AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + for (const Tensor& arg : args) + AddInputFromArray(dummy_shape, dummy_tensor); + } TF_ASSERT_OK(RunOpKernel()); // Compare output to expected results const Tensor& output_tensor = *GetOutput(0); - // Index 2 will need to be changed if the number of outputs produced - // by MklConv2D change. - const Tensor& output_meta_tensor = *GetOutput(2); CommonTestUtilities test_util; - test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, - output); + if (!NativeFormatEnabled()) { + // Index 2 will need to be changed if the number of outputs produced + // by MklConv2D change. + const Tensor& output_meta_tensor = *GetOutput(2); + test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, + output); + } else { + *output = output_tensor; + } } // Verifies computing unfused ops in a graph is identical to FusedConv2D. @@ -545,21 +566,36 @@ class MklFusedDepthwiseConv2DOpTest : public OpsTestBase { DataType dtype = DataTypeToEnum::v(); int num_args = static_cast(args.size()); - TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op", - "_MklFusedDepthwiseConv2dNative") - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(num_args, dtype)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(num_args, DT_UINT8)) - .Attr("T", dtype) - .Attr("num_args", num_args) - .Attr("strides", {1, stride, stride, 1}) - .Attr("padding", "SAME") - .Attr("fused_ops", fused_ops) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op", + "_MklFusedDepthwiseConv2dNative") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Attr("T", dtype) + .Attr("num_args", num_args) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("fused_ops", fused_ops) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op", + "_MklNativeFusedDepthwiseConv2dNative") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Attr("T", dtype) + .Attr("num_args", num_args) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("fused_ops", fused_ops) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); @@ -567,20 +603,26 @@ class MklFusedDepthwiseConv2DOpTest : public OpsTestBase { AddInputFromArray(filter.shape(), filter.flat()); for (const Tensor& arg : args) AddInputFromArray(arg.shape(), arg.flat()); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); - for (const Tensor& arg : args) + if (!NativeFormatEnabled()) { AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + for (const Tensor& arg : args) + AddInputFromArray(dummy_shape, dummy_tensor); + } TF_ASSERT_OK(RunOpKernel()); // Compare output to expected results const Tensor& output_tensor = *GetOutput(0); - // Index 2 will need to be changed if the number of outputs produced - // by MklDepthwiseConv2D change. - const Tensor& output_meta_tensor = *GetOutput(2); CommonTestUtilities test_util; - test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, - output); + if (!NativeFormatEnabled()) { + // Index 2 will need to be changed if the number of outputs produced + // by MklDepthwiseConv2D change. + const Tensor& output_meta_tensor = *GetOutput(2); + test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, + output); + } else { + *output = output_tensor; + } } // Verifies computing unfused ops in a graph is identical to @@ -750,35 +792,55 @@ class FusedPadConvOpTest : public OpsTestBase { 59, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); // Create a fused pad+conv2d node - TF_EXPECT_OK(NodeDefBuilder("fused_pad_conv_op", "_MklPadWithConv2D") - .Input(FakeInput(dtype)) // Input - .Input(FakeInput(dtype)) // Filter - .Input(FakeInput(DT_INT32)) // Padding - .Input(FakeInput(DT_UINT8)) // MKl second tensor - .Input(FakeInput(DT_UINT8)) // MKl second tensor - .Input(FakeInput(DT_UINT8)) // MKl second tensor - .Attr("padding", "VALID") - .Attr("data_format", data_format) - .Attr("T", dtype) - .Attr("strides", {1, stride, stride, 1}) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("fused_pad_conv_op", "_MklPadWithConv2D") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(dtype)) // Filter + .Input(FakeInput(DT_INT32)) // Padding + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("padding", "VALID") + .Attr("data_format", data_format) + .Attr("T", dtype) + .Attr("strides", {1, stride, stride, 1}) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK( + NodeDefBuilder("fused_pad_conv_op", "_MklNativePadWithConv2D") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(dtype)) // Filter + .Input(FakeInput(DT_INT32)) // Padding + .Attr("padding", "VALID") + .Attr("data_format", data_format) + .Attr("T", dtype) + .Attr("strides", {1, stride, stride, 1}) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); // Setting up inputs and execute AddInputFromArray(image.shape(), image.flat()); AddInputFromArray(filter.shape(), filter.flat()); AddInputFromArray(padding.shape(), padding.flat()); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); + if (!NativeFormatEnabled()) { + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + } TF_ASSERT_OK(RunOpKernel()); // Compare output to expected results const Tensor& first = *GetOutput(0); - const Tensor& second = *GetOutput(2); CommonTestUtilities test_util; - test_util.ConvertAndCompareIntegral(dtype, first, second, expected); + if (!NativeFormatEnabled()) { + const Tensor& second = *GetOutput(2); + test_util.ConvertAndCompareIntegral(dtype, first, second, expected); + } else { + test::ExpectTensorEqual(expected, first); + } } }; @@ -805,33 +867,52 @@ class FilterCacheTest : public OpsTestBase { const bool is_filter_const) { const int stride = 1; - TF_EXPECT_OK(NodeDefBuilder("conv2d_filter_cache", "_MklConv2D") - .Input(FakeInput(dtype)) // Input - .Input(FakeInput(dtype)) // Filter - .Input(FakeInput(DT_UINT8)) // MKl second tensor - .Input(FakeInput(DT_UINT8)) // MKl second tensor - .Attr("padding", "VALID") - .Attr("data_format", "NHWC") - .Attr("is_filter_const", is_filter_const) - .Attr("T", dtype) - .Attr("strides", {1, stride, stride, 1}) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("conv2d_filter_cache", "_MklConv2D") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(dtype)) // Filter + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Input(FakeInput(DT_UINT8)) // MKL second tensor + .Attr("padding", "VALID") + .Attr("data_format", "NHWC") + .Attr("is_filter_const", is_filter_const) + .Attr("T", dtype) + .Attr("strides", {1, stride, stride, 1}) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("conv2d_filter_cache", "_MklNativeConv2D") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(dtype)) // Filter + .Attr("padding", "VALID") + .Attr("data_format", "NHWC") + .Attr("is_filter_const", is_filter_const) + .Attr("T", dtype) + .Attr("strides", {1, stride, stride, 1}) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); // Setting up inputs and execute AddInputFromArray(image.shape(), image.flat()); AddInputFromArray(filter.shape(), filter.flat()); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); + if (!NativeFormatEnabled()) { + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + } TF_ASSERT_OK(RunOpKernel()); // Compare outputs to expected results const Tensor& output = *GetOutput(0); - const Tensor& output_layout = *GetOutput(2); CommonTestUtilities conv_comp; - conv_comp.ConvertAndCompare(dtype, output, output_layout, expected); + if (!NativeFormatEnabled()) { + const Tensor& output_layout = *GetOutput(2); + conv_comp.ConvertAndCompare(dtype, output, output_layout, expected); + } else { + test::ExpectTensorEqual(expected, output); + } // TODO(bhavanis): For now, we rely on internal performance tests to // determine if filter data is being cached and reused. @@ -841,10 +922,14 @@ class FilterCacheTest : public OpsTestBase { // Compare output to expected results const Tensor& output_new = *GetOutput(0); - const Tensor& output_layout_new = *GetOutput(2); CommonTestUtilities conv_comp_new; - conv_comp_new.ConvertAndCompare(dtype, output_new, output_layout_new, - expected); + if (!NativeFormatEnabled()) { + const Tensor& output_layout_new = *GetOutput(2); + conv_comp_new.ConvertAndCompare(dtype, output_new, output_layout_new, + expected); + } else { + test::ExpectTensorEqual(expected, output_new); + } } }; @@ -926,39 +1011,61 @@ class MklFusedMatMulOpTest : public OpsTestBase { DataType dtype = DataTypeToEnum::v(); const int num_args = 1; - TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(num_args, dtype)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(num_args, DT_UINT8)) - .Attr("T", dtype) - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("num_args", num_args) - .Attr("fused_ops", fused_ops) - .Attr("epsilon", 0.0001) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Attr("T", dtype) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("num_args", num_args) + .Attr("fused_ops", fused_ops) + .Attr("epsilon", 0.0001) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK( + NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Attr("T", dtype) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("num_args", num_args) + .Attr("fused_ops", fused_ops) + .Attr("epsilon", 0.0001) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); AddInputFromArray(input.shape(), input.flat()); AddInputFromArray(weight.shape(), weight.flat()); AddInputFromArray(bias.shape(), bias.flat()); - // Add MKL meta input for input, filter and bias. - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); + if (!NativeFormatEnabled()) { + // Add MKL meta input for input, filter and bias. + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + } TF_ASSERT_OK(RunOpKernel()); const Tensor& output_tensor = *GetOutput(0); - const Tensor& output_meta_tensor = *GetOutput(1); - CommonTestUtilities test_util; - test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, - output); + if (!NativeFormatEnabled()) { + const Tensor& output_meta_tensor = *GetOutput(1); + CommonTestUtilities test_util; + test_util.PerformConversion(dtype, output_tensor, + output_meta_tensor, output); + } else { + *output = output_tensor; + } }; CommonTestUtilities::VerifyFusedMatrixClose(kInputChannel, kBatch, @@ -1033,21 +1140,36 @@ TEST_F(MklFusedMatMulCacheTest, WeightCached) { const int num_args = 1; const std::vector& fused_ops = {"BiasAdd"}; - TF_ASSERT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(num_args, DT_FLOAT)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(num_args, DT_UINT8)) - .Attr("T", DT_FLOAT) - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("num_args", num_args) - .Attr("fused_ops", fused_ops) - .Attr("epsilon", 0.0001) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_ASSERT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(num_args, DT_FLOAT)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Attr("T", DT_FLOAT) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("num_args", num_args) + .Attr("fused_ops", fused_ops) + .Attr("epsilon", 0.0001) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_ASSERT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(num_args, DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("num_args", num_args) + .Attr("fused_ops", fused_ops) + .Attr("epsilon", 0.0001) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); // The tensor shape of (1,3) is selected to allow the mkldnn expected @@ -1063,10 +1185,12 @@ TEST_F(MklFusedMatMulCacheTest, WeightCached) { {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); // Bias vector. AddInputFromArray(TensorShape({4}), {1, 2, 3, 4}); - // Add MKL meta input for input, filter and bias. - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); - AddInputFromArray(dummy_shape, dummy_tensor); + if (!NativeFormatEnabled()) { + // Add MKL meta input for input, filter and bias. + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + } int64 start_time = Env::Default()->NowMicros(); TF_ASSERT_OK(RunOpKernel()); @@ -1079,9 +1203,13 @@ TEST_F(MklFusedMatMulCacheTest, WeightCached) { test::FillValues(&expected, {75, 82, 89, 96}); const Tensor& output = *GetOutput(0); - const Tensor& mkl_shape_tensor = *GetOutput(1); CommonTestUtilities test_util; - test_util.ConvertAndCompare(DT_FLOAT, output, mkl_shape_tensor, expected); + if (!NativeFormatEnabled()) { + const Tensor& mkl_shape_tensor = *GetOutput(1); + test_util.ConvertAndCompare(DT_FLOAT, output, mkl_shape_tensor, expected); + } else { + test::ExpectTensorNear(expected, output, 1e-5); + } // Test for the second time to use the cached weight start_time = Env::Default()->NowMicros(); @@ -1097,9 +1225,13 @@ TEST_F(MklFusedMatMulCacheTest, WeightCached) { // Compare the result with expected result CommonTestUtilities test_util_new; const Tensor& output_new = *GetOutput(0); - const Tensor& mkl_shape_tensor_new = *GetOutput(1); - test_util_new.ConvertAndCompare(DT_FLOAT, output_new, mkl_shape_tensor_new, - expected); + if (!NativeFormatEnabled()) { + const Tensor& mkl_shape_tensor_new = *GetOutput(1); + test_util_new.ConvertAndCompare(DT_FLOAT, output_new, mkl_shape_tensor_new, + expected); + } else { + test::ExpectTensorNear(expected, output_new, 1e-5); + } } class BiasCacheTest : public OpsTestBase { @@ -1364,22 +1496,38 @@ class MklPadWithFusedConv2DOpTest : public OpsTestBase { &padding, {0, 0, padding_list_[0], padding_list_[1], padding_list_[2], padding_list_[3], 0, 0}); - TF_EXPECT_OK(NodeDefBuilder("pad_fused_conv_op", "_MklPadWithFusedConv2D") - .Input(FakeInput(dtype)) - .Input(FakeInput(dtype)) - .Input(FakeInput(num_args, dtype)) - .Input(FakeInput(DT_INT32)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Input(FakeInput(num_args, DT_UINT8)) - .Input(FakeInput(DT_UINT8)) - .Attr("T", dtype) - .Attr("num_args", num_args) - .Attr("strides", {1, stride, stride, 1}) - .Attr("padding", "VALID") - .Attr("fused_ops", fused_ops) - .Attr("_kernel", "MklLayoutDependentOp") - .Finalize(node_def())); + if (!NativeFormatEnabled()) { + TF_EXPECT_OK(NodeDefBuilder("pad_fused_conv_op", "_MklPadWithFusedConv2D") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Attr("T", dtype) + .Attr("num_args", num_args) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "VALID") + .Attr("fused_ops", fused_ops) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + } else { + TF_EXPECT_OK( + NodeDefBuilder("pad_fused_conv_op", "_MklNativePadWithFusedConv2D") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_INT32)) + .Attr("T", dtype) + .Attr("num_args", num_args) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "VALID") + .Attr("fused_ops", fused_ops) + .Attr("_kernel", "MklNameChangeOp") + .Finalize(node_def())); + } TF_EXPECT_OK(InitOp()); @@ -1388,19 +1536,25 @@ class MklPadWithFusedConv2DOpTest : public OpsTestBase { for (const Tensor& arg : args) AddInputFromArray(arg.shape(), arg.flat()); AddInputFromArray(padding.shape(), padding.flat()); - // Add MKL meta input for input, filter, pad and agrs. - for (int i = 0; i < args.size() + 3; ++i) - AddInputFromArray(dummy_shape, dummy_tensor); + if (!NativeFormatEnabled()) { + // Add MKL meta input for input, filter, pad and agrs. + for (int i = 0; i < args.size() + 3; ++i) + AddInputFromArray(dummy_shape, dummy_tensor); + } TF_ASSERT_OK(RunOpKernel()); // Compare output to expected results const Tensor& output_tensor = *GetOutput(0); - // Index 2 will need to be changed if the number of outputs produced - // by MklConv2D change. - const Tensor& output_meta_tensor = *GetOutput(2); CommonTestUtilities test_util; - test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, - output); + if (!NativeFormatEnabled()) { + // Index 2 will need to be changed if the number of outputs produced + // by MklConv2D change. + const Tensor& output_meta_tensor = *GetOutput(2); + test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, + output); + } else { + *output = output_tensor; + } } public: diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 1d087851ce3..905abbfeef2 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tensorflow { // Fuse Operation -template +template class MklFusedMatMulOp : public MklDnnMatMulOpBase { public: explicit MklFusedMatMulOp(OpKernelConstruction* ctx) @@ -58,8 +58,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { MklDnnShape src_mkl_shape; MklDnnShape weight_mkl_shape; - GetMklShape(ctx, this->kInputIndexSrc, &src_mkl_shape); - GetMklShape(ctx, this->kInputIndexWeight, &weight_mkl_shape); + GetMklShape(ctx, this->kInputIndexSrc, &src_mkl_shape, native_format); + GetMklShape(ctx, this->kInputIndexWeight, &weight_mkl_shape, native_format); OP_REQUIRES(ctx, !weight_mkl_shape.IsMklTensor(), errors::InvalidArgument("Weight should not be in MKL Layout")); @@ -134,7 +134,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { MklDnnShape dst_mkl_shape; dst_mkl_shape.SetMklTensor(false); AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape, - dst_mkl_shape); + dst_mkl_shape, native_format); } // if there's nothing to compute, just return. @@ -243,13 +243,18 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { }; // Register mkl kernels for supported operations and types. -#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklFusedMatMul") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedMatMulOp); +#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklFusedMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedMatMulOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeFusedMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklFusedMatMulOp); TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_bfloat16(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index e084b25f737..b77d033c9de 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -35,12 +35,6 @@ using mkldnn::stream; namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -#ifdef INTEL_MKL_DNN_ONLY -// Temporarily copying some definitions from mkl_cblas.h so the same code can -// be used when calling oneDNN or CBLAS batchmatmul in mkl_batch_matmul_op.cc. -typedef enum { CblasRowMajor, CblasColumnMajor } CBLAS_LAYOUT; -#define MKL_INT int -#endif // This structure aggregates multiple inputs to MklDnnMatMul* methods. struct MklDnnMatMulFwdParams { @@ -729,67 +723,6 @@ class MklMatMulPrimitiveFactory : public MklPrimitiveFactory { } }; -template -void dnnl_gemm_batch(const std::vector& transa, - const std::vector& transb, const std::vector& m, - const std::vector& n, const std::vector& k, - const std::vector& alpha, const T* a, const T* b, - const std::vector& beta, T* c, - const int group_count, const std::vector& group_size, - OpKernelContext* ctx = nullptr) { - // Current BatchMatMul support in Tensorflow is narrower than the one offered - // by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1 - // group of size equal to batch_size, and all MatMul parameters (m, n, k, - // alpha, beta) within that group are same. - DCHECK(group_size.size() == 1); - DCHECK(transa.size() == group_size[0]); - DCHECK(transb.size() == group_size[0]); - DCHECK(alpha.size() == group_size[0]); - DCHECK(beta.size() == group_size[0]); - DCHECK(m.size() == group_size[0]); - DCHECK(n.size() == group_size[0]); - DCHECK(k.size() == group_size[0]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(transa[0] == transa[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(transb[0] == transb[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(alpha[0] == alpha[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) - DCHECK(beta[0] == beta[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(m[0] == m[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(n[0] == n[idx]); - for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(k[0] == k[idx]); - - using dims = mkldnn::memory::dims; - // Prepare strides based on the transa and transb flags: transposed - // matrices have strides swapped BatchMatMul in MKL-DNN supports 3D metrices - // so far. That is why strides are 3D also. - dims a_sizes = dims{group_size[0], m[0], k[0]}; - dims b_sizes = dims{group_size[0], k[0], n[0]}; - dims c_sizes = dims{group_size[0], m[0], n[0]}; - dims a_strides = - !transa[0] ? dims{m[0] * k[0], k[0], 1} : dims{k[0] * m[0], 1, m[0]}; - dims b_strides = - !transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]}; - dims c_strides = dims{m[0] * n[0], n[0], 1}; - - // MklMatMul uses const alpha and beta, make guarantee here to ensure - // they are never changed. - DCHECK_EQ(alpha, 1.0f); - DCHECK_EQ(beta, 0.f); - - MklMatMulParams params(a_sizes, b_sizes, c_sizes, a_strides, b_strides, - c_strides); - MklMatMulPrimitive* matmul_prim = - MklMatMulPrimitiveFactory::Get(params, 0); - - // Execute matmul primitive. - std::shared_ptr cpu_stream; - cpu_stream.reset(CreateStream(ctx, matmul_prim->GetEngine())); - matmul_prim->Execute(a, b, c, cpu_stream); -} - template void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, const T* a, int64_t lda, const T* b, int64_t ldb, diff --git a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc index ca7ebd7fd12..441d5f5099f 100644 --- a/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_maxpooling_op.cc @@ -44,7 +44,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; // An implementation of MaxPooling (forward). -template +template class MklMaxPoolingOp : public MklPoolingForwardOpBase { public: explicit MklMaxPoolingOp(OpKernelConstruction* context) @@ -52,6 +52,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { // In Max Pooling, MKL-DNN does not allow passing workspace as nullptr. // So we set workspace_enabled_ to true. this->workspace_enabled_ = true; + this->native_format_ = native_format; } void Compute(OpKernelContext* context) override { @@ -59,7 +60,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { const Tensor& input_tensor = MklGetInput(context, this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; - GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); + GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input, + this->native_format_); this->SanityCheckInput(context, input_tensor, dnn_shape_input); if (!context->status().ok()) return; @@ -143,7 +145,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, pooling_prop_kind, - static_cast(this->data_format_mkldnn_), input_md); + static_cast(this->data_format_mkldnn_), input_md, + this->native_format_); #else MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, @@ -229,7 +232,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { workspace_tf_shape.AddDim(workspace_bytes); AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace, &workspace_tensor, workspace_tf_shape, - workspace_mkl_shape); + workspace_mkl_shape, this->native_format_); DCHECK(workspace_tensor); dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); } @@ -241,11 +244,13 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { // - The original output tensor // - Backprop tensor for output // It produces one output: backprop tensor for input. -template +template class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { public: explicit MklMaxPoolingGradOp(OpKernelConstruction* context) - : MklPoolingBackwardOpBase(context) {} + : MklPoolingBackwardOpBase(context) { + this->native_format_ = native_format; + } void Compute(OpKernelContext* context) override { try { const Tensor& orig_input_tensor = @@ -255,8 +260,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { const Tensor& workspace_tensor = MklGetInput(context, kInputTensorIndexWorkspace); MklDnnShape orig_input_mkl_shape, grad_mkl_shape; - GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape); - GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape); + GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape, + this->native_format_); + GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape, + this->native_format_); if (!context->status().ok()) return; MklDnnData grad_dnn_data(&cpu_engine_); @@ -312,7 +319,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, prop_kind::forward_training, - static_cast(this->data_format_mkldnn_), src_md); + static_cast(this->data_format_mkldnn_), src_md, + this->native_format_); #else MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, @@ -335,7 +343,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { std::shared_ptr pooling_bwd_pd = pooling_bwd->GetPoolingBwdPd(); T* diff_dst_data = nullptr; - if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd, + if (!this->native_format_ && + IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd, pooling_bwd)) { grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); grad_dnn_data.CheckReorderToOpMem( @@ -389,36 +398,56 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { engine cpu_engine_ = engine(ENGINE_CPU, 0); }; // MklMaxPoolingGradOp -#define REGISTER_MKL_MAXPOOL3D_KERNELS(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklMaxPool3D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklMaxPoolingOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklMaxPool3DGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklMaxPoolingGradOp); +#define REGISTER_MKL_MAXPOOL3D_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklMaxPool3D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklMaxPoolingOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklMaxPool3DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklMaxPoolingGradOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPool3D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklMaxPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPool3DGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklMaxPoolingGradOp); TF_CALL_float(REGISTER_MKL_MAXPOOL3D_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL3D_KERNELS); -#define REGISTER_MKL_MAXPOOL_KERNELS(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklMaxPool") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklMaxPoolingOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("_MklMaxPoolGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklMaxPoolingGradOp); +#define REGISTER_MKL_MAXPOOL_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklMaxPool") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklMaxPoolingOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklMaxPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklMaxPoolingGradOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPool") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklMaxPoolingOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPoolGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklNameChangeOpLabel), \ + MklMaxPoolingGradOp); TF_CALL_float(REGISTER_MKL_MAXPOOL_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL_KERNELS); diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc index 9824fabce0e..2b18404d1cf 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.cc @@ -54,8 +54,9 @@ void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { #else context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); #endif // !ENABLE_MKLDNN_V1 - context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType(), - MEMORY_FORMAT::any)); + context_.dst_md.reset(new memory::desc( + {fwdParams.dst_dims}, MklDnnType(), + fwdParams.native_format ? fwdParams.src_format : MEMORY_FORMAT::any)); #ifndef ENABLE_MKLDNN_V1 // Create a pooling descriptor. @@ -187,8 +188,9 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { {bwdParams.dst_dims}, MklDnnType(), bwdParams.src_format)); #else context_.src_md.reset(new memory::desc(bwdParams.src_md.data)); - context_.dst_md.reset(new memory::desc({bwdParams.dst_dims}, MklDnnType(), - MEMORY_FORMAT::any)); + context_.dst_md.reset(new memory::desc( + {bwdParams.dst_dims}, MklDnnType(), + bwdParams.native_format ? bwdParams.src_format : MEMORY_FORMAT::any)); #endif // !ENABLE_MKLDNN_V1 #ifndef ENABLE_MKLDNN_V1 diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h index 3a608a66c16..7034c4a224b 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h @@ -49,12 +49,14 @@ struct MklPoolingParams { mkldnn::prop_kind prop_kind; MEMORY_FORMAT src_format; memory::desc src_md; + bool native_format; MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, memory::dims filter_dims, memory::dims strides, memory::dims padding_left, memory::dims padding_right, mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind, - MEMORY_FORMAT src_format, memory::desc src_md) + MEMORY_FORMAT src_format, memory::desc src_md, + bool native_format) : src_dims(src_dims), dst_dims(dst_dims), filter_dims(filter_dims), @@ -64,7 +66,8 @@ struct MklPoolingParams { alg_kind(alg_kind), prop_kind(prop_kind), src_format(src_format), - src_md(src_md) {} + src_md(src_md), + native_format(native_format) {} }; template @@ -583,7 +586,8 @@ class MklPoolingOpBase : public OpKernel { output_tf_shape = MklDnnDimsToTFShape(output_dims_order); } AllocateOutputSetMklShape(context, kOutputIndex, output_tensor, - output_tf_shape, output_mkl_shape); + output_tf_shape, output_mkl_shape, + native_format_); DCHECK(output_tensor); } @@ -608,6 +612,7 @@ class MklPoolingOpBase : public OpKernel { // Either memory::format (MKL-DNN v-0.x) or memory::format_tag (MKL-DNN v-1.x) MEMORY_FORMAT data_format_mkldnn_; bool workspace_enabled_; + bool native_format_ = false; }; template @@ -671,8 +676,13 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase { output_dims_mkl_order, output_tf_format); // Only allocate enough space for the elements we need. output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); + + if (this->native_format_) { + output_tf_shape = output_mkl_shape.GetTfShape(); + } AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, - output_tf_shape, output_mkl_shape); + output_tf_shape, output_mkl_shape, + this->native_format_); DCHECK(*output_tensor); } @@ -719,8 +729,12 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase { TensorShape output_tf_shape; output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); + if (this->native_format_) { + output_tf_shape = output_mkl_shape.GetTfShape(); + } AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, - output_tf_shape, output_mkl_shape); + output_tf_shape, output_mkl_shape, + this->native_format_); DCHECK(*output_tensor); } }; diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 1c9956c5ffb..1d18e41d3ef 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -4,6 +4,7 @@ load( "//tensorflow/core/kernels/mlir_generated:build_defs.bzl", "gen_kernel_library", "if_mlir_generated_gpu_kernels_enabled", + "if_mlir_unranked_kernels_enabled", ) load( "//tensorflow:tensorflow.bzl", @@ -18,7 +19,6 @@ load( package( default_visibility = [ - "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", "//tensorflow/core/kernels:__subpackages__", ], licenses = ["notice"], # Apache 2.0 @@ -31,29 +31,60 @@ config_setting( }, ) +config_setting( + name = "mlir_use_unranked_kernels", + define_values = {"enable_unranked_kernels": "1"}, +) + +filegroup( + name = "kernel_srcs", + srcs = if_mlir_unranked_kernels_enabled( + [ + "unranked_op_gpu_abs.cc", + "unranked_op_gpu_tanh.cc", + "unranked_op_gpu_base.h", + "unranked_op_gpu_base.cc", + ], + [ + "cwise_op_gpu_abs.cc", + "cwise_op_gpu_base.cc", + "cwise_op_gpu_base.h", + "cwise_op_gpu_tanh.cc", + ], + ), +) + +cc_library( + name = "kernel_deps", + deps = if_mlir_unranked_kernels_enabled( + [ + ":abs_unranked_kernels", + ":tanh_unranked_kernels", + "//tensorflow/compiler/mlir/tools/kernel_gen:tf_cuda_runtime_wrappers", + "//tensorflow/compiler/mlir/tools/kernel_gen:tf_framework_c_interface", + ], + [ + ":abs_kernels", + ":tanh_kernels", + ], + ), +) + tf_kernel_library( name = "cwise_unary_op", # Technically these source files don't need --config=cuda or --config=rocm, # but we want to avoid building them if they are not needed. - srcs = if_cuda_or_rocm([ - "cwise_op_gpu_abs.cc", - "cwise_op_gpu_base.cc", - "cwise_op_gpu_base.h", - "cwise_op_gpu_tanh.cc", - ]), - # Building abs_kernels and tanh_kernels requires hexdump on the ROCM - # platform, which might not be available. + srcs = if_cuda_or_rocm([":kernel_srcs"]), tags = ["no_rocm"], deps = if_cuda_or_rocm([ - ":abs_kernels", - ":tanh_kernels", + ":kernel_deps", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", ]), ) @@ -69,8 +100,6 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core/common_runtime:device", - "//tensorflow/core/common_runtime:device_factory", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:ops_testutil", ], @@ -88,8 +117,6 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core/common_runtime:device", - "//tensorflow/core/common_runtime:device_factory", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:ops_testutil", ], @@ -119,20 +146,9 @@ tf_cuda_cc_test( # ], # ) -gen_kernel_library( - name = "tanh", - same_shape = "0,1", - tile_size = "256", - types = [ - "f16", - "f32", - "f64", - ], - unroll_factors = "4", -) - gen_kernel_library( name = "abs", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -147,6 +163,7 @@ gen_kernel_library( gen_kernel_library( name = "ceil", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -157,8 +174,20 @@ gen_kernel_library( unroll_factors = "4", ) +gen_kernel_library( + name = "conj", + same_shape = "0,1", + tile_size = "256", + types = [ + "c64", + "c128", + ], + unroll_factors = "4", +) + gen_kernel_library( name = "cos", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -171,6 +200,7 @@ gen_kernel_library( gen_kernel_library( name = "exp", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -183,6 +213,44 @@ gen_kernel_library( gen_kernel_library( name = "floor", + generate_unranked = True, + same_shape = "0,1", + tile_size = "256", + types = [ + "f16", + "f32", + "f64", + ], + unroll_factors = "4", +) + +gen_kernel_library( + name = "imag", + tile_size = "256", + types = [ + "f32", + "f64", + ], + unroll_factors = "4", +) + +gen_kernel_library( + name = "invert", + generate_unranked = True, + same_shape = "0,1", + tile_size = "256", + types = [ + "i8", + "i16", + "i32", + "i64", + ], + unroll_factors = "4", +) + +gen_kernel_library( + name = "isfinite", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -195,6 +263,7 @@ gen_kernel_library( gen_kernel_library( name = "log", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -205,8 +274,18 @@ gen_kernel_library( unroll_factors = "4", ) +gen_kernel_library( + name = "logicalnot", + generate_unranked = True, + same_shape = "0,1", + tile_size = "256", + types = ["i1"], + unroll_factors = "4", +) + gen_kernel_library( name = "neg", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -217,8 +296,19 @@ gen_kernel_library( unroll_factors = "4", ) +gen_kernel_library( + name = "real", + tile_size = "256", + types = [ + "f32", + "f64", + ], + unroll_factors = "4", +) + gen_kernel_library( name = "rsqrt", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ @@ -229,8 +319,38 @@ gen_kernel_library( unroll_factors = "4", ) +gen_kernel_library( + name = "sign", + generate_unranked = True, + same_shape = "0,1", + tile_size = "256", + types = [ + # TODO(b/162577610): Add bf16, c64 and c128. + "f16", + "f32", + "f64", + "i32", + "i64", + ], + unroll_factors = "4", +) + gen_kernel_library( name = "sqrt", + generate_unranked = True, + same_shape = "0,1", + tile_size = "256", + types = [ + "f16", + "f32", + "f64", + ], + unroll_factors = "4", +) + +gen_kernel_library( + name = "tanh", + generate_unranked = True, same_shape = "0,1", tile_size = "256", types = [ diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index 5dc6fa527c0..5b4daac8820 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -6,6 +6,10 @@ load( "rocm_gpu_architectures", "rocm_is_configured", ) +load( + "//tensorflow/core/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) load( "//tensorflow/stream_executor:build_defs.bzl", "if_gpu_is_configured", @@ -43,7 +47,7 @@ def _gen_kernel_gpu_bin_impl(ctx): gpu_bins = [] for arch in ctx.attr.gpu_archs: - # TODO(b/152737872): 'compute_' should generate both SASS and PTX. + # TODO(b/170283783): 'compute_' should generate both SASS and PTX. arch = arch.replace("compute_", "sm_") filename = "%s.%s.bin" % (name, arch) gpu_bin = ctx.actions.declare_file(filename) @@ -53,8 +57,7 @@ def _gen_kernel_gpu_bin_impl(ctx): executable = ctx.executable._tool, arguments = cmd_args + [ "--tile_sizes=%s" % tile_sizes, - # For ROCM, remove the "gfx" prefix. For CUDA, remove the "sm_" prefix. - "--arch=%s" % arch[3:], + "--arch=%s" % arch, "--input=%s" % ctx.file.mlir_op.path, "--output=%s" % gpu_bin.path, ], @@ -151,12 +154,14 @@ def _gen_kernel_image_hdr_impl_rocm(ctx): outputs = [ctx.outputs.out], inputs = [fatbin], command = ( - ("echo 'static const unsigned char %s[] = {' > %s && " + - "hexdump -v -e \'/1 \"0x%%02x, \"\' %s | cat >> %s && " + + ("hex=`hexdump -v -e \'/1 \"0x%%02x, \"\' %s` && " + + "len=`echo $hex | wc -c` && " + + "echo 'static const unsigned char %s['$len' + 1] = {' > %s && " + + "echo $hex | cat >> %s && " + "echo '};' >> %s") % ( + fatbin.path, ctx.attr.symbol, ctx.outputs.out.path, - fatbin.path, ctx.outputs.out.path, ctx.outputs.out.path, ) @@ -195,13 +200,22 @@ def _gen_kernel_image_hdr(name, mlir_op, gpu_archs, tile_size, same_shape = None ) def _gen_mlir_op_impl(ctx): - ctx.actions.run_shell( + # In order to generate a ranked kernel we change *xelem_type to ?xelem_type + # and remove element type from the entry function name. + convert_to_ranked = "" + if ctx.attr.unranked == False: + convert_to_ranked = "sed s/*x/?x/g | sed s/_elem_type//g |" + cmd = ctx.actions.run_shell( inputs = [ctx.file.template], outputs = [ctx.outputs.out], - command = "cat %s | sed s/elem_type/%s/g > %s" % ( - ctx.file.template.path, - ctx.attr.type, - ctx.outputs.out.path, + command = ( + ("cat %s | %s sed s/elem_type/%s/g | sed 's/c64/complex/g'" + + " | sed 's/c128/complex/g' > %s") % ( + ctx.file.template.path, + convert_to_ranked, + ctx.attr.type, + ctx.outputs.out.path, + ) ), ) @@ -212,18 +226,21 @@ _gen_mlir_op_rule = rule( "template": attr.label(mandatory = True, allow_single_file = True), "type": attr.string(mandatory = True), "out": attr.output(mandatory = True), + "unranked": attr.bool(mandatory = True), }, ) -def _gen_mlir_op(name, type): +def _gen_mlir_op(name, type, unranked): + tmpl_name = name.replace("_unranked", "") if unranked else name _gen_mlir_op_rule( name = "generate_{name}_{type}_mlir".format(name = name, type = type), - template = "op_definitions/{name}.mlir.tmpl".format(name = name), + template = "op_definitions/{name}.mlir.tmpl".format(name = tmpl_name), type = type, out = "{name}_{type}.mlir".format(name = name, type = type), + unranked = unranked, ) -def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unroll_factors = None, extra_args = []): +def gen_ranked_kernel_library(name, types, tile_size, tags = [], same_shape = None, unroll_factors = None, extra_args = []): """ Generate a library with kernels for a specific tensorflow op. Args: @@ -241,6 +258,7 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr _gen_mlir_op( name = name, type = type, + unranked = False, ) _gen_kernel_image_hdr( name = "{name}_{type}_kernel".format(name = name, type = type), @@ -257,3 +275,115 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr hdrs = if_gpu_is_configured([":{name}_{type}_kernel".format(name = name, type = type) for type in types]), tags = tags, ) + +################################################################################ +# Unranked kernels build rules. +################################################################################ + +def if_mlir_unranked_kernels_enabled(if_true, if_false = []): + return select({ + "//tensorflow/core/kernels/mlir_generated:mlir_use_unranked_kernels": if_true, + "//conditions:default": if_false, + }) + +def _gen_unranked_kernel_fatbin_impl(ctx): + name = ctx.attr.name + cmd_args = [] + if ctx.attr.unroll_factors: + cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors) + if ctx.attr.extra_args: + cmd_args.extend(ctx.attr.extra_args) + tile_sizes = ctx.attr.tile_size.replace("x", ",") + arch_flag = ",".join(ctx.attr.gpu_archs) + gpu_bin = ctx.outputs.output + ctx.actions.run( + inputs = [ctx.file.mlir_op], + outputs = [gpu_bin], + executable = ctx.executable._tool, + arguments = cmd_args + [ + "--tile_sizes=%s" % tile_sizes, + "--arch=%s" % arch_flag, + "--input=%s" % ctx.file.mlir_op.path, + "--output=%s" % gpu_bin.path, + ], + mnemonic = "compile", + ) + +_gen_unranked_kernel_fatbin_rule = rule( + attrs = { + "mlir_op": attr.label(mandatory = True, allow_single_file = True), + "output": attr.output(mandatory = True, doc = "The generated file"), + "tile_size": attr.string(mandatory = True), + "unroll_factors": attr.string(), + "gpu_archs": attr.string_list(mandatory = True), + "extra_args": attr.string_list(), + "_tool": attr.label( + executable = True, + default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_kernel"), + cfg = "host", + ), + }, + output_to_genfiles = True, + implementation = _gen_unranked_kernel_fatbin_impl, +) + +def gen_unranked_kernel_library(name, types, tile_size, tags = [], unroll_factors = None, extra_args = []): + """ Generate a library with unranked kernels for a specific tensorflow op. + + Args: + name: The name of the tensorflow op. + types: The types ("f16", "f32", "f64") for which a kernel should be generated. + tile_size: The tiling specification, e.g. "16x16". + unroll_factors: The unrolling specification, e.g. "4,4" + tags: The tags which should be added to the library. + extra_args: Extra arguments to pass to the generator tool. + """ + + if cuda_gpu_architectures(): + for type in types: + _gen_mlir_op( + name = name, + type = type, + unranked = True, + ) + _gen_unranked_kernel_fatbin_rule( + name = "{name}_{type}_kernel_generator".format(name = name, type = type), + mlir_op = "{name}_{type}.mlir".format(name = name, type = type), + output = "{name}_{type}.a".format(name = name, type = type), + gpu_archs = cuda_gpu_architectures(), + tile_size = tile_size, + unroll_factors = unroll_factors, + extra_args = extra_args, + ) + native.cc_import( + name = "{name}_{type}_kernel".format(name = name, type = type), + static_library = "{name}_{type}.a".format(name = name, type = type), + ) + + native.cc_library( + name = name + "_kernels", + deps = if_cuda_is_configured([":{name}_{type}_kernel".format(name = name, type = type) for type in types]), + linkstatic = 1, + tags = tags, + ) + +def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unroll_factors = None, extra_args = [], generate_ranked = True, generate_unranked = False): + if (generate_ranked): + gen_ranked_kernel_library( + name = name, + types = types, + tile_size = tile_size, + tags = tags, + same_shape = same_shape, + unroll_factors = unroll_factors, + extra_args = extra_args, + ) + if (generate_unranked): + gen_unranked_kernel_library( + name = name + "_unranked", + types = types, + tile_size = tile_size, + tags = tags, + unroll_factors = unroll_factors, + extra_args = extra_args, + ) diff --git a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h index 4905d21c299..995aa5390e4 100644 --- a/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h +++ b/tensorflow/core/kernels/mlir_generated/cwise_op_gpu_base.h @@ -43,14 +43,12 @@ class MlirGeneratedUnaryOp : public OpKernel { absl::Mutex mu_; }; -#define GENERATE_OP_KERNEL_BASE(kernel_name) \ - class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \ - public: \ - MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \ - absl::Span cubin_data) \ - : MlirGeneratedUnaryOp(ctx, \ - absl::AsciiStrToLower(#kernel_name "_kernel"), \ - cubin_data) {} \ +#define GENERATE_OP_KERNEL_BASE(kernel_name) \ + class MlirGenerated##kernel_name##Op : public MlirGeneratedUnaryOp { \ + public: \ + MlirGenerated##kernel_name##Op(OpKernelConstruction* ctx, \ + absl::Span cubin_data) \ + : MlirGeneratedUnaryOp(ctx, #kernel_name "_kernel", cubin_data) {} \ }; #define GENERATE_OP_KERNEL_FOR(kernel_name, data_type) \ diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl index d4c9bd5eaed..c0dcd77a9e7 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/abs.mlir.tmpl @@ -1,5 +1,5 @@ -func @abs(%arg0: tensor) -> tensor { - %0 = "tf.Abs"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Abs_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Abs"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/acos.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/acos.mlir.tmpl new file mode 100644 index 00000000000..4727685d597 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/acos.mlir.tmpl @@ -0,0 +1,5 @@ +func @Acos_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Acos"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/acosh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/acosh.mlir.tmpl new file mode 100644 index 00000000000..7d2fc844f7d --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/acosh.mlir.tmpl @@ -0,0 +1,5 @@ +func @Acosh_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Acosh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl new file mode 100644 index 00000000000..8605741dfea --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/angle.mlir.tmpl @@ -0,0 +1,4 @@ +func @Angle(%arg0: tensor>) -> tensor { + %0 = "tf.Angle"(%arg0) : (tensor>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl new file mode 100644 index 00000000000..741ca1b145c --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/asin.mlir.tmpl @@ -0,0 +1,4 @@ +func @Asin(%arg0: tensor) -> tensor { + %0 = "tf.Asin"(%arg0) : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl new file mode 100644 index 00000000000..80e22f38dbe --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atan.mlir.tmpl @@ -0,0 +1,4 @@ +func @Atan(%arg0: tensor) -> tensor { + %0 = "tf.Atan"(%arg0) : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl index f58b6c0c1cb..210df66516c 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/bias_add.mlir.tmpl @@ -1,6 +1,6 @@ -func @bias_add(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "tf.BiasAdd"(%arg0, %arg1) { } - : (tensor, tensor) -> tensor +func @BiasAdd(%arg0: tensor, %arg1: tensor) + -> tensor { + %0 = "tf.BiasAdd"(%arg0, %arg1) + : (tensor, tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl index a02dd3bb14f..8eac1b6a602 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/ceil.mlir.tmpl @@ -1,5 +1,5 @@ -func @ceil(%arg0: tensor) -> tensor { - %0 = "tf.Ceil"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Ceil_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Ceil"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl new file mode 100644 index 00000000000..963a0740c6f --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/conj.mlir.tmpl @@ -0,0 +1,4 @@ +func @Conj(%arg0: tensor) -> tensor { + %0 = "tf.Conj"(%arg0) : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl index 3578ecc0c35..297cc82e2b0 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/cos.mlir.tmpl @@ -1,5 +1,5 @@ -func @cos(%arg0: tensor) -> tensor { - %0 = "tf.Cos"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Cos_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Cos"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/cosh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/cosh.mlir.tmpl new file mode 100644 index 00000000000..937e8d21ee6 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/cosh.mlir.tmpl @@ -0,0 +1,5 @@ +func @Cosh_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Cosh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/digamma.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/digamma.mlir.tmpl new file mode 100644 index 00000000000..3a14e107a51 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/digamma.mlir.tmpl @@ -0,0 +1,5 @@ +func @Digamma_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Digamma"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/erf.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/erf.mlir.tmpl new file mode 100644 index 00000000000..3299dedf434 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/erf.mlir.tmpl @@ -0,0 +1,5 @@ +func @Erf_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Erf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/erfc.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/erfc.mlir.tmpl new file mode 100644 index 00000000000..f42cf6da336 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/erfc.mlir.tmpl @@ -0,0 +1,5 @@ +func @Erfc_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Erfc"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl index e4f450692b9..e9725c8b174 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/exp.mlir.tmpl @@ -1,5 +1,5 @@ -func @exp(%arg0: tensor) -> tensor { - %0 = "tf.Exp"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Exp_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Exp"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/expm1.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/expm1.mlir.tmpl new file mode 100644 index 00000000000..36abfa3e019 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/expm1.mlir.tmpl @@ -0,0 +1,5 @@ +func @Expm1_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Expm1"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl index c841cfb6501..fcd5ab74f5b 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/floor.mlir.tmpl @@ -1,5 +1,5 @@ -func @floor(%arg0: tensor) -> tensor { - %0 = "tf.Floor"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Floor_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Floor"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl new file mode 100644 index 00000000000..c68c85a798c --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/imag.mlir.tmpl @@ -0,0 +1,4 @@ +func @Imag(%arg0: tensor>) -> tensor { + %0 = "tf.Imag"(%arg0) : (tensor>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/invert.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/invert.mlir.tmpl new file mode 100644 index 00000000000..28287d34832 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/invert.mlir.tmpl @@ -0,0 +1,5 @@ +func @Invert_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Invert"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/isfinite.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/isfinite.mlir.tmpl new file mode 100644 index 00000000000..73784adec09 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/isfinite.mlir.tmpl @@ -0,0 +1,5 @@ +func @Isfinite_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xi1> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.IsFinite"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/isinf.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/isinf.mlir.tmpl new file mode 100644 index 00000000000..b8477acca4e --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/isinf.mlir.tmpl @@ -0,0 +1,5 @@ +func @Isinf_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.IsInf"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/isnan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/isnan.mlir.tmpl new file mode 100644 index 00000000000..ac0c09d22d4 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/isnan.mlir.tmpl @@ -0,0 +1,5 @@ +func @Isnan_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.IsNan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xi1> + return %0 : tensor<*xi1> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/lgamma.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/lgamma.mlir.tmpl new file mode 100644 index 00000000000..ac3e44db23f --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/lgamma.mlir.tmpl @@ -0,0 +1,5 @@ +func @Lgamma_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Lgamma"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl index 71e5b3ef507..639e7ed5ffe 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/log.mlir.tmpl @@ -1,5 +1,5 @@ -func @log(%arg0: tensor) -> tensor { - %0 = "tf.Log"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Log_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Log"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/log1p.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/log1p.mlir.tmpl new file mode 100644 index 00000000000..6604d5f763a --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/log1p.mlir.tmpl @@ -0,0 +1,5 @@ +func @Log1p_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Log1p"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/logicalnot.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/logicalnot.mlir.tmpl new file mode 100644 index 00000000000..74750fd4749 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/logicalnot.mlir.tmpl @@ -0,0 +1,5 @@ +func @Logicalnot_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.LogicalNot"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl index bdc6d929239..fbf0d4fbfb5 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/neg.mlir.tmpl @@ -1,5 +1,5 @@ -func @neg(%arg0: tensor) -> tensor { - %0 = "tf.Neg"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Neg_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Neg"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl new file mode 100644 index 00000000000..600fbe563b8 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/real.mlir.tmpl @@ -0,0 +1,4 @@ +func @Real(%arg0: tensor>) -> tensor { + %0 = "tf.Real"(%arg0) : (tensor>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal.mlir.tmpl new file mode 100644 index 00000000000..9548f4b2bc2 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/reciprocal.mlir.tmpl @@ -0,0 +1,5 @@ +func @Reciprocal_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl index 7e7082ff295..4aaacf33df0 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/relu.mlir.tmpl @@ -1,5 +1,5 @@ -func @relu(%arg0: tensor) -> tensor { - %0 = "tf.Relu"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Relu_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Relu"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/rint.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/rint.mlir.tmpl new file mode 100644 index 00000000000..174ebc2210e --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/rint.mlir.tmpl @@ -0,0 +1,5 @@ +func @Rint_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Rint"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/round.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/round.mlir.tmpl new file mode 100644 index 00000000000..a5d9b79841a --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/round.mlir.tmpl @@ -0,0 +1,5 @@ +func @Round_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Round"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl index 3c456778092..e2a3bf8a5df 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/rsqrt.mlir.tmpl @@ -1,5 +1,5 @@ -func @rsqrt(%arg0: tensor) -> tensor { - %0 = "tf.Rsqrt"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Rsqrt_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Rsqrt"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sigmoid.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sigmoid.mlir.tmpl new file mode 100644 index 00000000000..7969fffdbce --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sigmoid.mlir.tmpl @@ -0,0 +1,5 @@ +func @Sigmoid_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Sigmoid"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sign.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sign.mlir.tmpl new file mode 100644 index 00000000000..8642c8b0b3e --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sign.mlir.tmpl @@ -0,0 +1,5 @@ +func @Sign_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Sign"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl new file mode 100644 index 00000000000..0e739740f93 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sin.mlir.tmpl @@ -0,0 +1,4 @@ +func @Sin(%arg0: tensor) -> tensor { + %0 = "tf.Sin"(%arg0) : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sinh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sinh.mlir.tmpl new file mode 100644 index 00000000000..4b110bae313 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sinh.mlir.tmpl @@ -0,0 +1,5 @@ +func @Sinh_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Sinh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl index 9d54326e032..2a2c1ce1d05 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/sqrt.mlir.tmpl @@ -1,5 +1,5 @@ -func @sqrt(%arg0: tensor) -> tensor { - %0 = "tf.Sqrt"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Sqrt_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Sqrt"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/square.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/square.mlir.tmpl new file mode 100644 index 00000000000..45bdc13e182 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/square.mlir.tmpl @@ -0,0 +1,5 @@ +func @Square_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Square"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/tan.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/tan.mlir.tmpl new file mode 100644 index 00000000000..1913e755b36 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/tan.mlir.tmpl @@ -0,0 +1,5 @@ +func @Tan_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Tan"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> +} diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl index 58a9b61ef68..694e4ce10a4 100644 --- a/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl +++ b/tensorflow/core/kernels/mlir_generated/op_definitions/tanh.mlir.tmpl @@ -1,5 +1,5 @@ -func @tanh(%arg0: tensor) -> tensor { - %0 = "tf.Tanh"(%arg0) { } - : (tensor) -> tensor - return %0 : tensor +func @Tanh_elem_type(%arg0: tensor<*xelem_type>) + -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} { + %0 = "tf.Tanh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type> + return %0 : tensor<*xelem_type> } diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_abs.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_abs.cc new file mode 100644 index 00000000000..586d73171f6 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_abs.cc @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h" + +namespace tensorflow { + +REGISTER_AND_GENERATE_KERNEL(Abs, f16, DT_HALF, Eigen::half); +REGISTER_AND_GENERATE_KERNEL(Abs, f32, DT_FLOAT, float); +REGISTER_AND_GENERATE_KERNEL(Abs, f64, DT_DOUBLE, double); +REGISTER_AND_GENERATE_KERNEL(Abs, i32, DT_INT32, int32); +REGISTER_AND_GENERATE_KERNEL(Abs, i64, DT_INT64, int64); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.cc new file mode 100644 index 00000000000..bc4a36facd3 --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h" + +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace { + +// A simple TensorBuffer implementation that allows us to create Tensors that +// take ownership of pre-allocated memory. +class MlirTensorBuffer : public TensorBuffer { + public: + MlirTensorBuffer(const void* ptr, size_t size, Allocator* allocator) + : TensorBuffer(const_cast(ptr)), + size_(size), + allocator_(allocator) {} + + ~MlirTensorBuffer() override { + if (data()) { + allocator_->DeallocateRaw(data()); + } + } + + size_t size() const override { return size_; } + + TensorBuffer* root_buffer() override { return this; } + + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_allocated_bytes(size_); + } + + private: + size_t size_; + Allocator* allocator_; +}; + +} // namespace + +TensorBuffer* GetMlirTensorBuffer(const void* ptr, size_t size, + Allocator* allocator) { + return new MlirTensorBuffer(ptr, size, allocator); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h new file mode 100644 index 00000000000..457658948ed --- /dev/null +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_UNRANKED_OP_GPU_ABS_H_ +#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_UNRANKED_OP_GPU_ABS_H_ + +#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// Returns a pointer to an allocated MlirTensorBuffer that takes ownership of +// pre-allocated memory. +TensorBuffer* GetMlirTensorBuffer(const void* ptr, size_t size, + Allocator* allocator); + +template +::UnrankedMemRefType ConvertTensorToDescriptor(const Tensor& tensor) { + ::UnrankedMemRefType result; + result.rank = tensor.dims(); + result.descriptor = malloc(sizeof(void*) * (2 * result.rank + 3)); + + // Fill the descriptor. + void** pointers = static_cast(result.descriptor); + pointers[0] = tensor.data(); + pointers[1] = tensor.data(); + intptr_t* int_pointers = static_cast(result.descriptor); + int_pointers[2] = 0; + // Fill size. + for (int i = 0; i < result.rank; ++i) { + int_pointers[3 + i] = tensor.dim_size(i); + } + // Fill strides. + int64_t stride = 1; + for (int i = result.rank - 1; i >= 0; --i) { + int_pointers[i + result.rank + 3] = stride; + stride *= tensor.dim_size(i); + } + return result; +} + +template +Tensor ConvertDescriptorToTensor( + ::UnrankedMemRefType unranked_descriptor, DataType tf_data_type, + Allocator* allocator) { + void* base_ptr = static_cast(unranked_descriptor.descriptor)[0]; + TensorShape result_shape; + intptr_t* pointers = static_cast(unranked_descriptor.descriptor); + for (int i = 0; i < unranked_descriptor.rank; ++i) { + result_shape.AddDim(pointers[3 + i]); + } + TensorBuffer* buffer = GetMlirTensorBuffer( + base_ptr, sizeof(ElemType) * result_shape.num_elements(), allocator); + + // Tensor takes ownership of the buffer. + Tensor tensor{tf_data_type, result_shape, buffer}; + // When Tensor is constructed, its ref-counter is incremented. We need to + // decrement it back. + buffer->Unref(); + return tensor; +} + +#define MLIR_FUNCTION(tf_op, mlir_type) _mlir_ciface_##tf_op##_##mlir_type + +// Generates a class derived from OpKernel with Compute function that converts +// input tensors to unranked memref descriptors and calls mlir-generated +// unranked kernel. The outputs are converted back to tensors using +// MlirTensorBuffer to take ownership of pre-allocated memory. +#define REGISTER_AND_GENERATE_KERNEL(tf_op, mlir_type, tf_data_type, \ + data_type) \ + extern "C" ::UnrankedMemRefType MLIR_FUNCTION(tf_op, mlir_type)( \ + tensorflow::OpKernelContext * ctx, \ + ::UnrankedMemRefType * arg); \ + \ + namespace { \ + class MlirUnranked##tf_op##mlir_type##Op : public OpKernel { \ + public: \ + MlirUnranked##tf_op##mlir_type##Op(OpKernelConstruction* ctx) \ + : OpKernel(ctx) {} \ + \ + void Compute(OpKernelContext* ctx) override { \ + const Tensor& input = ctx->input(0); \ + \ + auto input_desc = ConvertTensorToDescriptor(input); \ + auto result_desc = MLIR_FUNCTION(tf_op, mlir_type)(ctx, &input_desc); \ + free(input_desc.descriptor); \ + \ + tensorflow::AllocatorAttributes attrs; \ + auto* allocator = ctx->get_allocator(attrs); \ + \ + Tensor result_tensor = ConvertDescriptorToTensor( \ + result_desc, tf_data_type, allocator); \ + free(result_desc.descriptor); \ + ctx->set_output(0, result_tensor); \ + } \ + }; \ + } \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(#tf_op).Device(DEVICE_GPU).TypeConstraint("T"), \ + MlirUnranked##tf_op##mlir_type##Op); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_UNRANKED_OP_GPU_ABS_H_ diff --git a/tensorflow/lite/micro/cortex_m4_generic/debug_log.cc b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_tanh.cc similarity index 64% rename from tensorflow/lite/micro/cortex_m4_generic/debug_log.cc rename to tensorflow/core/kernels/mlir_generated/unranked_op_gpu_tanh.cc index 615b5a0f393..206c0756e9c 100644 --- a/tensorflow/lite/micro/cortex_m4_generic/debug_log.cc +++ b/tensorflow/core/kernels/mlir_generated/unranked_op_gpu_tanh.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/micro/debug_log.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/mlir_generated/unranked_op_gpu_base.h" -#ifndef TF_LITE_STRIP_ERROR_STRINGS -#include -#endif +namespace tensorflow { -extern "C" void DebugLog(const char* s) { -#ifndef TF_LITE_STRIP_ERROR_STRINGS - fprintf(stderr, "%s", s); -#endif -} +REGISTER_AND_GENERATE_KERNEL(Tanh, f16, DT_HALF, Eigen::half); +REGISTER_AND_GENERATE_KERNEL(Tanh, f32, DT_FLOAT, float); +REGISTER_AND_GENERATE_KERNEL(Tanh, f64, DT_DOUBLE, double); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index bbe03dd39cb..a0c07f31b3d 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -463,6 +463,8 @@ void DnnPoolingGradOp::Compute( return; } + TensorFormat transformed_input_data_format = data_format; + #if CUDNN_VERSION < 7300 /// For now, cudnn does not support NHWC format, so we need to convert it /// to NCHW before calling cudnn. We need to get rid of this once it is done @@ -516,6 +518,7 @@ void DnnPoolingGradOp::Compute( functor::NHWCToNCHW()(context->eigen_device(), tensor_in->tensor(), transformed_input.tensor()); + transformed_input_data_format = FORMAT_NCHW; } if (tensor_out) { // For AvgPoolGrad, the original output tensor is not necessary. However, @@ -577,6 +580,8 @@ void DnnPoolingGradOp::Compute( int64 input_pad_left = 0; int64 input_pad_right = 0; + Tensor transformed_and_padded_input_backprop; + if (padding == EXPLICIT && (params.pad_top != params.pad_bottom || params.pad_left != params.pad_right)) { // Pad the input in the same way we did during the forward pass, so that @@ -588,7 +593,6 @@ void DnnPoolingGradOp::Compute( std::min(params.pad_left, params.pad_right); Tensor padded_input; - Tensor padded_input_backprop; const int64 padding_rows_diff = std::abs(params.pad_top - params.pad_bottom); const int64 padding_cols_diff = @@ -607,18 +611,18 @@ void DnnPoolingGradOp::Compute( << " stride_rows" << params.row_stride; OP_REQUIRES_OK( - context, - context->allocate_temp(DataTypeToEnum::value, - ShapeFromFormat(data_format, batch_size, - new_in_rows, new_in_cols, depth), - &padded_input)); + context, context->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(transformed_input_data_format, batch_size, + new_in_rows, new_in_cols, depth), + &padded_input)); OP_REQUIRES_OK( - context, - context->allocate_temp(DataTypeToEnum::value, - ShapeFromFormat(data_format, batch_size, - new_in_rows, new_in_cols, depth), - &transformed_input_backprop)); + context, context->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(transformed_input_data_format, batch_size, + new_in_rows, new_in_cols, depth), + &transformed_and_padded_input_backprop)); input_pad_top = params.pad_top - common_padding_rows; input_pad_bottom = params.pad_bottom - common_padding_rows; @@ -644,7 +648,8 @@ void DnnPoolingGradOp::Compute( To32Bit(const_transformed_input.tensor()), static_cast(input_pad_top), static_cast(input_pad_bottom), static_cast(input_pad_left), static_cast(input_pad_right), - To32Bit(padded_input.tensor()), data_format)); + To32Bit(padded_input.tensor()), + transformed_input_data_format)); transformed_input = padded_input; @@ -654,6 +659,8 @@ void DnnPoolingGradOp::Compute( << " horizontal padding set to: " << horizontal_padding; tensor_in_rows = new_in_rows; tensor_in_cols = new_in_cols; + } else { + transformed_and_padded_input_backprop = transformed_input_backprop; } /// Get ready to call cudnn @@ -690,9 +697,9 @@ void DnnPoolingGradOp::Compute( auto output_backprop_data = AsDeviceMemory(transformed_output_backprop.template flat().data(), transformed_output_backprop.template flat().size()); - auto input_backprop_data = - AsDeviceMemory(transformed_input_backprop.template flat().data(), - transformed_input_backprop.template flat().size()); + auto input_backprop_data = AsDeviceMemory( + transformed_and_padded_input_backprop.template flat().data(), + transformed_and_padded_input_backprop.template flat().size()); auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); @@ -722,6 +729,20 @@ void DnnPoolingGradOp::Compute( OP_REQUIRES(context, status, errors::Internal("dnn PoolBackward launch failed")); + if (padding == EXPLICIT && (params.pad_top != params.pad_bottom || + params.pad_left != params.pad_right)) { + // Remove the padding that was added to the input shape above. + functor::PadInput()( + context->eigen_device(), + To32Bit(const_cast(transformed_and_padded_input_backprop) + .tensor()), + {{static_cast(-input_pad_top), static_cast(-input_pad_left)}}, + {{static_cast(-input_pad_bottom), + static_cast(-input_pad_right)}}, + To32Bit(transformed_input_backprop.template tensor()), + transformed_input_data_format, T{}); + } + #if CUDNN_VERSION < 7300 if (data_format == FORMAT_NHWC) { /// Transform the output data from NCHW back to NHWC. @@ -732,18 +753,6 @@ void DnnPoolingGradOp::Compute( input_backprop->tensor()); } #endif // CUDNN_VERSION < 7300 - if (padding == EXPLICIT && (params.pad_top != params.pad_bottom || - params.pad_left != params.pad_right)) { - // Remove the padding that was added to the input shape above. - functor::PadInput()( - context->eigen_device(), - To32Bit(const_cast(transformed_input_backprop) - .tensor()), - {{static_cast(-input_pad_top), static_cast(-input_pad_left)}}, - {{static_cast(-input_pad_bottom), - static_cast(-input_pad_right)}}, - To32Bit(input_backprop->tensor()), data_format, T{}); - } } #define DEFINE_DNN_OPS(T) \ diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h index fef3ed582b3..eaa29023a60 100644 --- a/tensorflow/core/kernels/quantization_utils.h +++ b/tensorflow/core/kernels/quantization_utils.h @@ -43,7 +43,8 @@ namespace tensorflow { // We have to be able to detect and handle overflows in int32, so this function // uses doubles and int64's to make sure we have enough room. template -int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) { +inline int64 FloatToQuantizedUnclamped(float input, float range_min, + float range_max) { const int64 lowest_quantized = static_cast(Eigen::NumTraits::lowest()); if (range_min == range_max) { @@ -60,6 +61,12 @@ int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) { return quantized; } +template <> +inline int64 FloatToQuantizedUnclamped(float input, float range_min, + float range_max) { + return -1; +} + // This converts the float into the final quantized type, clamping/saturating // any over or underflows. template diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc index 8f71d09c083..dec0262cf04 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc @@ -131,6 +131,75 @@ class QuantizeAndDequantizeV2Op : public OpKernel { bool narrow_range_; }; +// Implementation of QuantizeAndDequantizeV4GradientOp. +// When back-propagating the error through a quantized layer, the following +// paper gives evidence that clipped-ReLU is better than non-clipped: +// "Deep Learning with Low Precision by Half-wave Gaussian Quantization" +// http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf +template +class QuantizeAndDequantizeV4GradientOp : public OpKernel { + public: + explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx) + : OpKernel::OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& gradient = ctx->input(0); + const Tensor& input = ctx->input(1); + Tensor* input_backprop = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, input.shape(), &input_backprop)); + + OP_REQUIRES( + ctx, input.IsSameSize(gradient), + errors::InvalidArgument("gradient and input must be the same size")); + const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_); + const Tensor& input_min_tensor = ctx->input(2); + const Tensor& input_max_tensor = ctx->input(3); + if (axis_ != -1) { + OP_REQUIRES( + ctx, input_min_tensor.dim_size(0) == depth, + errors::InvalidArgument("min has incorrect size, expected ", depth, + " was ", input_min_tensor.dim_size(0))); + OP_REQUIRES( + ctx, input_max_tensor.dim_size(0) == depth, + errors::InvalidArgument("max has incorrect size, expected ", depth, + " was ", input_max_tensor.dim_size(0))); + } + + TensorShape min_max_shape(input_min_tensor.shape()); + Tensor* input_min_backprop; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(1, min_max_shape, &input_min_backprop)); + + Tensor* input_max_backprop; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(2, min_max_shape, &input_max_backprop)); + + if (axis_ == -1) { + functor::QuantizeAndDequantizeOneScaleGradientFunctor f; + f(ctx->eigen_device(), gradient.template flat(), + input.template flat(), input_min_tensor.scalar(), + input_max_tensor.scalar(), input_backprop->template flat(), + input_min_backprop->template scalar(), + input_max_backprop->template scalar()); + } else { + functor::QuantizeAndDequantizePerChannelGradientFunctor f; + f(ctx->eigen_device(), + gradient.template flat_inner_outer_dims(axis_ - 1), + input.template flat_inner_outer_dims(axis_ - 1), + &input_min_tensor, &input_max_tensor, + input_backprop->template flat_inner_outer_dims(axis_ - 1), + input_min_backprop->template flat(), + input_max_backprop->template flat()); + } + } + + private: + int axis_; +}; + // Simulate quantization precision loss in a float tensor by: // 1. Quantize the tensor to fixed point numbers, which should match the target // quantization method when it is used in inference. @@ -295,6 +364,43 @@ struct QuantizeAndDequantizePerChannelFunctor { input_max_tensor, round_mode, narrow_range, out); } }; + +template +struct QuantizeAndDequantizeOneScaleGradientFunctor { + void operator()(const CPUDevice& d, typename TTypes::ConstFlat gradient, + typename TTypes::ConstFlat input, + typename TTypes::ConstScalar input_min_tensor, + typename TTypes::ConstScalar input_max_tensor, + typename TTypes::Flat input_backprop, + typename TTypes::Scalar input_min_backprop, + typename TTypes::Scalar input_max_backprop) { + QuantizeAndDequantizeOneScaleGradientImpl::Compute( + d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, + input_min_backprop, input_max_backprop); + } +}; + +template +struct QuantizeAndDequantizePerChannelGradientFunctor { + void operator()(const CPUDevice& d, + typename TTypes::ConstTensor gradient, + typename TTypes::ConstTensor input, + const Tensor* input_min_tensor, + const Tensor* input_max_tensor, + typename TTypes::Tensor input_backprop, + typename TTypes::Flat input_min_backprop, + typename TTypes::Flat input_max_backprop) { + QuantizeAndDequantizePerChannelGradientImpl::Compute( + d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, + input_min_backprop, input_max_backprop); + } +}; + +template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor; +template struct functor::QuantizeAndDequantizePerChannelGradientFunctor< + CPUDevice, double>; + } // namespace functor #define REGISTER_CPU_KERNEL(T) \ @@ -306,6 +412,14 @@ struct QuantizeAndDequantizePerChannelFunctor { .Device(DEVICE_CPU) \ .TypeConstraint("T"), \ QuantizeAndDequantizeV3Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + QuantizeAndDequantizeV2Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + QuantizeAndDequantizeV4GradientOp); \ REGISTER_KERNEL_BUILDER( \ Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint("T"), \ QuantizeAndDequantizeOp); @@ -329,6 +443,18 @@ TF_CALL_double(REGISTER_CPU_KERNEL); .HostMemory("num_bits") \ .TypeConstraint("T"), \ QuantizeAndDequantizeV3Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4") \ + .Device(DEVICE_GPU) \ + .HostMemory("input_min") \ + .HostMemory("input_max") \ + .TypeConstraint("T"), \ + QuantizeAndDequantizeV2Op); \ + REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad") \ + .Device(DEVICE_GPU) \ + .HostMemory("input_min") \ + .HostMemory("input_max") \ + .TypeConstraint("T"), \ + QuantizeAndDequantizeV4GradientOp); \ REGISTER_KERNEL_BUILDER( \ Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint("T"), \ QuantizeAndDequantizeOp); diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h index 4dd6e5c839b..c286a10a9c6 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.h +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h @@ -60,6 +60,28 @@ struct QuantizeAndDequantizePerChannelFunctor { typename TTypes::Tensor output); }; +template +struct QuantizeAndDequantizeOneScaleGradientFunctor { + void operator()(const Device& d, typename TTypes::ConstFlat gradient, + typename TTypes::ConstFlat input, + typename TTypes::ConstScalar input_min, + typename TTypes::ConstScalar input_max, + typename TTypes::Flat input_backprop, + typename TTypes::Scalar input_min_backprop, + typename TTypes::Scalar input_max_backprop); +}; + +template +struct QuantizeAndDequantizePerChannelGradientFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor gradient, + typename TTypes::ConstTensor input, + const Tensor* input_min_tensor, + const Tensor* input_max_tensor, + typename TTypes::Tensor input_backprop, + typename TTypes::Flat input_min_backprop, + typename TTypes::Flat input_max_backprop); +}; + // The implementation below runs on both CPU and GPU. template ::Vec, @@ -249,6 +271,55 @@ struct QuantizeAndDequantizePerChannelImpl { } }; +template +struct QuantizeAndDequantizeOneScaleGradientImpl { + static void Compute(const Device& d, typename TTypes::ConstFlat gradient, + typename TTypes::ConstFlat input, + typename TTypes::ConstScalar input_min, + typename TTypes::ConstScalar input_max, + typename TTypes::Flat input_backprop, + typename TTypes::Scalar input_min_backprop, + typename TTypes::Scalar input_max_backprop) { + const T min_val = input_min(); + const T max_val = input_max(); + const auto in_range = + (input >= min_val && input <= max_val) + .select(input.constant(1.0f), input.constant(0.0f)); + input_backprop.device(d) = gradient * in_range; + input_min_backprop.device(d) = input_min_backprop.constant(0.0f); + input_max_backprop.device(d) = input_max_backprop.constant(0.0f); + } +}; + +template +struct QuantizeAndDequantizePerChannelGradientImpl { + static void Compute(const Device& d, + typename TTypes::ConstTensor gradient, + typename TTypes::ConstTensor input, + const Tensor* input_min_tensor, + const Tensor* input_max_tensor, + typename TTypes::Tensor input_backprop, + typename TTypes::Flat input_min_backprop, + typename TTypes::Flat input_max_backprop) { + using Index = typename tensorflow::TTypes::ConstTensor::Index; + auto input_min = input_min_tensor->vec(); + auto input_max = input_max_tensor->vec(); + int num_channels = input.dimension(1); + for (Index i = 0; i < num_channels; ++i) { + const auto gradient_chip = gradient.template chip<1>(i); + const auto input_chip = input.template chip<1>(i); + const T min_val = input_min(i); + const T max_val = input_max(i); + const auto in_range = + (input_chip >= min_val && input_chip <= max_val) + .select(input_chip.constant(1.0f), input_chip.constant(0.0f)); + input_backprop.template chip<1>(i).device(d) = gradient_chip * in_range; + } + input_min_backprop.device(d) = input_min_backprop.constant(0.0f); + input_max_backprop.device(d) = input_max_backprop.constant(0.0f); + } +}; + } // end of namespace functor } // end of namespace tensorflow diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc index f3bb41071cb..9f074535770 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc @@ -53,6 +53,37 @@ struct QuantizeAndDequantizePerChannelFunctor { } }; +template +struct QuantizeAndDequantizeOneScaleGradientFunctor { + void operator()(const GPUDevice& d, typename TTypes::ConstFlat gradient, + typename TTypes::ConstFlat input, + typename TTypes::ConstScalar input_min_tensor, + typename TTypes::ConstScalar input_max_tensor, + typename TTypes::Flat input_backprop, + typename TTypes::Scalar input_min_backprop, + typename TTypes::Scalar input_max_backprop) { + QuantizeAndDequantizeOneScaleGradientImpl::Compute( + d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, + input_min_backprop, input_max_backprop); + } +}; + +template +struct QuantizeAndDequantizePerChannelGradientFunctor { + void operator()(const GPUDevice& d, + typename TTypes::ConstTensor gradient, + typename TTypes::ConstTensor input, + const Tensor* input_min_tensor, + const Tensor* input_max_tensor, + typename TTypes::Tensor input_backprop, + typename TTypes::Flat input_min_backprop, + typename TTypes::Flat input_max_backprop) { + QuantizeAndDequantizePerChannelGradientImpl::Compute( + d, gradient, input, input_min_tensor, input_max_tensor, input_backprop, + input_min_backprop, input_max_backprop); + } +}; + } // end namespace functor // Instantiate the GPU implementation for float and double. @@ -65,6 +96,15 @@ template struct functor::QuantizeAndDequantizePerChannelFunctor; +template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor; +template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor; +template struct functor::QuantizeAndDequantizePerChannelGradientFunctor< + GPUDevice, float>; +template struct functor::QuantizeAndDequantizePerChannelGradientFunctor< + GPUDevice, double>; + } // end namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc index 90764b0feb2..596ab13590a 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc @@ -362,6 +362,54 @@ TEST_P(ParameterizedQuantizeAndDequantizeTest, } } +// Verifies the Gradient. +TEST_P(ParameterizedQuantizeAndDequantizeTest, GradientV4_op) { + const int axis = GetParam(); + TF_ASSERT_OK(NodeDefBuilder("qdq_v4_grad_op", "QuantizeAndDequantizeV4Grad") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("axis", axis) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + const std::vector dims = {2, 3, 4, 5}; + // Input gradient. (repeating 11 values multiplied by (slice_idx + 1)) + auto gradients = ScalePerSliceAlongAxis( + dims, axis, {1, -2, -3, 4, 5, 6, -7, -8, -9, -10, 11}); + AddInputFromArray(TensorShape(dims), gradients); + // Forward op inputs. (repeating 7 values multiplied by (slice_idx + 1)). + auto inputs = ScalePerSliceAlongAxis( + dims, axis, {-1, -0.5, 0, 0.3, 0.8, 0.55, 0.6}); + AddInputFromArray(TensorShape(dims), inputs); + const int num_slices = (axis == -1) ? 1 : dims[axis]; + const TensorShape range_shape = + (axis == -1) ? TensorShape({}) : TensorShape({num_slices}); + std::vector input_min_values(num_slices), input_max_values(num_slices); + for (int i = 0; i < num_slices; ++i) { + input_max_values[i] = 0.8f + i * 0.4f; + input_min_values[i] = -input_max_values[i]; + } + AddInputFromArray(range_shape, input_min_values); + AddInputFromArray(range_shape, input_max_values); + std::vector expected_vals(inputs.size()); + int minor_size = 1; + for (int i = axis + 1; i < dims.size(); ++i) { + minor_size *= dims[i]; + } + for (int i = 0; i < inputs.size(); ++i) { + int slice_idx = (i / minor_size) % num_slices; + expected_vals[i] = ((inputs[i] >= input_min_values[slice_idx]) && + (inputs[i] <= input_max_values[slice_idx])) + ? gradients[i] + : 0; + } + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_FLOAT, TensorShape(dims)); + test::FillValues(&expected, expected_vals); + test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); +} + // Instantiate parameterized tests for axis = -1, 1, 3. INSTANTIATE_TEST_SUITE_P(All, ParameterizedQuantizeAndDequantizeTest, ::testing::Values(-1, 1, 3)); diff --git a/tensorflow/core/kernels/random_binomial_op.cc b/tensorflow/core/kernels/random_binomial_op.cc index e2487d68be8..3c4027765b7 100644 --- a/tensorflow/core/kernels/random_binomial_op.cc +++ b/tensorflow/core/kernels/random_binomial_op.cc @@ -255,7 +255,6 @@ struct RandomBinomialFunctor { } } else if (prob > T(0.5)) { T q = T(1) - prob; - double dcount = static_cast(count); double dq = static_cast(q); if (count * q >= T(10)) { for (int64 sample_idx = output_idx % samples_per_batch; diff --git a/tensorflow/core/kernels/random_poisson_op.cc b/tensorflow/core/kernels/random_poisson_op.cc index dcb7d6b0f0e..2c898d95749 100644 --- a/tensorflow/core/kernels/random_poisson_op.cc +++ b/tensorflow/core/kernels/random_poisson_op.cc @@ -150,6 +150,16 @@ struct PoissonFunctor { } continue; } + if (Eigen::numext::isinf(rate) && rate > CT(0)) { + // Fill the rest of the samples for the current rate value. + for (int64 sample_idx = output_idx % num_samples; + sample_idx < num_samples && output_idx < limit_output; + sample_idx++, output_idx++) { + U k = Eigen::NumTraits::infinity(); + samples_rate_output[sample_idx * num_rate] = k; + } + continue; + } // Transformed rejection due to Hormann. // // Given a CDF F(x), and G(x), a dominating distribution chosen such diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index 210b994a0b8..cd6826a65a0 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -18,6 +18,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/relu_op.h" + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -68,7 +69,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS); SeluGradOp) // Elu and Selu only make sense with float or double. -TF_CALL_GPU_NUMBER_TYPES(REGISTER_ELU_KERNELS); +TF_CALL_FLOAT_TYPES(REGISTER_ELU_KERNELS); #undef REGISTER_ELU_KERNELS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -208,5 +209,4 @@ REGISTER_KERNEL_BUILDER( #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM - } // namespace tensorflow diff --git a/tensorflow/core/kernels/rnn/BUILD b/tensorflow/core/kernels/rnn/BUILD index 0ac9b08ed15..104cc1eba03 100644 --- a/tensorflow/core/kernels/rnn/BUILD +++ b/tensorflow/core/kernels/rnn/BUILD @@ -32,9 +32,9 @@ tf_gpu_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:eigen_helpers", + "//tensorflow/core/platform:stream_executor", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index bdb07470c07..aaa0310cd1b 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -1461,20 +1461,23 @@ class ApplyProximalAdagradOp : public OpKernel { const Tensor& lr = ctx->input(2); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()) && - lr.scalar()() > static_cast(0), + (!std::is_same::value || + lr.scalar()() > static_cast(0)), errors::InvalidArgument("lr is not a positive scalar: ", lr.shape().DebugString())); const Tensor& l1 = ctx->input(3); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1.shape()) && - l1.scalar()() >= static_cast(0), + (!std::is_same::value || + l1.scalar()() >= static_cast(0)), errors::InvalidArgument("l1 regularization strength is not a " "non-negative scalar: ", l1.shape().DebugString())); const Tensor& l2 = ctx->input(4); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2.shape()) && - l2.scalar()() >= static_cast(0), + (!std::is_same::value || + l2.scalar()() >= static_cast(0)), errors::InvalidArgument("l2 regularization strength is not a " "non-negative scalar: ", l2.shape().DebugString())); @@ -1510,6 +1513,29 @@ class ApplyProximalAdagradOp : public OpKernel { REGISTER_KERNELS(CPU, float); REGISTER_KERNELS(CPU, double); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyProximalAdagrad::operator()( \ + const GPUDevice& d, typename TTypes::Flat var, \ + typename TTypes::Flat accum, typename TTypes::ConstScalar lr, \ + typename TTypes::ConstScalar l1, typename TTypes::ConstScalar l2, \ + typename TTypes::ConstFlat grad); \ + extern template struct ApplyProximalAdagrad; +DECLARE_GPU_SPEC(Eigen::half); +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, Eigen::half); +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS namespace { @@ -2615,11 +2641,14 @@ class SparseApplyFtrlOp : public OpKernel { errors::InvalidArgument("indices must be one-dimensional")); const Tensor& lr = ctx->input(5); - OP_REQUIRES(ctx, - TensorShapeUtils::IsScalar(lr.shape()) && - lr.scalar()() > static_cast(0), - errors::InvalidArgument("lr is not a positive scalar: ", - lr.shape().DebugString())); + OP_REQUIRES( + ctx, + TensorShapeUtils::IsScalar(lr.shape()) && + (lr.scalar()() > static_cast(0) || + (multiply_linear_by_lr_ && lr.scalar()() >= static_cast(0))), + errors::InvalidArgument("lr is not a positive scalar (or zero if " + "multiply_linear_by_lr is set): ", + lr.shape().DebugString())); const Tensor& l1 = ctx->input(6); OP_REQUIRES(ctx, diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 5a0c6d5c168..64df24180ef 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -184,6 +184,20 @@ __device__ std::complex impl_rsqrt(std::complex x) { return *(reinterpret_cast*>(&root)); } +template +__device__ T impl_fabs(T x) { + return fabs(x); +} +template <> +__device__ Eigen::half impl_fabs(Eigen::half x) { + return __float2half(fabs(__half2float(x))); +} + +template +__device__ T impl_sign(T x) { + return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1); +} + template __global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg, T* var, T* accum, @@ -207,6 +221,20 @@ __global__ __launch_bounds__(1024) void ApplyAdagradV2Kernel( } } +template +__global__ __launch_bounds__(1024) void ApplyProximalAdagradKernel( + GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* l1, + const T* l2, const T* grad) { + GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) { + accum[i] += grad[i] * grad[i]; + T lr_scaled = lr[0] * impl_rsqrt(accum[i]); + T prox_var = var[i] - grad[i] * lr_scaled; + var[i] = impl_sign(prox_var) * + max(impl_fabs(prox_var) - lr_scaled * max(l1[0], T(0.f)), T(0.f)) / + (T(1.f) + l2[0] * lr_scaled); + } +} + template __global__ __launch_bounds__(1024) void ApplyAdadeltaKernel( GpuLaunchConfig cfg, T* var, T* accum, T* accum_update, const T* plr, @@ -326,6 +354,43 @@ struct ApplyAdagradV2 { #endif } }; + +template +struct ApplyProximalAdagrad { + void operator()(const GPUDevice& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstFlat grad) { +#if TENSORFLOW_USE_ROCM + wrap_kernel_call(ApplyProximalAdagradKernel, d, var, accum, lr, l1, l2, + grad); +#else + Eigen::array::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + // Fobos update per paper with Adagrad learning rate. + accum.device(d) += grad.square(); + // Adagrad learning rate. + // The following is the GPU equivalent of the CPU version: + // auto learning_rate = accum.constant(lr()) * accum.rsqrt(); + auto lr_bcast = lr.reshape(single).broadcast(bcast); + auto l1_bcast = l1.reshape(single).broadcast(bcast); + auto l2_bcast = l2.reshape(single).broadcast(bcast); + auto learning_rate = lr_bcast * accum.rsqrt(); + auto prox_var = var; + // compute v = w - lr * grad. + prox_var.device(d) -= grad * learning_rate; + // compute sign(v) * max(|v| - lr * max(l1, 0), 0) + var.device(d) = prox_var.sign() * + (prox_var.abs() - learning_rate * l1_bcast.cwiseMax(T(0.f))) + .cwiseMax(T(0.f)) / + (var.constant(T(1.f)) + l2_bcast * learning_rate); +#endif + } +}; + template struct ApplyAdadelta { void operator()(const GPUDevice& d, typename TTypes::Flat var, @@ -811,6 +876,10 @@ template struct functor::ApplyAdagradV2; template struct functor::ApplyAdagradV2; #endif +template struct functor::ApplyProximalAdagrad; +template struct functor::ApplyProximalAdagrad; +template struct functor::ApplyProximalAdagrad; + template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; template struct functor::ApplyAdadelta; diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD index 886e3e9140c..1311b4a44d5 100644 --- a/tensorflow/core/lib/core/BUILD +++ b/tensorflow/core/lib/core/BUILD @@ -247,14 +247,6 @@ filegroup( visibility = ["//tensorflow/core:__pkg__"], ) -filegroup( - name = "legacy_lib_core_threadpool_options_header", - srcs = [ - "threadpool_options.h", - ], - visibility = ["//tensorflow/core:__pkg__"], -) - filegroup( name = "legacy_lib_proto_parsing_headers", srcs = [ @@ -279,6 +271,7 @@ filegroup( "stringpiece.h", "threadpool.h", "threadpool_interface.h", + "threadpool_options.h", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/png/BUILD b/tensorflow/core/lib/png/BUILD index 63326737d8e..aca343c0dbd 100644 --- a/tensorflow/core/lib/png/BUILD +++ b/tensorflow/core/lib/png/BUILD @@ -5,8 +5,7 @@ load( package( default_visibility = [ - # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** - "//tensorflow/core:__pkg__", + "//tensorflow:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index a076d388b70..9b1447f53c1 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -25,8 +25,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "nccl_lib", srcs = if_cuda_or_rocm([ @@ -45,12 +43,11 @@ cc_library( ]) + if_cuda_or_rocm([ "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:annotated_traceme", diff --git a/tensorflow/core/nccl/collective_communicator.cc b/tensorflow/core/nccl/collective_communicator.cc index e71721ce080..56e2255ae99 100644 --- a/tensorflow/core/nccl/collective_communicator.cc +++ b/tensorflow/core/nccl/collective_communicator.cc @@ -69,8 +69,8 @@ void NcclCommunicator::Enqueue(std::shared_ptr col_ctx, StatusCallback done) { const CollectiveParams& col_params = col_ctx->col_params; const int num_global_devices = col_params.group.group_size; - const int num_local_devices = col_params.instance.num_devices_per_task.at( - col_params.instance.task_names[col_params.default_rank]); + const int num_local_devices = col_params.group.num_devices_per_task.at( + col_params.group.task_names[col_params.default_rank]); const string nccl_collective_key = NcclCollectiveKey(col_ctx->exec_key, col_ctx->step_id); auto* compute_stream = col_ctx->op_ctx->op_device_context()->stream(); @@ -84,7 +84,7 @@ void NcclCommunicator::Enqueue(std::shared_ptr col_ctx, col_params.source_rank); VLOG(1) << "NcclCommunicator::Enqueue type " << col_params.instance.type << " num_tasks " << col_params.group.num_tasks << " current task " - << col_params.instance.task_names[col_params.default_rank] + << col_params.group.task_names[col_params.default_rank] << " num local devices " << num_local_devices << " num global devices " << num_global_devices << " device " << col_ctx->device_name << " instance " diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index 43c9b229450..157c255a316 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -632,7 +632,7 @@ void NcclManager::RunCollective(Collective* collective) { // Wait to ensure that the kernel that produces the data in the input // tensor has finished running before the nccl kernel runs on the // communication stream. - nccl_stream->stream->ThenWaitFor(p->input_event.get()); + nccl_stream->stream->ThenWaitFor(p->tensor_stream); } if (p->root) { if (collective->root_rank == -1) { diff --git a/tensorflow/core/nccl/nccl_manager.h b/tensorflow/core/nccl/nccl_manager.h index b91ef86042a..88b8bc85663 100644 --- a/tensorflow/core/nccl/nccl_manager.h +++ b/tensorflow/core/nccl/nccl_manager.h @@ -27,7 +27,6 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" -#include "absl/memory/memory.h" #if GOOGLE_CUDA #include "third_party/nccl/nccl.h" #elif TENSORFLOW_USE_ROCM @@ -77,7 +76,6 @@ class NcclManager { context(static_cast(info->default_context)), #endif input(input), - input_event(nullptr), output(output), global_rank(global_rank), done_callback(std::move(done_callback)), @@ -85,11 +83,6 @@ class NcclManager { DCHECK(executor != nullptr); DCHECK(event_mgr != nullptr); DCHECK(tensor_stream != nullptr); - if (input != nullptr) { - input_event = absl::make_unique(executor); - input_event->Init(); - tensor_stream->ThenRecordEvent(input_event.get()); - } } // StreamExecutor for the device. Expected to be live for process lifetime. @@ -118,10 +111,6 @@ class NcclManager { // called. Is NULL for participants that only receive data. const Tensor* input; - // Wait on this event rather than synchronizing on the entire stream. - // This allows greater concurrency between compute and nccl streams. - std::unique_ptr input_event; - // Owned by the caller, who must keep it live until `done_callback` is // called. Is NULL for participants that only send data. Tensor* output; diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD new file mode 100644 index 00000000000..d50d43f0b72 --- /dev/null +++ b/tensorflow/core/ops/BUILD @@ -0,0 +1,539 @@ +# Description: +# Tensorflow default op definitions. + +load( + "//tensorflow:tensorflow.bzl", + "if_chromiumos", + "tf_cc_test", +) + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") + +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + +# For platform specific build config +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core/platform:rules_cc.bzl", + "cc_library", +) +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) + +# A lot of packages try to minimize binary size by depending on individual ops,\ +# so they need access here. +package( + default_visibility = [ + "//visibility:public", + ], + features = ["-parse_headers"], + licenses = ["notice"], # Apache 2.0 +) + +# Export the BUILD file so automated tooling can check licenses +exports_files([ + "BUILD", + "ops.pbtxt", +]) + +# Generates library per group of ops. +tf_gen_op_libs( + is_external = False, + op_lib_names = [ + "batch_ops", + "bitwise_ops", + "boosted_trees_ops", + "tensor_forest_ops", + "candidate_sampling_ops", + "checkpoint_ops", + "clustering_ops", + "collective_ops", + "control_flow_ops", + "count_ops", + "ctc_ops", + "data_flow_ops", + "dataset_ops", + "decode_proto_ops", + "encode_proto_ops", + "experimental_dataset_ops", + "function_ops", + "functional_ops", + "image_ops", + "io_ops", + "linalg_ops", + "list_ops", + "map_ops", + "lookup_ops", + "manip_ops", + "math_ops", + "mkl_nn_ops", + "nccl_ops", + "nn_ops", + "no_op", + "parsing_ops", + "random_grad", + "random_ops", + "special_math_ops", + "stateful_random_ops", + "remote_fused_graph_ops", + "rnn_ops", + "rpc_ops", + "scoped_allocator_ops", + "sdca_ops", + "set_ops", + "script_ops", + "sendrecv_ops", + "sparse_csr_matrix_ops", + "sparse_ops", + "spectral_ops", + "state_ops", + "stateless_random_ops", + "stateless_random_ops_v2", + "summary_ops", + "training_ops", + ], + sub_directory = "", + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_gen_op_libs( + is_external = False, + op_lib_names = [ + "logging_ops", + ], + sub_directory = "", + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + # TODO(b/162630222): remove this dependency. + "//tensorflow/c/kernels:histogram_summary_op_lib", + "//tensorflow/c/kernels:merge_summary_op_lib", + "//tensorflow/c/kernels:summary_op_lib", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "string_ops", + ], + sub_directory = "", + deps = [ + "//tensorflow/core:lib_internal", + "//tensorflow/core:lib_proto_parsing", + "@com_google_absl//absl/strings", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "array_ops", + ], + sub_directory = "", + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "mkl_array_ops", + ], + sub_directory = "", + deps = ["//tensorflow/core:protos_all_cc"], +) + +tf_gen_op_libs( + op_lib_names = [ + "audio_ops", + ], + sub_directory = "", + deps = ["//tensorflow/core:lib"], +) + +tf_gen_op_libs( + op_lib_names = ["debug_ops"], + sub_directory = "", + deps = ["//tensorflow/core:lib"], +) + +tf_gen_op_libs( + is_external = False, + op_lib_names = [ + "resource_variable_ops", + ], + sub_directory = "", + deps = ["//tensorflow/core:lib"], +) + +tf_gen_op_libs( + op_lib_names = [ + "tpu_configuration_ops", + "tpu_cross_replica_ops", + "tpu_embedding_ops", + "tpu_embedding_load_retrieve_ops", + "tpu_functional_ops", + "tpu_heartbeat_ops", + "tpu_host_compute_ops", + "tpu_infeed_ops", + "tpu_outfeed_ops", + "tpu_ordinal_selector_ops", + "tpu_replication_ops", + ], + sub_directory = "", + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", + "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "//tensorflow/core/tpu:tpu_embedding_output_layout_utils", + ], +) + +cc_library( + name = "word2vec_ops", + srcs = ["word2vec_ops.cc"], + linkstatic = 1, + deps = ["//tensorflow/core:framework"], + alwayslink = 1, +) + +tf_gen_op_libs( + op_lib_names = [ + "cudnn_rnn_ops", + ], + sub_directory = "", + deps = [ + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "ragged_ops", + deps = [ + "//tensorflow/core:ragged_array_ops_op_lib", + "//tensorflow/core:ragged_conversion_ops_op_lib", + "//tensorflow/core:ragged_math_ops_op_lib", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "ragged_array_ops", + "ragged_conversion_ops", + "ragged_math_ops", + ], + sub_directory = "", + deps = ["//tensorflow/core/util:ragged_to_dense_util"], +) + +cc_library( + name = "ops", + deps = [ + ":array_ops_op_lib", + ":audio_ops_op_lib", + ":batch_ops_op_lib", + ":bitwise_ops_op_lib", + ":boosted_trees_ops_op_lib", + ":tensor_forest_ops_op_lib", + ":candidate_sampling_ops_op_lib", + ":checkpoint_ops_op_lib", + ":clustering_ops_op_lib", + ":collective_ops_op_lib", + ":control_flow_ops_op_lib", + ":count_ops_op_lib", + ":ctc_ops_op_lib", + ":cudnn_rnn_ops_op_lib", + ":data_flow_ops_op_lib", + ":dataset_ops_op_lib", + ":debug_ops_op_lib", + ":decode_proto_ops_op_lib", + ":encode_proto_ops_op_lib", + ":experimental_dataset_ops_op_lib", + ":function_ops_op_lib", + ":functional_ops_op_lib", + ":image_ops_op_lib", + ":io_ops_op_lib", + ":linalg_ops_op_lib", + ":list_ops_op_lib", + ":map_ops_op_lib", + ":logging_ops_op_lib", + ":lookup_ops_op_lib", + ":manip_ops_op_lib", + ":math_ops_op_lib", + ":nccl_ops_op_lib", + ":nn_ops_op_lib", + ":no_op_op_lib", + ":parsing_ops_op_lib", + ":ragged_ops", + ":random_ops_op_lib", + ":rnn_ops_op_lib", + ":special_math_ops_op_lib", + ":stateful_random_ops_op_lib", + ":remote_fused_graph_ops_op_lib", + ":resource_variable_ops_op_lib", + ":rpc_ops_op_lib", + ":scoped_allocator_ops_op_lib", + ":script_ops_op_lib", + ":sdca_ops_op_lib", + ":sendrecv_ops_op_lib", + ":set_ops_op_lib", + ":sparse_csr_matrix_ops_op_lib", + ":sparse_ops_op_lib", + ":summary_ops_op_lib", + ":spectral_ops_op_lib", + ":state_ops_op_lib", + ":stateless_random_ops_op_lib", + ":stateless_random_ops_v2_op_lib", + ":string_ops_op_lib", + ":training_ops_op_lib", + ":word2vec_ops", + ] + if_chromiumos( + [], + # Non-tpu platforms don't need tpu dependency. + [ + ":tpu_configuration_ops_op_lib", + ":tpu_cross_replica_ops_op_lib", + ":tpu_embedding_ops_op_lib", + ":tpu_embedding_load_retrieve_ops_op_lib", + ":tpu_functional_ops_op_lib", + ":tpu_heartbeat_ops_op_lib", + ":tpu_host_compute_ops_op_lib", + ":tpu_infeed_ops_op_lib", + ":tpu_outfeed_ops_op_lib", + ":tpu_ordinal_selector_ops_op_lib", + ":tpu_replication_ops_op_lib", + ], + ) + if_mkl([ + ":mkl_array_ops_op_lib", + ":mkl_nn_ops_op_lib", + ]), + alwayslink = 1, +) + +cc_library( + name = "array_grad", + srcs = ["array_grad.cc"], + linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 + deps = [ + ":array_ops_op_lib", + "//tensorflow/c/kernels:bitcast_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "functional_grad", + srcs = ["functional_grad.cc"], + linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 + deps = [ + ":functional_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "math_grad", + srcs = [ + "math_grad.cc", + "random_grad.cc", + "stateless_random_grad.cc", + ], + linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 + deps = [ + ":math_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +cc_library( + name = "nn_grad", + srcs = ["nn_grad.cc"], + linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ":nn_ops_op_lib", + ] + if_mkl([ + ":mkl_nn_ops_op_lib", + ]), + alwayslink = 1, +) + +filegroup( + name = "portable_op_registrations_and_gradients", + srcs = ["//tensorflow/c/kernels:android_all_ops"] + glob( + [ + "**/*.cc", + "**/*.h", + ], + exclude = [ + "**/*test.cc", + "**/*testutil*", + "**/*testlib*", + "**/*main.cc", + "**/tpu_*", + ], + ), +) + +tf_cc_test( + name = "cudnn_rnn_ops_test_cc", + size = "small", + srcs = [ + "cudnn_rnn_ops_test.cc", + ], + deps = [ + "//tensorflow/core", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "ops_array_grad_test", + size = "small", + srcs = ["array_grad_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":ops", + "//tensorflow/cc:cc_ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "ops_math_grad_test", + size = "small", + srcs = ["math_grad_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_gpu"], + deps = [ + ":ops", + "//tensorflow/cc:cc_ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:data_flow", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "ops_remote_fused_graph_ops_test", + size = "small", + srcs = ["remote_fused_graph_ops_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:remote_fused_graph_ops", + ], +) + +tf_cc_test( + name = "ops_tests", + size = "small", + srcs = [ + "array_ops_test.cc", + "candidate_sampling_ops_test.cc", + "control_flow_ops_test.cc", + "ctc_ops_test.cc", + "data_flow_ops_test.cc", + "functional_ops_test.cc", + "image_ops_test.cc", + "io_ops_test.cc", + "linalg_ops_test.cc", + "lookup_ops_test.cc", + "math_ops_test.cc", + "nn_ops_test.cc", + "parsing_ops_test.cc", + "random_ops_test.cc", + "rnn_ops_test.cc", + "set_ops_test.cc", + "shape_function_test.cc", + "sparse_csr_matrix_ops_test.cc", + "sparse_ops_test.cc", + "spectral_ops_test.cc", + "state_ops_test.cc", + "string_ops_test.cc", + "training_ops_test.cc", + ], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":ops", + "//tensorflow/cc:cc_ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//third_party/eigen3", + ], +) diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index bf3955b354b..2018f793741 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2808,6 +2808,70 @@ REGISTER_OP("QuantizeAndDequantizeV2") return Status::OK(); }); +REGISTER_OP("QuantizeAndDequantizeV4") + .Input("input: T") + .Input("input_min: T") + .Input("input_max: T") + .Attr("signed_input: bool = true") + .Attr("num_bits: int = 8") + .Attr("range_given: bool = false") + .Output("output: T") + .Attr("T: {bfloat16, half, float, double}") + .Attr( + "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = " + "'HALF_TO_EVEN'") + .Attr("narrow_range: bool = false") + .Attr("axis: int = -1") + .SetShapeFn([](InferenceContext* c) { + int axis; + TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); + const int minmax_rank = (axis == -1) ? 0 : 1; + ShapeHandle minmax; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), minmax_rank, &minmax)); + TF_RETURN_IF_ERROR(c->Merge(c->input(2), minmax, &minmax)); + if (axis != -1) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); + DimensionHandle depth; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); + } + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +REGISTER_OP("QuantizeAndDequantizeV4Grad") + .Input("gradients: T") + .Input("input: T") + .Input("input_min: T") + .Input("input_max: T") + .Output("input_backprop: T") + .Output("input_min_backprop: T") + .Output("input_max_backprop: T") + .Attr("T: {bfloat16, half, float, double}") + .Attr("axis: int = -1") + .SetShapeFn([](InferenceContext* c) { + int axis; + TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); + const int minmax_rank = (axis == -1) ? 0 : 1; + ShapeHandle minmax; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), minmax_rank, &minmax)); + TF_RETURN_IF_ERROR(c->Merge(c->input(3), minmax, &minmax)); + if (axis != -1) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), axis + 1, &input)); + DimensionHandle depth; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(minmax, 0), c->Dim(input, axis), &depth)); + } + ShapeHandle inputs; + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs)); + c->set_output(0, inputs); + c->set_output(1, minmax); + c->set_output(2, minmax); + return Status::OK(); + }); + REGISTER_OP("QuantizeAndDequantizeV3") .Input("input: T") .Input("input_min: T") diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc index fc9010ade34..ecaab00c91a 100644 --- a/tensorflow/core/ops/collective_ops.cc +++ b/tensorflow/core/ops/collective_ops.cc @@ -114,6 +114,7 @@ REGISTER_OP("CollectiveReduceV2") .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}") .Attr("final_op: {'Id', 'Div'}") .Attr("communication_hint: string = 'auto'") + .Attr("timeout_seconds: float = 0") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape); diff --git a/tensorflow/core/ops/compat/BUILD b/tensorflow/core/ops/compat/BUILD index 7bfe06f2806..aad8b8e550c 100644 --- a/tensorflow/core/ops/compat/BUILD +++ b/tensorflow/core/ops/compat/BUILD @@ -14,19 +14,17 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "op_compatibility_lib", srcs = ["op_compatibility_lib.cc"], hdrs = ["op_compatibility_lib.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:debug_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/ops:debug_ops_op_lib", ], ) @@ -37,7 +35,7 @@ tf_cc_test( "backwards_compatibility_test.cc", ], data = [ - "//tensorflow/core:ops/ops.pbtxt", + "//tensorflow/core/ops:ops.pbtxt", "//tensorflow/core/ops/compat/ops_history_v1:ops_history_v1_srcs", "//tensorflow/core/ops/compat/ops_history_v2:ops_history_v2_srcs", ] + glob([ diff --git a/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt index dd39ac27f93..b2751cc59e8 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceV2.pbtxt @@ -64,3 +64,76 @@ op { } is_stateful: true } +op { + name: "CollectiveReduceV2" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "group_size" + type: DT_INT32 + } + input_arg { + name: "group_key" + type: DT_INT32 + } + input_arg { + name: "instance_key" + type: DT_INT32 + } + output_arg { + name: "data" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "merge_op" + type: "string" + allowed_values { + list { + s: "Min" + s: "Max" + s: "Mul" + s: "Add" + } + } + } + attr { + name: "final_op" + type: "string" + allowed_values { + list { + s: "Id" + s: "Div" + } + } + } + attr { + name: "communication_hint" + type: "string" + default_value { + s: "auto" + } + } + attr { + name: "timeout_seconds" + type: "float" + default_value { + f: 0 + } + } + is_stateful: true +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormGradV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormGradV3.pbtxt index b1576ffb772..aa05a575bfe 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormGradV3.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormGradV3.pbtxt @@ -92,3 +92,99 @@ op { } } } +op { + name: "FusedBatchNormGradV3" + input_arg { + name: "y_backprop" + type_attr: "T" + } + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type: DT_FLOAT + } + input_arg { + name: "reserve_space_1" + type_attr: "U" + } + input_arg { + name: "reserve_space_2" + type_attr: "U" + } + input_arg { + name: "reserve_space_3" + type_attr: "U" + } + output_arg { + name: "x_backprop" + type_attr: "T" + } + output_arg { + name: "scale_backprop" + type_attr: "U" + } + output_arg { + name: "offset_backprop" + type_attr: "U" + } + output_arg { + name: "reserve_space_4" + type_attr: "U" + } + output_arg { + name: "reserve_space_5" + type_attr: "U" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + } + } + } + attr { + name: "U" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + s: "NDHWC" + s: "NCDHW" + } + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormV3.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormV3.pbtxt index 6df70795677..834ff00a0b3 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormV3.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/FusedBatchNormV3.pbtxt @@ -99,3 +99,106 @@ op { } } } +op { + name: "FusedBatchNormV3" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type_attr: "U" + } + input_arg { + name: "offset" + type_attr: "U" + } + input_arg { + name: "mean" + type_attr: "U" + } + input_arg { + name: "variance" + type_attr: "U" + } + output_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "batch_mean" + type_attr: "U" + } + output_arg { + name: "batch_variance" + type_attr: "U" + } + output_arg { + name: "reserve_space_1" + type_attr: "U" + } + output_arg { + name: "reserve_space_2" + type_attr: "U" + } + output_arg { + name: "reserve_space_3" + type_attr: "U" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + } + } + } + attr { + name: "U" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "exponential_avg_factor" + type: "float" + default_value { + f: 1 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + s: "NDHWC" + s: "NCDHW" + } + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/QuantizeAndDequantizeV4.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/QuantizeAndDequantizeV4.pbtxt new file mode 100644 index 00000000000..2a49131faaf --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/QuantizeAndDequantizeV4.pbtxt @@ -0,0 +1,79 @@ +op { + name: "QuantizeAndDequantizeV4" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "input_min" + type_attr: "T" + } + input_arg { + name: "input_max" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "signed_input" + type: "bool" + default_value { + b: true + } + } + attr { + name: "num_bits" + type: "int" + default_value { + i: 8 + } + } + attr { + name: "range_given" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "round_mode" + type: "string" + default_value { + s: "HALF_TO_EVEN" + } + allowed_values { + list { + s: "HALF_TO_EVEN" + s: "HALF_UP" + } + } + } + attr { + name: "narrow_range" + type: "bool" + default_value { + b: false + } + } + attr { + name: "axis" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/QuantizeAndDequantizeV4Grad.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/QuantizeAndDequantizeV4Grad.pbtxt new file mode 100644 index 00000000000..0bbe87452b1 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/QuantizeAndDequantizeV4Grad.pbtxt @@ -0,0 +1,50 @@ +op { + name: "QuantizeAndDequantizeV4Grad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "input_min" + type_attr: "T" + } + input_arg { + name: "input_max" + type_attr: "T" + } + output_arg { + name: "input_backprop" + type_attr: "T" + } + output_arg { + name: "input_min_backprop" + type_attr: "T" + } + output_arg { + name: "input_max_backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "axis" + type: "int" + default_value { + i: -1 + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/SnapshotDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SnapshotDatasetV2.pbtxt index f0868514182..4356d954c8f 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/SnapshotDatasetV2.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/SnapshotDatasetV2.pbtxt @@ -58,3 +58,77 @@ op { has_minimum: true } } +op { + name: "SnapshotDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "path" + type: DT_STRING + } + input_arg { + name: "reader_func_other_args" + type_list_attr: "Treader_func_args" + } + input_arg { + name: "shard_func_other_args" + type_list_attr: "Tshard_func_args" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "compression" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reader_prefix" + type: "string" + default_value { + s: "" + } + } + attr { + name: "writer_prefix" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reader_func" + type: "func" + } + attr { + name: "shard_func" + type: "func" + } + attr { + name: "Treader_func_args" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tshard_func_args" + type: "list(type)" + has_minimum: true + } +} diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 2d012fdfa3e..9838398a6fd 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -966,6 +966,8 @@ REGISTER_OP("SnapshotDatasetV2") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("compression: string = ''") + .Attr("reader_prefix: string = ''") + .Attr("writer_prefix: string = ''") .Attr("reader_func: func") .Attr("shard_func: func") .Attr("Treader_func_args: list(type) >= 0") diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 4ad676c37ea..91bcc3be49a 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -369,7 +370,7 @@ REGISTER_OP("TensorListFromTensor") &tensor_shape_except_first_dim)); c->set_output_handle_shapes_and_types( 0, std::vector{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); @@ -409,7 +410,7 @@ REGISTER_OP("TensorListReserve") TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype)); c->set_output_handle_shapes_and_types( 0, std::vector{ - {element_shape, element_dtype}}); + {element_shape, element_dtype, ST_TENSOR_LIST}}); return Status::OK(); }); @@ -481,7 +482,7 @@ REGISTER_OP("TensorListSetItem") c->set_output_handle_shapes_and_types(0, *handle_data); } else { c->set_output_handle_shapes_and_types( - 0, {{c->UnknownShape(), element_dtype}}); + 0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}}); } return Status::OK(); }); @@ -532,8 +533,8 @@ REGISTER_OP("TensorListScatter") shape_inference::ShapeHandle element_shape; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( 2, &element_shape)); - c->set_output_handle_shapes_and_types(0, - {{element_shape, element_dtype}}); + c->set_output_handle_shapes_and_types( + 0, {{element_shape, element_dtype, ST_TENSOR_LIST}}); c->set_output(0, c->Scalar()); return Status::OK(); }); @@ -552,8 +553,8 @@ REGISTER_OP("TensorListScatterV2") shape_inference::ShapeHandle element_shape; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( 2, &element_shape)); - c->set_output_handle_shapes_and_types(0, - {{element_shape, element_dtype}}); + c->set_output_handle_shapes_and_types( + 0, {{element_shape, element_dtype, ST_TENSOR_LIST}}); c->set_output(0, c->Scalar()); return Status::OK(); }); @@ -580,8 +581,8 @@ REGISTER_OP("TensorListScatterIntoExistingList") TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype)); element_shape = GetElementShapeFromHandleData(*handle_data); } - c->set_output_handle_shapes_and_types(0, - {{element_shape, element_dtype}}); + c->set_output_handle_shapes_and_types( + 0, {{element_shape, element_dtype, ST_TENSOR_LIST}}); c->set_output(0, c->Scalar()); return Status::OK(); }); @@ -606,7 +607,7 @@ REGISTER_OP("TensorListConcatLists") bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty(); if (!(handle_data_a_nonempty || handle_data_b_nonempty)) { c->set_output_handle_shapes_and_types( - 0, {{c->UnknownShape(), element_dtype}}); + 0, {{c->UnknownShape(), element_dtype, ST_TENSOR_LIST}}); return Status::OK(); } shape_inference::ShapeAndType list_shape_type_a = diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index f99c1f6922a..8948df2cef3 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -87,10 +87,35 @@ REGISTER_OP("LookupTableFind") return Status::OK(); }); +Status ValidateTableType(InferenceContext* c, + const ShapeAndType& key_shape_and_type, + const string& key_dtype_attr, + const ShapeAndType& value_shape_and_type, + const string& value_dtype_attr) { + DataType key_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); + if (key_shape_and_type.dtype != key_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(key_shape_and_type.dtype), " got ", + DataTypeString(key_dtype)); + } + DataType value_dtype; + TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype)); + if (value_shape_and_type.dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to read value with wrong dtype. " + "Expected ", + DataTypeString(value_shape_and_type.dtype), " got ", + DataTypeString(value_dtype)); + } + return Status::OK(); +} + Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys, const string& key_dtype_attr, const string& value_dtype_attr, - bool is_lookup, ShapeAndType* output_shape_and_type) { auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data == nullptr || handle_data->size() != 2) { @@ -99,57 +124,35 @@ Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys, } else { const ShapeAndType& key_shape_and_type = (*handle_data)[0]; const ShapeAndType& value_shape_and_type = (*handle_data)[1]; - DataType key_dtype; - TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); - if (key_shape_and_type.dtype != key_dtype) { - return errors::InvalidArgument( - "Trying to read value with wrong dtype. " - "Expected ", - DataTypeString(key_shape_and_type.dtype), " got ", - DataTypeString(key_dtype)); - } - DataType value_dtype; - TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype)); - if (value_shape_and_type.dtype != value_dtype) { - return errors::InvalidArgument( - "Trying to read value with wrong dtype. " - "Expected ", - DataTypeString(value_shape_and_type.dtype), " got ", - DataTypeString(value_dtype)); - } + TF_RETURN_IF_ERROR(ValidateTableType(c, key_shape_and_type, key_dtype_attr, + value_shape_and_type, + value_dtype_attr)); output_shape_and_type->dtype = value_shape_and_type.dtype; - - if (is_lookup) { - if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) { - int keys_rank = c->Rank(keys); - int key_suffix_rank = c->Rank(key_shape_and_type.shape); - if (keys_rank < key_suffix_rank) { - return errors::InvalidArgument( - "Expected keys to have suffix ", - c->DebugString(key_shape_and_type.shape), - " but saw shape: ", c->DebugString(keys)); - } - for (int d = 0; d < key_suffix_rank; d++) { - // Ensure the suffix of keys match what's in the Table. - DimensionHandle dim = c->Dim(key_shape_and_type.shape, d); - TF_RETURN_IF_ERROR( - c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys)); - } - std::vector keys_prefix_vec; - keys_prefix_vec.reserve(keys_rank - key_suffix_rank); - for (int d = 0; d < keys_rank - key_suffix_rank; ++d) { - keys_prefix_vec.push_back(c->Dim(keys, d)); - } - ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec); - TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix, - value_shape_and_type.shape, - &output_shape_and_type->shape)); - } else { - output_shape_and_type->shape = c->UnknownShape(); + if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) { + int keys_rank = c->Rank(keys); + int key_suffix_rank = c->Rank(key_shape_and_type.shape); + if (keys_rank < key_suffix_rank) { + return errors::InvalidArgument( + "Expected keys to have suffix ", + c->DebugString(key_shape_and_type.shape), + " but saw shape: ", c->DebugString(keys)); } - } else { - TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape, + for (int d = 0; d < key_suffix_rank; d++) { + // Ensure the suffix of keys match what's in the Table. + DimensionHandle dim = c->Dim(key_shape_and_type.shape, d); + TF_RETURN_IF_ERROR( + c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys)); + } + std::vector keys_prefix_vec; + keys_prefix_vec.reserve(keys_rank - key_suffix_rank); + for (int d = 0; d < keys_rank - key_suffix_rank; ++d) { + keys_prefix_vec.push_back(c->Dim(keys, d)); + } + ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec); + TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix, value_shape_and_type.shape, &output_shape_and_type->shape)); + } else { + output_shape_and_type->shape = c->UnknownShape(); } } return Status::OK(); @@ -175,8 +178,7 @@ REGISTER_OP("LookupTableFindV2") c, /*keys=*/c->input(1), /*key_dtype_attr=*/"Tin", - /*value_dtype_attr=*/"Tout", - /*is_lookup=*/true, &value_shape_and_type)); + /*value_dtype_attr=*/"Tout", &value_shape_and_type)); c->set_output(0, value_shape_and_type.shape); return Status::OK(); @@ -268,16 +270,18 @@ REGISTER_OP("LookupTableExportV2") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - ShapeHandle keys = c->UnknownShapeOfRank(1); - ShapeAndType value_shape_and_type; - TF_RETURN_IF_ERROR(ValidateTableResourceHandle( - c, - /*keys=*/keys, - /*key_dtype_attr=*/"Tkeys", - /*value_dtype_attr=*/"Tvalues", - /*is_lookup=*/false, &value_shape_and_type)); - c->set_output(0, keys); - c->set_output(1, value_shape_and_type.shape); + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr && handle_data->size() == 2) { + const ShapeAndType& key_shape_and_type = (*handle_data)[0]; + const ShapeAndType& value_shape_and_type = (*handle_data)[1]; + TF_RETURN_IF_ERROR(ValidateTableType(c, key_shape_and_type, + /*key_dtype_attr*/ "Tkeys", + value_shape_and_type, + /*value_dtype_attr*/ "Tvalues")); + } + // Different lookup tables have different output shapes. + c->set_output(0, c->UnknownShape()); + c->set_output(1, c->UnknownShape()); return Status::OK(); }); diff --git a/tensorflow/core/ops/lookup_ops_test.cc b/tensorflow/core/ops/lookup_ops_test.cc new file mode 100644 index 00000000000..ac899d59993 --- /dev/null +++ b/tensorflow/core/ops/lookup_ops_test.cc @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) { + ShapeInferenceTestOp op("LookupTableFindV2"); + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?"); + INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]"); + TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2") + .Input({"table_handle", 0, DT_RESOURCE}) + .Input({"keys", 0, DT_INT64}) + .Input({"default_value", 0, DT_FLOAT}) + .Attr("Tin", DT_INT64) + .Attr("Tout", DT_FLOAT) + .Finalize(&op.node_def)); + std::vector> types; + auto set_types = [&op, &types](DataType key_type, DataType value_type) { + types.emplace_back(); + auto& table = types.back(); + table.emplace_back("[3]", key_type); + table.emplace_back("[4]", value_type); + op.input_resource_handle_shapes_and_types = {&table, nullptr, nullptr}; + }; + // If there's no input handle shapes and types, output shape is unknown. + INFER_OK(op, "[];[?,3];[4]", "?"); + // Set input handle with mismatched key type. + set_types(DT_INT32, DT_FLOAT); + INFER_ERROR("read value with wrong dtype", op, "[];[?,3];[4]"); + // Set input handle with mismatched value type. + set_types(DT_INT64, DT_INT64); + INFER_ERROR("read value with wrong dtype", op, "[];[?,3];[4]"); + // Set input handle with matched types. + set_types(DT_INT64, DT_FLOAT); + INFER_OK(op, "[];[?,3];[4]", "[d1_0,4]"); + INFER_OK(op, "[];[1,3];[4]", "[d1_0,4]"); + INFER_OK(op, "[];[1,?];[4]", "[d1_0,4]"); +} + +TEST(LookupOpsTest, LookupTableExportV2_ShapeFn) { + ShapeInferenceTestOp op("LookupTableExportV2"); + TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableExportV2") + .Input({"table_handle", 0, DT_RESOURCE}) + .Attr("Tkeys", DT_INT64) + .Attr("Tvalues", DT_FLOAT) + .Finalize(&op.node_def)); + std::vector> types; + auto set_types = [&op, &types](DataType key_type, DataType value_type) { + types.emplace_back(); + auto& table = types.back(); + table.emplace_back("[3]", key_type); + table.emplace_back("[4]", value_type); + op.input_resource_handle_shapes_and_types = {&table}; + }; + // Set input handle with mismatched key type. + set_types(DT_INT32, DT_FLOAT); + INFER_ERROR("read value with wrong dtype", op, "[]"); + // Set input handle with mismatched value type. + set_types(DT_INT64, DT_INT64); + INFER_ERROR("read value with wrong dtype", op, "[]"); + // Set input handle with matched types. + set_types(DT_INT64, DT_FLOAT); + INFER_OK(op, "[]", "?;?"); +} + +// TODO(b/169969017): add shape fn tests for rest of the ops. + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index 53a0682dc46..1604527b941 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -186,6 +186,52 @@ REGISTER_OP("_MklFusedConv2D") is expected to create these operators. )doc"); +REGISTER_OP("_MklNativeFusedConv2D") + .Input("input: T") + .Input("filter: T") + .Input("args: num_args * T") + .Output("output: T") + .Attr("T: {bfloat16, float}") + .Attr("num_args: int >= 0") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = false") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetExplicitPaddingsAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ------------------------------------ // + .Attr("epsilon: float = 0.0001") + // Attributes for the LeakyRelu ----------------------------------------- // + .Attr("leakyrelu_alpha: float = 0.2") + // ---------------------------------------------------------------------- // + .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) + .Doc(R"doc( +*NOTE*: Do not invoke this operator directly in Python. oneDNN graph transformer + is expected to create these operators. +)doc"); + +REGISTER_OP("_MklNativeConv2DWithBias") + .Input("input: T") + .Input("filter: T") + .Input("bias: T") + .Output("output: T") + .Attr("T: {bfloat16, float}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr("is_filter_const: bool = false") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .SetShapeFn(shape_inference::Conv2DShape) + .Doc(R"doc( +MKL version of Conv2D and BiasAdd operator. Uses oneDNN APIs to perform +2D convolution and add Bias to the output of convolution. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke this operator. +)doc"); + REGISTER_OP("_MklFusedDepthwiseConv2dNative") .Input("input: T") .Input("filter: T") @@ -212,6 +258,26 @@ REGISTER_OP("_MklFusedDepthwiseConv2dNative") // ---------------------------------------------------------------------- // .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); +REGISTER_OP("_MklNativeFusedDepthwiseConv2dNative") + .Input("input: T") + .Input("filter: T") + .Input("args: num_args * T") + .Output("output: T") + .Attr("T: {bfloat16, float}") + .Attr("num_args: int >= 0") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = false") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ------------------------------------ // + .Attr("epsilon: float = 0.0001") + // Attributes for the LeakyRelu ----------------------------------------- // + .Attr("leakyrelu_alpha: float = 0.2") + // ---------------------------------------------------------------------- // + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); + REGISTER_OP("_MklFusedMatMul") .Input("a: T") .Input("b: T") @@ -239,6 +305,29 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +REGISTER_OP("_MklNativeFusedMatMul") + .Input("a: T") + .Input("b: T") + .Input("args: num_args * T") + .Output("product: T") + .Attr("is_filter_const: bool = false") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("T: {bfloat16, float}") + .Attr("num_args: int >= 0") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ----------- // + .Attr("epsilon: float = 0.0001") + // --------------------------------------------- // + .SetShapeFn(shape_inference::MatMulShape) + .Doc(R"doc( +oneDNN version of FusedMatMul operator that does not depend +on layout propagation. Uses oneDNN APIs to implement MatMul fusion. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke this one. +)doc"); + REGISTER_OP("__MklDummyPadWithFusedConv2D") .Input("input: T") .Input("filter: T") @@ -300,6 +389,205 @@ REGISTER_OP("_MklPadWithFusedConv2D") is expected to create these operators. )doc"); +REGISTER_OP("_MklNativePadWithFusedConv2D") + .Input("input: T") + .Input("filter: T") + .Input("args: num_args * T") + .Input("paddings: Tpaddings") + .Output("output: T") + .Attr("T: {bfloat16, float}") + .Attr("num_args: int >= 0") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = false") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("fused_ops: list(string) = []") + .Attr("Tpaddings: {int32, int64} = DT_INT32") + // Attributes for the FusedBatchNorm ------------------------------------ // + .Attr("epsilon: float = 0.0001") + // Attributes for the LeakyRelu ----------------------------------------- // + .Attr("leakyrelu_alpha: float = 0.2") + // ---------------------------------------------------------------------- // + .SetShapeFn(shape_inference::Conv2DShape) + .Doc(R"doc( +*NOTE*: Do not invoke this operator directly in Python. oneDNN graph transformer + is expected to create these operators. +)doc"); + +REGISTER_OP("_MklNativePadWithConv2D") + .Input("input: T") + .Input("filter: T") + .Input("paddings: Tpaddings") + .Output("output: T") + .Attr("T: {bfloat16, float}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("is_filter_const: bool = false") + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("Tpaddings: {int32, int64} = DT_INT32") + .SetShapeFn(shape_inference::Conv2DShape) + .Doc(R"doc( +MKL version of Pad and Conv2D fusion that does not depend +on layout propagation. Uses oneDNN APIs to perform +the fusion. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeAvgPool") + .Input("value: T") + .Output("output: T") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("T: {float, half, double, bfloat16}") + .SetShapeFn(shape_inference::AvgPoolShape) + .Doc(R"doc( +oneDNN version of AvgPool operator that does not depend on layout +propagation. Uses oneDNN APIs to perform average pooling on the input. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeAvgPoolGrad") + .Input("orig_input_shape: int32") + .Input("grad: T") + .Output("output: T") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("T: {float, half, double, bfloat16}") + .SetShapeFn(shape_inference::AvgPoolGradShape) + .Doc(R"doc( +oneDNN version of AvgPoolGrad operator that does not depend on layout +propagation. Uses oneDNN APIs to compute gradients of AvgPool operator. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeAvgPool3D") + .Input("value: T") + .Output("output: T") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {float, half, double, bfloat16}") + .SetShapeFn(shape_inference::Pool3DShape) + .Doc(R"doc( +oneDNN version of AvgPool3D operator that does not depend on layout +propagation. Uses oneDNN APIs to perform 3D average pooling on the input. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeAvgPool3DGrad") + .Input("orig_input_shape: int32") + .Input("grad: T") + .Output("output: T") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {float, half, double, bfloat16}") + .SetShapeFn(shape_inference::AvgPool3DGradShape) + .Doc(R"doc( +oneDNN version of AvgPool3DGrad operator that does not depend on layout +propagation. Uses oneDNN APIs to compute gradients of AvgPool3D function. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeMaxPool") + .Attr("T: {float, half, bfloat16} = DT_FLOAT") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetExplicitPaddingsAttrString()) + .Attr("workspace_enabled: bool = false") + .Input("input: T") + .Output("output: T") + .Output("workspace: uint8") + .SetShapeFn(shape_inference::MaxPoolShape) + .Doc(R"doc( +oneDNN version of MaxPool operator that does not depend +on layout propagation. Uses oneDNN APIs to perform max pooling +on the input. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeMaxPoolGrad") + .Attr("T: {float, half, bfloat16} = DT_FLOAT") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr("workspace_enabled: bool = false") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetExplicitPaddingsAttrString()) + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Input("workspace: uint8") + .Output("output: T") + .SetShapeFn(shape_inference::MaxPoolGradShape) + .Doc(R"doc( +oneDNN version of MaxPoolGrad that does not depend on layout propagation. +Uses oneDNN APIs to compute gradients of MaxPool operator. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeMaxPool3D") + .Input("input: T") + .Output("output: T") + .Output("workspace: uint8") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {half, bfloat16, float}") + .Attr("workspace_enabled: bool = false") + .SetShapeFn(shape_inference::Pool3DShape) + .Doc(R"doc( +oneDNN version of MaxPool3D operator that does not depend on layout propagation. +Uses oneDNN APIs to perform 3D max pooling on the input. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeMaxPool3DGrad") + .Input("orig_input: TInput") + .Input("orig_output: TInput") + .Input("grad: T") + .Input("workspace: uint8") + .Output("output: T") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {half, bfloat16, float} = DT_FLOAT") + .Attr("TInput: {half, bfloat16, float} = DT_FLOAT") + .Attr("workspace_enabled: bool = false") + .SetShapeFn(shape_inference::MaxPool3DGradShape) + .Doc(R"doc( +oneDNN version of MaxPool3DGrad operator that does not depend on layout +propagation. Uses oneDNN APIs to compute gradients of MaxPool3D function. +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + REGISTER_OP("_MklQuantizedMaxPool") .Input("input: T") .Input("min_input: float") @@ -1542,6 +1830,170 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +REGISTER_OP("_MklNativeFusedBatchNorm") + .Input("x: T") + .Input("scale: T") + .Input("offset: T") + .Input("mean: T") + .Input("variance: T") + .Output("y: T") + .Output("batch_mean: T") + .Output("batch_variance: T") + .Output("reserve_space_1: T") + .Output("reserve_space_2: T") + .Attr("T: numbertype") + .Attr("epsilon: float = 0.0001") + .Attr("data_format: string = 'NHWC'") + .Attr("exponential_avg_factor: float = 1.0") + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormShape) + .Doc(R"doc( +oneDNN version of FusedBatchNorm operator that does not depend on layout +propagation. Uses oneDNN APIs to perform fused batch normalization. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeFusedBatchNormGrad") + .Input("y_backprop: T") + .Input("x: T") + .Input("scale: T") + .Input("reserve_space_1: T") + .Input("reserve_space_2: T") + .Output("x_backprop: T") + .Output("scale_backprop: T") + .Output("offset_backprop: T") + .Output("reserve_space_3: T") + .Output("reserve_space_4: T") + .Attr("T: numbertype") + .Attr("epsilon: float = 0.0001") + .Attr("data_format: string = 'NHWC'") + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormGradShape) + .Doc(R"doc( +oneDNN version of FusedBatchNormGrad operator that does not depend +on layout propagation. Uses oneDNN APIs to compute gradients for fused +batch normalization. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklNativeFusedBatchNormV2") + .Input("x: T") + .Input("scale: U") + .Input("offset: U") + .Input("mean: U") + .Input("variance: U") + .Output("y: T") + .Output("batch_mean: U") + .Output("batch_variance: U") + .Output("reserve_space_1: U") + .Output("reserve_space_2: U") + .Attr("T: {bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("exponential_avg_factor: float = 1.0") + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormShape); + +REGISTER_OP("_MklNativeFusedBatchNormGradV2") + .Input("y_backprop: T") + .Input("x: T") + .Input("scale: float") + .Input("reserve_space_1: U") + .Input("reserve_space_2: U") + .Output("x_backprop: T") + .Output("scale_backprop: U") + .Output("offset_backprop: U") + .Output("reserve_space_3: U") + .Output("reserve_space_4: U") + .Attr("T: {bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormGradShape); + +REGISTER_OP("_MklNativeFusedBatchNormV3") + .Input("x: T") + .Input("scale: U") + .Input("offset: U") + .Input("mean: U") + .Input("variance: U") + .Output("y: T") + .Output("batch_mean: U") + .Output("batch_variance: U") + .Output("reserve_space_1: U") + .Output("reserve_space_2: U") + .Output("reserve_space_3: U") + .Attr("T: {half, bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("exponential_avg_factor: float = 1.0") + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormShape) + .Doc( + R"doc(oneDNN version of FusedBatchNormV3 operator that does not depend + on layout propagation. Do not invoke this operator directly in Python. + Graph rewrite pass is expected to invoke this operator.)doc"); + +REGISTER_OP("_MklNativeFusedBatchNormGradV3") + .Input("y_backprop: T") + .Input("x: T") + .Input("scale: float") + .Input("reserve_space_1: U") + .Input("reserve_space_2: U") + .Input("reserve_space_3: U") + .Output("x_backprop: T") + .Output("scale_backprop: U") + .Output("offset_backprop: U") + .Output("reserve_space_4: U") + .Output("reserve_space_5: U") + .Attr("T: {half, bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormGradShape) + .Doc( + R"doc(oneDNN version of FusedBatchNormGradV3 that does not depend + on layout propagation. Do not invoke this operator directly in Python. + Graph rewrite pass is expected to invoke this operator.)doc"); + +REGISTER_OP("_MklNativeFusedBatchNormEx") + .Input("x: T") + .Input("scale: U") + .Input("offset: U") + .Input("mean: U") + .Input("variance: U") + .Input("side_input: num_side_inputs * T") + .Output("y: T") + .Output("batch_mean: U") + .Output("batch_variance: U") + .Output("reserve_space_1: U") + .Output("reserve_space_2: U") + .Output("reserve_space_3: U") + .Attr("T: {bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr("exponential_avg_factor: float = 1.0") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("num_side_inputs: int >= 0 = 0") + .Attr("activation_mode: string = \"Identity\"") + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormShape) + .Doc(R"doc( +oneDNN version of FusedBatchNormEx operator that does not depend on layout propagation. +Uses oneDNN APIs to perform fused batch normalization and relu. + +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 2b6330db4aa..aeef3f6cf89 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -88,13 +88,7 @@ REGISTER_OP("AvgPoolGrad") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) .Attr("T: {half, bfloat16, float, double}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); - TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); - c->set_output(0, s); - return Status::OK(); - }); + .SetShapeFn(shape_inference::AvgPoolGradShape); // -------------------------------------------------------------------------- @@ -221,7 +215,7 @@ REGISTER_OP("FusedBatchNormV3") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") .Attr("exponential_avg_factor: float = 1.0") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormV3Shape); @@ -308,7 +302,7 @@ REGISTER_OP("FusedBatchNormGradV3") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormGradShape); // -------------------------------------------------------------------------- @@ -742,13 +736,7 @@ REGISTER_OP("AvgPool3DGrad") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) .Attr("T: {half, bfloat16, float, double}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); - TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); - c->set_output(0, s); - return Status::OK(); - }); + .SetShapeFn(shape_inference::AvgPool3DGradShape); // -------------------------------------------------------------------------- @@ -773,9 +761,7 @@ REGISTER_OP("MaxPool3DGrad") .Attr(GetConvnet3dDataFormatAttrString()) .Attr("T: {half, bfloat16, float} = DT_FLOAT") .Attr("TInput: {half, bfloat16, float} = DT_FLOAT") - .SetShapeFn([](InferenceContext* c) { - return UnchangedShapeWithRank(c, 5); - }); + .SetShapeFn(shape_inference::MaxPool3DGradShape); REGISTER_OP("MaxPool3DGradGrad") .Input("orig_input: T") @@ -879,9 +865,7 @@ REGISTER_OP("MaxPoolGrad") .Input("grad: T") .Output("output: T") .Attr("T: realnumbertype = DT_FLOAT") - .SetShapeFn([](InferenceContext* c) { - return UnchangedShapeWithRank(c, 4); - }); + .SetShapeFn(shape_inference::MaxPoolGradShape); REGISTER_OP("MaxPoolGradV2") .Attr(GetPaddingAttrString()) @@ -893,9 +877,7 @@ REGISTER_OP("MaxPoolGradV2") .Input("strides: int32") .Output("output: T") .Attr("T: realnumbertype = DT_FLOAT") - .SetShapeFn([](InferenceContext* c) { - return UnchangedShapeWithRank(c, 4); - }); + .SetShapeFn(shape_inference::MaxPoolGradShape); // TODO(b/150813181): Implement explicit padding. REGISTER_OP("MaxPoolGradGrad") @@ -2343,14 +2325,12 @@ REGISTER_OP("_MklMaxPoolGrad") .Input("mkl_workspace: uint8") .Output("output: T") .Output("mkl_output: uint8") - .SetShapeFn([](InferenceContext* c) { - return UnchangedShapeWithRank(c, 4); - }) + .SetShapeFn(shape_inference::MaxPoolGradShape) .Doc(R"doc( -MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of +oneDNN version of MaxPoolGrad. Uses oneDNN APIs to compute gradients of MaxPool operator. -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); @@ -2385,18 +2365,12 @@ REGISTER_OP("_MklAvgPoolGrad") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) .Attr("T: {float, half, double, bfloat16}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); - TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); - c->set_output(0, s); - return Status::OK(); - }) + .SetShapeFn(shape_inference::AvgPoolGradShape) .Doc(R"doc( -MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients +oneDNN version of AvgPoolGrad operator. Uses oneDNN APIs to compute gradients of AvgPool function. -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); @@ -2431,18 +2405,12 @@ REGISTER_OP("_MklAvgPool3DGrad") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) .Attr("T: {float, half, double, bfloat16}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); - TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); - c->set_output(0, s); - return Status::OK(); - }) + .SetShapeFn(shape_inference::AvgPool3DGradShape) .Doc(R"doc( -MKL version of AvgPool3DGrad operator. Uses MKL DNN APIs to compute gradients +oneDNN version of AvgPool3DGrad operator. Uses oneDNN APIs to compute gradients of AvgPool function. -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); @@ -2486,14 +2454,12 @@ REGISTER_OP("_MklMaxPool3DGrad") .Attr("T: {half, bfloat16, float} = DT_FLOAT") .Attr("TInput: {half, bfloat16, float} = DT_FLOAT") .Attr("workspace_enabled: bool = false") - .SetShapeFn([](InferenceContext* c) { - return UnchangedShapeWithRank(c, 5); - }) + .SetShapeFn(shape_inference::MaxPool3DGradShape) .Doc(R"doc( -MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients -of MklPool function. +oneDNN version of MaxPool3DGrad operator. Uses oneDNN APIs to compute gradients +of MaxPool3D function. -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); @@ -2580,44 +2546,12 @@ REGISTER_OP("_MklFusedBatchNorm") .Attr("data_format: string = 'NHWC'") .Attr("exponential_avg_factor: float = 1.0") .Attr("is_training: bool = true") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle x; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); - - bool is_training; - TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); - int number_inputs = (is_training) ? 3 : 5; - string data_format; - TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); - DimensionHandle channel_dim = - (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1); - - // covers scale, offset, and if is_training is false, mean, variance - for (int i = 1; i < number_inputs; ++i) { - ShapeHandle vec; - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); - TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); - } - - ShapeHandle y; - if (data_format == "NHWC") { - TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y)); - } else { - TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y)); - } - c->set_output(0, y); - ShapeHandle vector_shape = c->Vector(channel_dim); - c->set_output(1, vector_shape); - c->set_output(2, vector_shape); - c->set_output(3, vector_shape); - c->set_output(4, vector_shape); - return Status::OK(); - }) + .SetShapeFn(shape_inference::FusedBatchNormShape) .Doc(R"doc( -MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused +oneDNN version of FusedBatchNorm operator. Uses oneDNN APIs to perform fused batch normalization. -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); @@ -2646,60 +2580,12 @@ REGISTER_OP("_MklFusedBatchNormGrad") .Attr("epsilon: float = 0.0001") .Attr("data_format: string = 'NHWC'") .Attr("is_training: bool = true") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle y_backprop; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); - ShapeHandle x; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); - - bool is_training; - string data_format; - TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); - TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); - DimensionHandle channel_dim = (data_format == "NHWC") - ? c->Dim(y_backprop, 3) - : c->Dim(y_backprop, 1); - if (data_format == "NHWC") { - TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim)); - } else { - TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim)); - } - - // covers scale, mean (reserve_space_1), variance (reserve_space_2) - for (int i = 2; i < 5; ++i) { - ShapeHandle vec; - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); - TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); - } - - ShapeHandle x_backprop; - if (data_format == "NHWC") { - TF_RETURN_IF_ERROR( - c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop)); - } else { - TF_RETURN_IF_ERROR( - c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop)); - } - c->set_output(0, x_backprop); - c->set_output(1, c->Vector(channel_dim)); - c->set_output(2, c->Vector(channel_dim)); - // Set the correct shapes for reserve_spaces - // so that gradients can be performed when - // the op is in a symbolic condition. - if (is_training) { - c->set_output(3, c->Vector(0)); - c->set_output(4, c->Vector(0)); - } else { - c->set_output(3, c->Vector(channel_dim)); - c->set_output(4, c->Vector(channel_dim)); - } - return Status::OK(); - }) + .SetShapeFn(shape_inference::FusedBatchNormGradShape) .Doc(R"doc( -MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute +oneDNN version of FusedBatchNormGrad operator. Uses oneDNN APIs to compute gradients for fused batch normalization. -NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index ac990696d5d..250dfc49f60 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7791,6 +7791,13 @@ op { s: "auto" } } + attr { + name: "timeout_seconds" + type: "float" + default_value { + f: 0 + } + } is_stateful: true } op { @@ -17468,6 +17475,8 @@ op { list { s: "NHWC" s: "NCHW" + s: "NDHWC" + s: "NCDHW" } } } @@ -17666,6 +17675,8 @@ op { list { s: "NHWC" s: "NCHW" + s: "NDHWC" + s: "NCDHW" } } } @@ -29861,6 +29872,135 @@ op { } } } +op { + name: "QuantizeAndDequantizeV4" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "input_min" + type_attr: "T" + } + input_arg { + name: "input_max" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "signed_input" + type: "bool" + default_value { + b: true + } + } + attr { + name: "num_bits" + type: "int" + default_value { + i: 8 + } + } + attr { + name: "range_given" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "round_mode" + type: "string" + default_value { + s: "HALF_TO_EVEN" + } + allowed_values { + list { + s: "HALF_TO_EVEN" + s: "HALF_UP" + } + } + } + attr { + name: "narrow_range" + type: "bool" + default_value { + b: false + } + } + attr { + name: "axis" + type: "int" + default_value { + i: -1 + } + } +} +op { + name: "QuantizeAndDequantizeV4Grad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "input_min" + type_attr: "T" + } + input_arg { + name: "input_max" + type_attr: "T" + } + output_arg { + name: "input_backprop" + type_attr: "T" + } + output_arg { + name: "input_min_backprop" + type_attr: "T" + } + output_arg { + name: "input_max_backprop" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "axis" + type: "int" + default_value { + i: -1 + } + } +} op { name: "QuantizeDownAndShrinkRange" input_arg { @@ -44424,6 +44564,20 @@ op { s: "" } } + attr { + name: "reader_prefix" + type: "string" + default_value { + s: "" + } + } + attr { + name: "writer_prefix" + type: "string" + default_value { + s: "" + } + } attr { name: "reader_func" type: "func" diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc index 164d78e8e9e..792c85bb79b 100644 --- a/tensorflow/core/ops/tpu_embedding_ops.cc +++ b/tensorflow/core/ops/tpu_embedding_ops.cc @@ -70,7 +70,6 @@ REGISTER_OP("RecvTPUEmbeddingActivations") if (!config.ParseFromString(config_string)) { return errors::InvalidArgument("Malformed tpu_embedding_config."); } - tpu::AddDefaultEmbeddingOutputLayoutIfNeeded(&config); std::vector output_shapes; TF_RETURN_IF_ERROR(ComputeOutputTensorShapes(config, &output_shapes)); if (c->num_outputs() != output_shapes.size()) { diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 7e2955ce044..0a324f31023 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -61,7 +61,9 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( - default_visibility = ["//tensorflow:__subpackages__"], + default_visibility = [ + "//visibility:public", + ], licenses = ["notice"], # Apache 2.0 ) @@ -250,6 +252,11 @@ cc_library( deps = tf_windows_aware_platform_deps("env"), ) +cc_library( + name = "env_impl", + deps = tf_windows_aware_platform_deps("env_impl"), +) + cc_library( name = "env_time", textual_hdrs = ["env_time.h"], @@ -382,8 +389,6 @@ cc_library( name = "mutex", compatible_with = get_compatible_with_portable(), textual_hdrs = ["mutex.h"], - # TODO(b/161569340): Short-term fix. Remove this visibility rule. - visibility = ["//tensorflow:__subpackages__"], deps = tf_platform_deps("mutex"), ) @@ -622,10 +627,6 @@ filegroup( srcs = [ "stacktrace_handler.h", ], - visibility = [ - "//tensorflow/core:__pkg__", - "//tensorflow/python:__pkg__", - ], ) cc_library( @@ -702,7 +703,6 @@ cc_library( cc_library( name = "strong_hash", hdrs = ["strong_hash.h"], - visibility = ["//visibility:public"], deps = [ ":platform", ":types", @@ -803,11 +803,6 @@ cc_library( name = "types", hdrs = ["types.h"], compatible_with = get_compatible_with_portable(), - # TODO(b/161569340): Short-term fix. Remove this visibility rule. - visibility = [ - "//tensorflow:__subpackages__", - "//tensorflow_text:__subpackages__", - ], deps = [ ":platform", ":bfloat16", @@ -940,7 +935,6 @@ tf_cuda_library( "stream_executor.h", ], features = ["-parse_headers"], - visibility = ["//tensorflow/core:__pkg__"], deps = [ "//tensorflow/core/platform/default/build_config:stream_executor", ], @@ -955,7 +949,6 @@ cc_library( "stream_executor_no_cuda.h", ], features = ["-parse_headers"], - visibility = ["//tensorflow/core:__pkg__"], deps = [ "//tensorflow/core/platform/default/build_config:stream_executor_no_cuda", ], @@ -970,7 +963,6 @@ cc_library( "//tensorflow:windows": [], "//conditions:default": ["-lm"], }), - visibility = ["//tensorflow/core:__pkg__"], deps = [ ":platform", ":stacktrace_handler", @@ -1013,6 +1005,7 @@ cc_library( filegroup( name = "tensor_float_32_hdr", srcs = ["tensor_float_32_utils.h"], + compatible_with = get_compatible_with_portable(), ) tf_cc_tests( @@ -1714,6 +1707,13 @@ bzl_library( srcs = [ "build_config_root.bzl", ] + tf_platform_alias("build_config_root.bzl"), + visibility = [ + "//learning/brain/google/xla/tests:__pkg__", + "//learning/brain/tfrc/runtime/tpu_driver:__subpackages__", + "//nlp/deleuze:__pkg__", + "//nlp/projects/minmt:__pkg__", + "//tensorflow:__subpackages__", + ], ) bzl_library( diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl index 565f4962915..e58c70313ad 100644 --- a/tensorflow/core/platform/build_config.bzl +++ b/tensorflow/core/platform/build_config.bzl @@ -28,7 +28,6 @@ load( _tf_platform_deps = "tf_platform_deps", _tf_portable_deps_no_runtime = "tf_portable_deps_no_runtime", _tf_portable_proto_lib = "tf_portable_proto_lib", - _tf_profiler_client_deps = "tf_profiler_client_deps", _tf_proto_library = "tf_proto_library", _tf_proto_library_cc = "tf_proto_library_cc", _tf_protobuf_compiler_deps = "tf_protobuf_compiler_deps", @@ -82,7 +81,6 @@ tf_protos_grappler = _tf_protos_grappler tf_protos_grappler_impl = _tf_protos_grappler_impl tf_protos_profiler_impl = _tf_protos_profiler_impl tf_protos_profiler_service = _tf_protos_profiler_service -tf_profiler_client_deps = _tf_profiler_client_deps tf_py_clif_cc = _tf_py_clif_cc tf_pyclif_proto_library = _tf_pyclif_proto_library tf_resource_deps = _tf_resource_deps diff --git a/tensorflow/core/platform/build_config_root.bzl b/tensorflow/core/platform/build_config_root.bzl index b82e1041695..eea50174b38 100644 --- a/tensorflow/core/platform/build_config_root.bzl +++ b/tensorflow/core/platform/build_config_root.bzl @@ -5,7 +5,6 @@ load( _if_dynamic_kernels = "if_dynamic_kernels", _if_static = "if_static", _if_static_and_not_mobile = "if_static_and_not_mobile", - _register_extension_info = "register_extension_info", _tf_additional_grpc_deps_py = "tf_additional_grpc_deps_py", _tf_additional_license_deps = "tf_additional_license_deps", _tf_additional_plugin_deps = "tf_additional_plugin_deps", @@ -19,7 +18,6 @@ load( if_dynamic_kernels = _if_dynamic_kernels if_static = _if_static if_static_and_not_mobile = _if_static_and_not_mobile -register_extension_info = _register_extension_info tf_additional_grpc_deps_py = _tf_additional_grpc_deps_py tf_additional_license_deps = _tf_additional_license_deps tf_additional_plugin_deps = _tf_additional_plugin_deps diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 74553f33a3d..214ac4409d5 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "tensorflow/core/platform/file_statistics.h" #include "tensorflow/core/platform/strcat.h" #ifdef _WIN32 #include // for _mktemp @@ -405,6 +406,11 @@ typedef std::function StatusPoller; +// Function object declaration with params needed to poll upload status. +typedef std::function + GenerationGetter; + /// \brief GCS-based implementation of a writeable file. /// /// Since GCS objects are immutable, this implementation writes to a local @@ -417,7 +423,8 @@ class GcsWritableFile : public WritableFile { std::function file_cache_erase, RetryConfig retry_config, bool compose_append, SessionCreator session_creator, - ObjectUploader object_uploader, StatusPoller status_poller) + ObjectUploader object_uploader, StatusPoller status_poller, + GenerationGetter generation_getter) : bucket_(bucket), object_(object), filesystem_(filesystem), @@ -429,7 +436,8 @@ class GcsWritableFile : public WritableFile { start_offset_(0), session_creator_(std::move(session_creator)), object_uploader_(std::move(object_uploader)), - status_poller_(std::move(status_poller)) { + status_poller_(std::move(status_poller)), + generation_getter_(std::move(generation_getter)) { // TODO: to make it safer, outfile_ should be constructed from an FD VLOG(3) << "GcsWritableFile: " << GetGcsPath(); if (GetTmpFilename(&tmp_content_filename_).ok()) { @@ -449,7 +457,8 @@ class GcsWritableFile : public WritableFile { std::function file_cache_erase, RetryConfig retry_config, bool compose_append, SessionCreator session_creator, - ObjectUploader object_uploader, StatusPoller status_poller) + ObjectUploader object_uploader, StatusPoller status_poller, + GenerationGetter generation_getter) : bucket_(bucket), object_(object), filesystem_(filesystem), @@ -461,7 +470,8 @@ class GcsWritableFile : public WritableFile { start_offset_(0), session_creator_(std::move(session_creator)), object_uploader_(std::move(object_uploader)), - status_poller_(std::move(status_poller)) { + status_poller_(std::move(status_poller)), + generation_getter_(std::move(generation_getter)) { VLOG(3) << "GcsWritableFile: " << GetGcsPath() << "with existing file " << tmp_content_filename; tmp_content_filename_ = tmp_content_filename; @@ -632,32 +642,40 @@ class GcsWritableFile : public WritableFile { /// Appends the data of append_object to the original object and deletes /// append_object. Status AppendObject(string append_object) { - VLOG(3) << "AppendObject: " << GetGcsPathWithObject(append_object) << " to " - << GetGcsPath(); - std::unique_ptr request; - TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request)); + const string append_object_path = GetGcsPathWithObject(append_object); + VLOG(3) << "AppendObject: " << append_object_path << " to " << GetGcsPath(); - request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket_, "/o/", - request->EscapeString(object_), - "/compose")); + int64 generation = 0; + TF_RETURN_IF_ERROR( + generation_getter_(GetGcsPath(), bucket_, object_, &generation)); - const string request_body = - strings::StrCat("{'sourceObjects': [{'name': '", object_, - "'},{'name': '", append_object, "'}]}"); - request->SetTimeouts(timeouts_->connect, timeouts_->idle, - timeouts_->metadata); - request->AddHeader("content-type", "application/json"); - request->SetPostFromBuffer(request_body.c_str(), request_body.size()); - return RetryingUtils::CallWithRetries( - [&request, &append_object, this]() { + TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( + [&append_object, &generation, this]() { + std::unique_ptr request; + TF_RETURN_IF_ERROR(filesystem_->CreateHttpRequest(&request)); + + request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket_, "/o/", + request->EscapeString(object_), + "/compose")); + + const string request_body = strings::StrCat( + "{'sourceObjects': [{'name': '", object_, + "','objectPrecondition':{'ifGenerationMatch':", generation, + "}},{'name': '", append_object, "'}]}"); + request->SetTimeouts(timeouts_->connect, timeouts_->idle, + timeouts_->metadata); + request->AddHeader("content-type", "application/json"); + request->SetPostFromBuffer(request_body.c_str(), request_body.size()); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when composing to ", GetGcsPath()); - TF_RETURN_WITH_CONTEXT_IF_ERROR( - filesystem_->DeleteFile(GetGcsPathWithObject(append_object), - nullptr), - " when cleaning up."); return Status::OK(); }, + retry_config_)); + + return RetryingUtils::DeleteWithRetries( + [&append_object_path, this]() { + return filesystem_->DeleteFile(append_object_path, nullptr); + }, retry_config_); } @@ -712,6 +730,7 @@ class GcsWritableFile : public WritableFile { const SessionCreator session_creator_; const ObjectUploader object_uploader_; const StatusPoller status_poller_; + const GenerationGetter generation_getter_; }; class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { @@ -1265,10 +1284,23 @@ Status GcsFileSystem::NewWritableFile(const string& fname, completed, uploaded); }; + auto generation_getter = [this](const string& fname, const string& bucket, + const string& object, int64* generation) { + GcsFileStat stat; + TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( + [&fname, &bucket, &object, &stat, this]() { + return UncachedStatForObject(fname, bucket, object, &stat); + }, + retry_config_)); + *generation = stat.generation_number; + return Status::OK(); + }; + result->reset(new GcsWritableFile( bucket, object, this, &timeouts_, [this, fname]() { ClearFileCaches(fname); }, retry_config_, - compose_append_, session_creator, object_uploader, status_poller)); + compose_append_, session_creator, object_uploader, status_poller, + generation_getter)); return Status::OK(); } @@ -1329,13 +1361,26 @@ Status GcsFileSystem::NewAppendableFile(const string& fname, completed, uploaded); }; + auto generation_getter = [this](const string& fname, const string& bucket, + const string& object, int64* generation) { + GcsFileStat stat; + TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( + [&fname, &bucket, &object, &stat, this]() { + return UncachedStatForObject(fname, bucket, object, &stat); + }, + retry_config_)); + *generation = stat.generation_number; + return Status::OK(); + }; + // Create a writable file and pass the old content to it. string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); result->reset(new GcsWritableFile( bucket, object, this, old_content_filename, &timeouts_, [this, fname]() { ClearFileCaches(fname); }, retry_config_, - compose_append_, session_creator, object_uploader, status_poller)); + compose_append_, session_creator, object_uploader, status_poller, + generation_getter)); return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 34f86f5107c..6700b200e0a 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -3882,6 +3882,15 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "Put body: ", contents[2], "\n"), ""), + // Fetch generation + new FakeHttpRequest( + "Uri: " + "https://www.googleapis.com/storage/v1/b/bucket/o/" + "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + strings::StrCat("{\"size\": \"8\",\"generation\": \"1234\"," + "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), // Compose the new part at the end of the original object. new FakeHttpRequest("Uri: " "https://www.googleapis.com/storage/v1/b/bucket/o/" @@ -3890,7 +3899,9 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "Timeouts: 5 1 10\n" "Header content-type: application/json\n" "Post body: {'sourceObjects': [{'name': " - "'some/path/appendable'},{'name': " + "'some/path/" + "appendable','objectPrecondition':{'" + "ifGenerationMatch':1234}},{'name': " "'some/path/.tmpcompose/appendable.18'}]}\n", ""), // Delete the temporary object. @@ -3918,6 +3929,15 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "Put body: ", contents[3], "\n"), ""), + // Fetch generation + new FakeHttpRequest( + "Uri: " + "https://www.googleapis.com/storage/v1/b/bucket/o/" + "some%2Fpath%2Fappendable?fields=size%2Cgeneration%2Cupdated\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + strings::StrCat("{\"size\": \"8\",\"generation\": \"4567\"," + "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest("Uri: " "https://www.googleapis.com/storage/v1/b/bucket/o/" "some%2Fpath%2Fappendable/compose\n" @@ -3925,7 +3945,9 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithCompose) { "Timeouts: 5 1 10\n" "Header content-type: application/json\n" "Post body: {'sourceObjects': [{'name': " - "'some/path/appendable'},{'name': " + "'some/path/" + "appendable','objectPrecondition':{'" + "ifGenerationMatch':4567}},{'name': " "'some/path/.tmpcompose/appendable.27'}]}\n", ""), new FakeHttpRequest("Uri: " diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index a1f23563cf5..849048e99be 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -74,7 +74,6 @@ cc_library( cc_library( name = "env", srcs = [ - "env.cc", "posix_file_system.cc", "posix_file_system.h", "//tensorflow/core/platform:env.cc", @@ -129,6 +128,21 @@ cc_library( ], ) +cc_library( + name = "env_impl", + srcs = [ + "env.cc", + ], + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + deps = [ + ":env", + ], +) + cc_library( name = "env_time", srcs = ["env_time.cc"], diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 9a1068d5a3a..2471ed64464 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -1,7 +1,7 @@ # Platform-specific build configurations. load("@com_google_protobuf//:protobuf.bzl", "proto_gen") -load("//tensorflow:tensorflow.bzl", "clean_dep", "if_not_windows", "if_tpu") +load("//tensorflow:tensorflow.bzl", "clean_dep", "if_libtpu", "if_not_windows") load("//tensorflow/core/platform:build_config_root.bzl", "if_static") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") @@ -506,6 +506,7 @@ def tf_proto_library( js_codegen = "jspb", create_service = False, create_java_proto = False, + create_go_proto = False, create_grpc_library = False, make_default_target_header_only = False, exports = [], @@ -522,6 +523,7 @@ def tf_proto_library( create_java_proto, create_grpc_library, cc_stubby_versions, + create_go_proto, ) native.proto_library( @@ -617,9 +619,6 @@ def tf_protos_profiler_service(): clean_dep("//tensorflow/core/profiler:profiler_service_monitor_result_proto_cc_impl"), ] -def tf_profiler_client_deps(): - return [clean_dep("//tensorflow/core/profiler/rpc/client:profiler_client_headers")] - def tf_protos_grappler_impl(): return [clean_dep("//tensorflow/core/grappler/costs:op_performance_data_cc_impl")] @@ -814,4 +813,4 @@ def if_llvm_system_z_available(then, otherwise = []): }) def tf_tpu_dependencies(): - return if_tpu(["//tensorflow/core/tpu/kernels"]) + return if_libtpu(["//tensorflow/core/tpu/kernels"]) diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl index 6012b4db407..38aeb9fb6ba 100644 --- a/tensorflow/core/platform/default/build_config_root.bzl +++ b/tensorflow/core/platform/default/build_config_root.bzl @@ -71,6 +71,3 @@ def if_dynamic_kernels(extra_deps, otherwise = []): str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps, "//conditions:default": otherwise, }) - -def register_extension_info(**kwargs): - pass diff --git a/tensorflow/core/platform/default/distribute.bzl b/tensorflow/core/platform/default/distribute.bzl index b16d5e8cff7..058ed0c4ff8 100644 --- a/tensorflow/core/platform/default/distribute.bzl +++ b/tensorflow/core/platform/default/distribute.bzl @@ -1,9 +1,5 @@ """Build rules for tf.distribute testing.""" -load( - "//tensorflow/core/platform:build_config_root.bzl", - "register_extension_info", -) load("//tensorflow/python/tpu:tpu.bzl", _tpu_py_test = "tpu_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") @@ -81,8 +77,3 @@ def distribute_py_test( disable_v3 = disable_v3, disable_mlir_bridge = disable_mlir_bridge, ) - -register_extension_info( - extension_name = "distribute_py_test", - label_regex_for_dep = "{extension_name}", -) diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 1181fab3d9e..74195db7730 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -191,12 +191,16 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) { nn.empty() ? "default" : nn.c_str()); cacheKey += nn; } - if (connectionCache_.find(cacheKey) == connectionCache_.end()) { - connectionCache_[cacheKey] = libhdfs()->hdfsBuilderConnect(builder); - } - *fs = connectionCache_[cacheKey]; - if (*fs == nullptr) { - return errors::NotFound(strerror(errno)); + { + mutex_lock lock(mu_); + if (connectionCache_.find(cacheKey) == connectionCache_.end()) { + hdfsFS cacheFs = libhdfs()->hdfsBuilderConnect(builder); + if (cacheFs == nullptr) { + return errors::NotFound(strerror(errno)); + } + connectionCache_[cacheKey] = cacheFs; + } + *fs = connectionCache_[cacheKey]; } return Status::OK(); } diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.h b/tensorflow/core/platform/hadoop/hadoop_file_system.h index 580c4fc8c82..78569b63b1f 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.h +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.h @@ -77,7 +77,8 @@ class HadoopFileSystem : public FileSystem { string TranslateName(const string& name) const override; private: - std::map connectionCache_; + mutex mu_; + std::map connectionCache_ TF_GUARDED_BY(mu_); Status Connect(StringPiece fname, hdfsFS* fs); }; diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h index 4f8e49d2653..249ccb58032 100644 --- a/tensorflow/core/platform/macros.h +++ b/tensorflow/core/platform/macros.h @@ -93,6 +93,15 @@ limitations under the License. #define TF_ATTRIBUTE_ANNOTATE(str) #endif +// A variable declaration annotated with the `TF_CONST_INIT` attribute will +// not compile (on supported platforms) unless the variable has a constant +// initializer. +#if TF_HAS_CPP_ATTRIBUTE(clang::require_constant_initialization) +#define TF_CONST_INIT [[clang::require_constant_initialization]] +#else +#define TF_CONST_INIT +#endif + // Compilers can be told that a certain branch is not likely to be taken // (for instance, a CHECK failure), and use that information in static // analysis. Giving it this information can help it optimize for the diff --git a/tensorflow/core/platform/numbers.cc b/tensorflow/core/platform/numbers.cc index d4ea4bb52e9..e96ef6243cf 100644 --- a/tensorflow/core/platform/numbers.cc +++ b/tensorflow/core/platform/numbers.cc @@ -186,6 +186,14 @@ size_t DoubleToBuffer(double value, char* buffer) { // this assert. static_assert(DBL_DIG < 20, "DBL_DIG is too big"); + if (std::isnan(value)) { + int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan", + std::signbit(value) ? "-" : ""); + // Paranoid check to ensure we don't overflow the buffer. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + return snprintf_result; + } + if (std::abs(value) <= kDoublePrecisionCheckMax) { int snprintf_result = snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value); @@ -365,6 +373,14 @@ size_t FloatToBuffer(float value, char* buffer) { // this assert. static_assert(FLT_DIG < 10, "FLT_DIG is too big"); + if (std::isnan(value)) { + int snprintf_result = snprintf(buffer, kFastToBufferSize, "%snan", + std::signbit(value) ? "-" : ""); + // Paranoid check to ensure we don't overflow the buffer. + DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize); + return snprintf_result; + } + int snprintf_result = snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value); diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h index ba507837652..29fceb2d896 100644 --- a/tensorflow/core/platform/test.h +++ b/tensorflow/core/platform/test.h @@ -40,7 +40,7 @@ limitations under the License. // The advantages of using gmock matchers instead of self defined matchers are // better error messages, more maintainable tests and more test coverage. #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) -#include "testing/base/public/gmock.h" +#include "testing/base/public/gmock.h" // IWYU pragma: export #else #include #include diff --git a/tensorflow/core/platform/windows/BUILD b/tensorflow/core/platform/windows/BUILD index 235900ac759..423d8fe5d35 100644 --- a/tensorflow/core/platform/windows/BUILD +++ b/tensorflow/core/platform/windows/BUILD @@ -20,7 +20,6 @@ package( cc_library( name = "env", srcs = [ - "env.cc", "windows_file_system.cc", "windows_file_system.h", "//tensorflow/core/platform:env.cc", @@ -77,6 +76,21 @@ cc_library( ], ) +cc_library( + name = "env_impl", + srcs = [ + "env.cc", + ], + tags = [ + "manual", + "no_oss", + "nobuilder", + ], + deps = [ + ":env", + ], +) + cc_library( name = "env_time", srcs = ["env_time.cc"], diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index 527d538b75c..84228af8d5e 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -89,9 +89,9 @@ cc_library( name = "profiler_impl", visibility = ["//tensorflow:__pkg__"], deps = [ - "//tensorflow/core/profiler/internal:annotation_stack_impl", - "//tensorflow/core/profiler/internal:profiler_factory_impl", - "//tensorflow/core/profiler/internal:traceme_recorder_impl", + "//tensorflow/core/profiler/internal/cpu:annotation_stack_impl", + "//tensorflow/core/profiler/internal/cpu:traceme_recorder_impl", + "//tensorflow/core/profiler/lib:profiler_factory_impl", "//tensorflow/core/profiler/lib:profiler_session_impl", ], alwayslink = True, @@ -135,7 +135,6 @@ cc_library( filegroup( name = "mobile_srcs", srcs = [ - "//tensorflow/core/profiler/internal:mobile_srcs", "//tensorflow/core/profiler/lib:mobile_srcs", ], visibility = ["//tensorflow/core:__pkg__"], diff --git a/tensorflow/core/profiler/builds/build_config.bzl b/tensorflow/core/profiler/builds/build_config.bzl index 7c1b0a06c06..bd20787b398 100644 --- a/tensorflow/core/profiler/builds/build_config.bzl +++ b/tensorflow/core/profiler/builds/build_config.bzl @@ -3,12 +3,17 @@ load( "//tensorflow/core/profiler/builds/oss:build_config.bzl", _tf_profiler_alias = "tf_profiler_alias", + _tf_profiler_pybind_cc_library_wrapper = "tf_profiler_pybind_cc_library_wrapper", ) tf_profiler_alias = _tf_profiler_alias +tf_profiler_pybind_cc_library_wrapper = _tf_profiler_pybind_cc_library_wrapper def if_profiler_oss(if_true, if_false = []): return select({ "//tensorflow/core/profiler/builds:profiler_build_oss": if_true, "//conditions:default": if_false, }) + +def tf_profiler_copts(): + return if_profiler_oss(["-Wc++14-compat"]) diff --git a/tensorflow/core/profiler/builds/oss/build_config.bzl b/tensorflow/core/profiler/builds/oss/build_config.bzl index 1dcfd0e3291..8849aa4cff4 100644 --- a/tensorflow/core/profiler/builds/oss/build_config.bzl +++ b/tensorflow/core/profiler/builds/oss/build_config.bzl @@ -3,5 +3,17 @@ TF profiler build macros for use in OSS. """ +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") + def tf_profiler_alias(target_dir, name): return target_dir + "oss:" + name + +def tf_profiler_pybind_cc_library_wrapper(name, actual, **kwargs): + """Wrapper for cc_library used by tf_python_pybind_extension. + + This wrapper ensures that cc libraries headers are made available to pybind + code, without creating ODR violations in the dynamically linked case. The + symbols in these deps symbols should be linked to, and exported by, the core + pywrap_tensorflow_internal.so + """ + cc_header_only_library(name = name, deps = [actual], **kwargs) diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index ad5f95dc214..38f5da03838 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( default_visibility = ["//tensorflow/core/profiler:internal"], @@ -10,6 +11,7 @@ cc_library( name = "xplane_to_op_metrics_db", srcs = ["xplane_to_op_metrics_db.cc"], hdrs = ["xplane_to_op_metrics_db.h"], + copts = tf_profiler_copts(), deps = [ ":op_metrics_db_combiner", ":op_stack", @@ -57,6 +59,7 @@ cc_library( name = "op_metrics_db_combiner", srcs = ["op_metrics_db_combiner.cc"], hdrs = ["op_metrics_db_combiner.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core/platform:protobuf", @@ -70,6 +73,7 @@ cc_library( name = "op_metrics_to_record", srcs = ["op_metrics_to_record.cc"], hdrs = ["op_metrics_to_record.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "//tensorflow/core/profiler/utils:math_utils", @@ -81,6 +85,7 @@ cc_library( cc_library( name = "op_stack", hdrs = ["op_stack.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", ], @@ -90,6 +95,7 @@ cc_library( name = "op_stats_to_overview_page", srcs = ["op_stats_to_overview_page.cc"], hdrs = ["op_stats_to_overview_page.h"], + copts = tf_profiler_copts(), deps = [ ":op_metrics_to_record", ":op_stats_to_input_pipeline_analysis", @@ -104,6 +110,7 @@ cc_library( "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", "//tensorflow/core/profiler/utils:diagnostics", + "//tensorflow/core/profiler/utils:format_utils", "//tensorflow/core/profiler/utils:hardware_type_utils", "//tensorflow/core/profiler/utils:html_utils", "//tensorflow/core/profiler/utils:kernel_stats_utils", @@ -112,7 +119,82 @@ cc_library( "//tensorflow/core/profiler/utils:tf_op_utils", "//tensorflow/core/profiler/utils:time_utils", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "op_stats_to_pod_stats", + srcs = ["op_stats_to_pod_stats.cc"], + hdrs = ["op_stats_to_pod_stats.h"], + copts = tf_profiler_copts(), + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", + "//tensorflow/core/profiler/utils:diagnostics", + "//tensorflow/core/profiler/utils:event_span", + "//tensorflow/core/profiler/utils:time_utils", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "op_stats_to_pod_stats_test", + srcs = ["op_stats_to_pod_stats_test.cc"], + deps = [ + ":op_stats_to_pod_stats", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", + "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", + "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", + "//tensorflow/core/profiler/utils:diagnostics", + "//tensorflow/core/profiler/utils:event_span", + "//tensorflow/core/profiler/utils:time_utils", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "op_stats_to_pod_viewer", + srcs = ["op_stats_to_pod_viewer.cc"], + hdrs = ["op_stats_to_pod_viewer.h"], + copts = tf_profiler_copts(), + deps = [ + ":op_stats_to_pod_stats", + "//tensorflow/core:lib", + "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:pod_viewer_proto_cc", + "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", + "//tensorflow/core/profiler/utils:diagnostics", + "//tensorflow/core/profiler/utils:event_span", + "//tensorflow/core/profiler/utils:time_utils", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "op_stats_to_pod_viewer_test", + srcs = ["op_stats_to_pod_viewer_test.cc"], + deps = [ + ":op_stats_to_pod_viewer", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", + "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", + "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:pod_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", + "//tensorflow/core/profiler/utils:diagnostics", + "//tensorflow/core/profiler/utils:event_span", + "//tensorflow/core/profiler/utils:time_utils", + "@com_google_absl//absl/strings", ], ) @@ -120,6 +202,7 @@ cc_library( name = "op_stats_to_input_pipeline_analysis", srcs = ["op_stats_to_input_pipeline_analysis.cc"], hdrs = ["op_stats_to_input_pipeline_analysis.h"], + copts = tf_profiler_copts(), deps = [ ":op_metrics_to_record", ":step_events_to_steps_db", @@ -133,6 +216,7 @@ cc_library( "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/utils:diagnostics", "//tensorflow/core/profiler/utils:event_span", + "//tensorflow/core/profiler/utils:format_utils", "//tensorflow/core/profiler/utils:hardware_type_utils", "//tensorflow/core/profiler/utils:html_utils", "//tensorflow/core/profiler/utils:math_utils", @@ -141,7 +225,6 @@ cc_library( "//tensorflow/core/util:stats_calculator_portable", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", ], ) @@ -149,6 +232,7 @@ cc_library( name = "op_stats_to_tf_stats", srcs = ["op_stats_to_tf_stats.cc"], hdrs = ["op_stats_to_tf_stats.h"], + copts = tf_profiler_copts(), deps = [ ":op_metrics_to_record", "//tensorflow/core:lib", @@ -189,6 +273,7 @@ cc_library( name = "step_events_to_steps_db", srcs = ["step_events_to_steps_db.cc"], hdrs = ["step_events_to_steps_db.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -204,11 +289,15 @@ cc_library( name = "trace_events_to_json", srcs = ["trace_events_to_json.cc"], hdrs = ["trace_events_to_json.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", + "//tensorflow/core/platform:protobuf", "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", + "//tensorflow/core/profiler/utils:format_utils", + "//tensorflow/core/profiler/utils:time_utils", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@jsoncpp_git//:jsoncpp", ], ) @@ -230,6 +319,7 @@ cc_library( name = "xplane_to_op_stats", srcs = ["xplane_to_op_stats.cc"], hdrs = ["xplane_to_op_stats.h"], + copts = tf_profiler_copts(), deps = [ ":op_metrics_db_combiner", ":op_stats_combiner", @@ -292,6 +382,7 @@ cc_library( name = "xplane_to_profile_response", srcs = ["xplane_to_profile_response.cc"], hdrs = ["xplane_to_profile_response.h"], + copts = tf_profiler_copts(), deps = [ ":op_stats_to_input_pipeline_analysis", ":op_stats_to_overview_page", @@ -344,6 +435,7 @@ cc_library( name = "xplane_to_step_events", srcs = ["xplane_to_step_events.cc"], hdrs = ["xplane_to_step_events.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", @@ -383,6 +475,7 @@ cc_library( name = "xplane_to_trace_events", srcs = ["xplane_to_trace_events.cc"], hdrs = ["xplane_to_trace_events.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", @@ -418,6 +511,7 @@ cc_library( name = "xplane_to_kernel_stats_db", srcs = ["xplane_to_kernel_stats_db.cc"], hdrs = ["xplane_to_kernel_stats_db.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -460,6 +554,7 @@ cc_library( name = "xplane_to_tf_functions", srcs = ["xplane_to_tf_functions.cc"], hdrs = ["xplane_to_tf_functions.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -504,6 +599,7 @@ cc_library( name = "xplane_to_memory_profile", srcs = ["xplane_to_memory_profile.cc"], hdrs = ["xplane_to_memory_profile.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -520,7 +616,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], ) @@ -548,6 +643,7 @@ cc_library( name = "op_stats_combiner", srcs = ["op_stats_combiner.cc"], hdrs = ["op_stats_combiner.h"], + copts = tf_profiler_copts(), deps = [ ":op_metrics_db_combiner", ":xplane_to_tf_functions", @@ -568,10 +664,10 @@ cc_library( name = "post_process_single_host_xplane", srcs = ["post_process_single_host_xplane.cc"], hdrs = ["post_process_single_host_xplane.h"], + copts = tf_profiler_copts(), visibility = ["//tensorflow/core/profiler:internal"], deps = [ "//tensorflow/core:lib", - "//tensorflow/core/profiler/internal:profiler_factory", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:derived_timeline", "//tensorflow/core/profiler/utils:group_events", @@ -579,3 +675,67 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_utils", ], ) + +cc_library( + name = "xplane_to_tools_data", + srcs = ["xplane_to_tools_data.cc"], + hdrs = ["xplane_to_tools_data.h"], + copts = tf_profiler_copts(), + deps = [ + ":op_stats_to_input_pipeline_analysis", + ":op_stats_to_overview_page", + ":op_stats_to_tf_stats", + ":xplane_to_memory_profile", + ":xplane_to_op_stats", + ":xplane_to_trace_events", + "//tensorflow/core:lib", + "//tensorflow/core/profiler/protobuf:input_pipeline_proto_cc", + "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:overview_page_proto_cc", + "//tensorflow/core/profiler/protobuf:tf_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "xplane_to_tf_data_stats", + srcs = ["xplane_to_tf_data_stats.cc"], + hdrs = ["xplane_to_tf_data_stats.h"], + copts = tf_profiler_copts(), + deps = [ + "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:protobuf", + "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:group_events", + "//tensorflow/core/profiler/utils:tf_op_utils", + "//tensorflow/core/profiler/utils:tf_xplane_visitor", + "//tensorflow/core/profiler/utils:timespan", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_visitor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "xplane_to_tf_data_stats_test", + size = "small", + srcs = ["xplane_to_tf_data_stats_test.cc"], + tags = [ + "no_oss", # b/169705709 + ], + deps = [ + ":xplane_to_tf_data_stats", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler/protobuf:tf_data_stats_proto_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:xplane_builder", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_test_utils", + ], +) diff --git a/tensorflow/core/profiler/convert/op_stats_combiner.cc b/tensorflow/core/profiler/convert/op_stats_combiner.cc index d4ce7ec315f..5c492a6fa44 100644 --- a/tensorflow/core/profiler/convert/op_stats_combiner.cc +++ b/tensorflow/core/profiler/convert/op_stats_combiner.cc @@ -157,6 +157,10 @@ void CombineOpStats( // Combine tf-function stats. CombineTfFunctionDb(src.tf_function_db(), dst->mutable_tf_function_db()); + + // Combine the mapping from core ID to details. + CombineCoreIdMap(src_host_id, src.core_id_to_details(), + dst->mutable_core_id_to_details()); } } // namespace diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index 0aea94b23d4..c339a43cefb 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" @@ -38,6 +37,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/utils/diagnostics.h" #include "tensorflow/core/profiler/utils/event_span.h" +#include "tensorflow/core/profiler/utils/format_utils.h" #include "tensorflow/core/profiler/utils/hardware_type_utils.h" #include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/math_utils.h" @@ -379,21 +379,18 @@ double RatioOfHostToDeviceTimeToStepTime( void DeviceCollectivesAnalysis(double device_collectives_percent, std::string* device_collectives_classification, std::string* device_collectives_statement) { - std::string percent_str = - absl::StrFormat("%.1lf", device_collectives_percent); - if (device_collectives_percent >= kHighlyDeviceCollectivesBoundThresholdInPercent) { *device_collectives_classification = "high"; *device_collectives_statement = - absl::StrCat(percent_str, + absl::StrCat(OneDigit(device_collectives_percent), " % of the total step time sampled is spent on 'Device " "Collective Communication'."); } else if (device_collectives_percent >= kModeratelyDeviceCollectivesBoundThresholdInPercent) { *device_collectives_classification = "moderate"; *device_collectives_statement = - absl::StrCat(percent_str, + absl::StrCat(OneDigit(device_collectives_percent), " % of the total step time sampled is spent on 'Device " "Collective Communication'."); } else { @@ -405,11 +402,10 @@ void DeviceCollectivesAnalysis(double device_collectives_percent, void KernelLaunchAnalysis(bool tfdata_used, double kernel_launch_percent, std::string* kernel_launch_classification, std::string* kernel_launch_statement) { - std::string percent_str = absl::StrFormat("%.1lf", kernel_launch_percent); if (kernel_launch_percent >= kHighlyKernelLaunchBoundThresholdInPercent) { *kernel_launch_classification = "high"; *kernel_launch_statement = absl::StrCat( - percent_str, + OneDigit(kernel_launch_percent), " % of the total step time sampled is spent on 'Kernel Launch'."); if (tfdata_used) { absl::StrAppend(kernel_launch_statement, kKernelLaunchTfDataContention); @@ -418,7 +414,7 @@ void KernelLaunchAnalysis(bool tfdata_used, double kernel_launch_percent, kModeratelyKernelLaunchBoundThresholdInPercent) { *kernel_launch_classification = "moderate"; *kernel_launch_statement = absl::StrCat( - percent_str, + OneDigit(kernel_launch_percent), " % of the total step time sampled is spent on 'Kernel Launch'."); if (tfdata_used) { absl::StrAppend(kernel_launch_statement, kKernelLaunchTfDataContention); @@ -437,15 +433,14 @@ void AllOtherAnalysis(bool all_other_reported, double all_other_percent, *all_other_statement = ""; return; } - std::string percent_str = absl::StrFormat("%.1lf", all_other_percent); if (all_other_percent >= kHighlyAllOtherBoundThresholdInPercent) { *all_other_classification = "high"; *all_other_statement = - absl::StrCat(percent_str, kAllOthersPythonExplanation); + absl::StrCat(OneDigit(all_other_percent), kAllOthersPythonExplanation); } else if (all_other_percent >= kModeratelyAllOtherBoundThresholdInPercent) { *all_other_classification = "moderate"; *all_other_statement = - absl::StrCat(percent_str, kAllOthersPythonExplanation); + absl::StrCat(OneDigit(all_other_percent), kAllOthersPythonExplanation); } else { *all_other_classification = "no"; *all_other_statement = ""; @@ -627,18 +622,18 @@ bool InputAnalysis(double input_percent, double all_other_percent, std::string* input_classification, std::string* input_statement) { absl::string_view non_input_time = "other time"; - std::string infeed_percent_str = absl::StrFormat("%.1lf", input_percent); if (input_percent >= kHighlyInfeedBoundThresholdInPercent) { *input_classification = "host"; *input_statement = absl::StrCat( - "Your program is HIGHLY input-bound because ", infeed_percent_str, + "Your program is HIGHLY input-bound because ", OneDigit(input_percent), "% of the total step time sampled is waiting for input. Therefore, you " "should first focus on reducing the input time."); return false; } else if (input_percent >= kModeratelyInfeedBoundThresholdInPercent) { *input_classification = "both"; *input_statement = absl::StrCat( - "Your program is MODERATELY input-bound because ", infeed_percent_str, + "Your program is MODERATELY input-bound because ", + OneDigit(input_percent), "% of the total step time sampled is waiting for input. Therefore, " "you would need to reduce both the input time and ", non_input_time, "."); @@ -647,41 +642,40 @@ bool InputAnalysis(double input_percent, double all_other_percent, // Input analysis says it is not input-bound, but "All-Other" time // is significant. It could still be input-bound (or Python overhead). *input_classification = "both"; - std::string all_other_percent_str = - absl::StrFormat("%.1lf", all_other_percent); *input_statement = absl::StrCat( "Your program is POTENTIALLY input-bound because ", - all_other_percent_str, + OneDigit(all_other_percent), "% of the total step time sampled is spent on 'All Others' time (which " "could be due to I/O or Python execution or both)."); return true; } else { // Defintely not input-bound. *input_classification = "device"; - *input_statement = absl::StrCat( - "Your program is NOT input-bound because only ", infeed_percent_str, - "% of the total step time sampled is waiting for " - "input. Therefore, you should focus on " - "reducing ", - non_input_time, "."); + *input_statement = + absl::StrCat("Your program is NOT input-bound because only ", + OneDigit(input_percent), + "% of the total step time sampled is waiting for " + "input. Therefore, you should focus on " + "reducing ", + non_input_time, "."); return false; } } void OutputAnalysis(double output_percent, std::string* output_classification, std::string* output_statement) { - string tc_outfeed_percent_str = absl::StrFormat("%.1lf", output_percent); if (output_percent >= kHighlyOutfeedBoundThresholdInPercent) { *output_classification = "host"; *output_statement = absl::StrCat( - "Your program is HIGHLY output-bound because ", tc_outfeed_percent_str, + "Your program is HIGHLY output-bound because ", + OneDigit(output_percent), "% of the total step time sampled is spent on output. Therefore, you " "should first focus on reducing the output time."); } else if (output_percent >= kModeratelyOutfeedBoundThresholdInPercent) { *output_classification = "both"; *output_statement = absl::StrCat( "Your program is MODERATELY output-bound because ", - tc_outfeed_percent_str, + OneDigit(output_percent), "% of the total step time sampled is spent on output. Therefore, " "you would need to reduce both the output time and other time."); } else { diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index 8f0e920c7e6..cf0b7e6ad43 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -19,7 +19,6 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_to_record.h" #include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" @@ -32,6 +31,7 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/protobuf/tf_function.pb.h" #include "tensorflow/core/profiler/utils/diagnostics.h" +#include "tensorflow/core/profiler/utils/format_utils.h" #include "tensorflow/core/profiler/utils/hardware_type_utils.h" #include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" @@ -118,7 +118,7 @@ std::string GeneratePrecisionStatement(const PrecisionStats& precision_stats) { (100.0 * precision_stats.compute_16bit_ps()) / total_compute_ps; if (percent_16bit < kLowPrecisionPercentThreshold) { return absl::StrCat( - "Only ", absl::StrFormat("%.1lf", percent_16bit), + "Only ", OneDigit(percent_16bit), "% of device computation is 16 bit. So you might want to replace " "more 32-bit Ops by 16-bit Ops to improve performance (if the " "reduced accuracy is acceptable)."); @@ -338,12 +338,10 @@ std::string EagerRecommendationHtml(double host_op_time_eager_percent, double device_op_time_eager_percent) { std::string recommendation = ""; if (host_op_time_eager_percent > kEagerReportThresholdInPercent) - absl::StrAppend(&recommendation, - absl::StrFormat("%.1f", host_op_time_eager_percent), + absl::StrAppend(&recommendation, OneDigit(host_op_time_eager_percent), "% of Op time on the host used eager execution. "); if (device_op_time_eager_percent > kEagerReportThresholdInPercent) - absl::StrAppend(&recommendation, - absl::StrFormat("%.1f", device_op_time_eager_percent), + absl::StrAppend(&recommendation, OneDigit(device_op_time_eager_percent), "% of Op time on the device used eager execution. "); if (!recommendation.empty()) absl::StrAppend(&recommendation, "Performance could be improved with ", @@ -358,7 +356,7 @@ std::string OutsideCompilationRecommendationHtml( kOutsideCompilationThresholdInPercent) return ""; return absl::StrCat( - absl::StrFormat("%.1lf", device_op_time_outside_compilation_percent), + OneDigit(device_op_time_outside_compilation_percent), " % of Op time on the device are for outside compilation. Performance " "could be improved by avoiding outside compilation."); } diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc new file mode 100644 index 00000000000..9e18cb1446a --- /dev/null +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.cc @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" + +#include "google/protobuf/any.pb.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "tensorflow/core/profiler/utils/diagnostics.h" +#include "tensorflow/core/profiler/utils/event_span.h" +#include "tensorflow/core/profiler/utils/time_utils.h" + +namespace tensorflow { +namespace profiler { + +namespace { + +PodStatsRecord CreatePodStatsRecord(absl::string_view host_name, + const StepInfoResult& step_info) { + PodStatsRecord record; + GenericStepBreakdown generic; + bool success = step_info.step_breakdown().UnpackTo(&generic); + DCHECK(success); + record.set_host_name(string(host_name)); + record.set_step_num(step_info.step_num()); + record.set_total_duration_us(PicosToMicros(step_info.duration_ps())); + auto& step_breakdown_map = *record.mutable_step_breakdown_us(); + std::vector> metrics; + + auto add_event = [&](GenericEventType type, + std::initializer_list event_list) { + uint64 ps = 0; + for (const auto& event_type : event_list) { + ps += gtl::FindWithDefault(generic.type_ps(), event_type, /*value=*/0); + } + step_breakdown_map[type] = PicosToMicros(ps); + metrics.emplace_back(ps, GetGenericEventTypeStr(type)); + }; + + add_event(kDeviceCompute, {DEVICE_COMPUTE_32, DEVICE_COMPUTE_16}); + add_event(kDeviceToDevice, {DEVICE_TO_DEVICE, DEVICE_WAIT_DEVICE}); + add_event(kDeviceCollectives, {DEVICE_COLLECTIVES}); + add_event(kHostCompute, {HOST_COMPUTE}); + add_event(kHostPrepare, {HOST_PREPARE}); + add_event(kInput, {HOST_WAIT_INPUT, HOST_TO_DEVICE, DEVICE_WAIT_HOST}); + add_event(kOutput, {DEVICE_TO_HOST}); + add_event(kCompile, {HOST_COMPILE}); + add_event(kAllOthers, {UNKNOWN_TIME}); + + std::sort(metrics.begin(), metrics.end()); + record.set_bottleneck(metrics.back().second.data(), + metrics.back().second.size()); + return record; +} + +} // namespace + +PodStatsDatabase ConvertOpStatsToPodStats(const OpStats& op_stats) { + PodStatsDatabase pod_stats_db; + const auto& core_id_map = op_stats.core_id_to_details(); + for (int i = GenericEventType::kFirstGenericEventType; + i <= GenericEventType::kLastGenericEventType; i++) { + auto& event = *pod_stats_db.add_step_breakdown_events(); + event.set_id(i); + absl::string_view type_str = + GetGenericEventTypeStr(static_cast(i)); + event.set_name(type_str.data(), type_str.size()); + } + + for (const auto& step_sequence : op_stats.step_db().step_sequence()) { + for (const auto& entry : step_sequence.step_info_per_core()) { + const CoreDetails& details = core_id_map.at(entry.first); + *pod_stats_db.add_pod_stats_record() = + CreatePodStatsRecord(details.hostname(), entry.second); + } + } + PopulateStepDiagnostics(op_stats, pod_stats_db.mutable_diagnostics()); + return pod_stats_db; +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h new file mode 100644 index 00000000000..bd3d74068d8 --- /dev/null +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ + +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" + +namespace tensorflow { +namespace profiler { + +PodStatsDatabase ConvertOpStatsToPodStats(const OpStats& op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc new file mode 100644 index 00000000000..4f922791e5f --- /dev/null +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_stats_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" + +#include "google/protobuf/any.pb.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "tensorflow/core/profiler/utils/diagnostics.h" +#include "tensorflow/core/profiler/utils/event_span.h" +#include "tensorflow/core/profiler/utils/time_utils.h" + +namespace tensorflow { +namespace profiler { +namespace { + +const double kMaxError = 1e-6; +constexpr int kStepNum = 2; +constexpr int kCoreId = 1001; +constexpr int kStepTimePs = 1000; +constexpr int kHostComputePs = 50; +constexpr int kHostCompilePs = 50; +constexpr int kHostToHostPs = 50; +constexpr int kHostToDevicePs = 50; +constexpr int kHostPreparePs = 50; +constexpr int kDeviceCollectivePs = 350; +constexpr int kHostWaitInputPs = 50; +constexpr int kDeviceToDevicePs = 50; +constexpr int kDeviceToHostPs = 50; +constexpr int kDeviceCompute32Ps = 50; +constexpr int kDeviceCompute16Ps = 50; +constexpr int kDeviceWaitDevicePs = 50; +constexpr int kDeviceWaitHostPs = 50; +constexpr int kUnknownTimePs = 50; +static constexpr char kHostname[] = "host:123"; + +void CreateOpStats(OpStats* op_stats) { + PerCoreStepInfo* info = op_stats->mutable_step_db()->add_step_sequence(); + info->set_step_num(kStepNum); + StepInfoResult& step_info = (*info->mutable_step_info_per_core())[kCoreId]; + step_info.set_step_num(kStepNum); + step_info.set_duration_ps(kStepTimePs); + GenericStepBreakdown breakdown; + auto& type_ps = *breakdown.mutable_type_ps(); + type_ps[HOST_COMPUTE] = kHostComputePs; + type_ps[HOST_COMPILE] = kHostCompilePs; + type_ps[HOST_TO_HOST] = kHostToHostPs; + type_ps[HOST_TO_DEVICE] = kHostToDevicePs; + type_ps[HOST_PREPARE] = kHostPreparePs; + type_ps[DEVICE_COLLECTIVES] = kDeviceCollectivePs; + type_ps[HOST_WAIT_INPUT] = kHostWaitInputPs; + type_ps[DEVICE_TO_DEVICE] = kDeviceToDevicePs; + type_ps[DEVICE_TO_HOST] = kDeviceToHostPs; + type_ps[DEVICE_COMPUTE_32] = kDeviceCompute32Ps; + type_ps[DEVICE_COMPUTE_16] = kDeviceCompute16Ps; + type_ps[DEVICE_WAIT_DEVICE] = kDeviceWaitDevicePs; + type_ps[DEVICE_WAIT_HOST] = kDeviceWaitHostPs; + type_ps[UNKNOWN_TIME] = kUnknownTimePs; + step_info.mutable_step_breakdown()->PackFrom(breakdown); + CoreDetails& details = (*op_stats->mutable_core_id_to_details())[kCoreId]; + details.set_hostname(kHostname); +} + +TEST(OpStatsToPodStats, GpuPodStats) { + OpStats op_stats; + CreateOpStats(&op_stats); + PodStatsDatabase pod_stats_db = ConvertOpStatsToPodStats(op_stats); + EXPECT_EQ(1, pod_stats_db.pod_stats_record_size()); + const PodStatsRecord& record = pod_stats_db.pod_stats_record(0); + EXPECT_EQ(kStepNum, record.step_num()); + EXPECT_EQ(kHostname, record.host_name()); + EXPECT_NEAR(PicosToMicros(kStepTimePs), record.total_duration_us(), + kMaxError); + const auto& breakdown = record.step_breakdown_us(); + EXPECT_NEAR(PicosToMicros(kDeviceCompute32Ps + kDeviceCompute16Ps), + breakdown.at(kDeviceCompute), kMaxError); + EXPECT_NEAR(PicosToMicros(kDeviceToDevicePs + kDeviceWaitDevicePs), + breakdown.at(kDeviceToDevice), kMaxError); + EXPECT_NEAR(PicosToMicros(kDeviceCollectivePs), + breakdown.at(kDeviceCollectives), kMaxError); + EXPECT_NEAR(PicosToMicros(kHostComputePs), breakdown.at(kHostCompute), + kMaxError); + EXPECT_NEAR(PicosToMicros(kHostPreparePs), breakdown.at(kHostPrepare), + kMaxError); + EXPECT_NEAR( + PicosToMicros(kHostWaitInputPs + kHostToDevicePs + kDeviceWaitHostPs), + breakdown.at(kInput), kMaxError); + EXPECT_NEAR(PicosToMicros(kDeviceToHostPs), breakdown.at(kOutput), kMaxError); + EXPECT_NEAR(PicosToMicros(kHostCompilePs), breakdown.at(kCompile), kMaxError); + EXPECT_NEAR(PicosToMicros(kUnknownTimePs), breakdown.at(kAllOthers), + kMaxError); + + EXPECT_EQ(GetGenericEventTypeStr(kDeviceCollectives), record.bottleneck()); +} + +TEST(OpStatsToPodStats, Diagnostics) { + OpStats op_stats; + op_stats.mutable_step_db()->set_use_incomplete_step(true); + PodStatsDatabase pod_stats_db = ConvertOpStatsToPodStats(op_stats); + EXPECT_EQ(1, pod_stats_db.diagnostics().warnings_size()); + EXPECT_EQ(kErrorIncompleteStep, pod_stats_db.diagnostics().warnings(0)); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc new file mode 100644 index 00000000000..c6d2a95f26d --- /dev/null +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.cc @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/profiler/convert/op_stats_to_pod_stats.h" +#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "tensorflow/core/profiler/utils/diagnostics.h" + +namespace tensorflow { +namespace profiler { +namespace { + +PodStatsSequence ConvertOpStatsToPodStatsSequence(const OpStats& op_stats, + PodStatsDatabase pod_stats) { + PodStatsSequence result_db; + // PodStatsDatabase is created using the same iteration order below. + // Thus, we just need to move one record at a time. + int i = 0; + for (const auto& step_sequence : op_stats.step_db().step_sequence()) { + PodStatsMap* pod_stats_map = result_db.add_pod_stats_map(); + pod_stats_map->set_step_num(step_sequence.step_num()); + for (const auto& entry : step_sequence.step_info_per_core()) { + PodStatsRecord& record = + (*pod_stats_map->mutable_pod_stats_per_core())[entry.first]; + DCHECK_LE(i, pod_stats.pod_stats_record_size()); + record = std::move(*pod_stats.mutable_pod_stats_record(i++)); + } + } + return result_db; +} + +} // namespace + +PodViewerDatabase ConvertOpStatsToPodViewer(const OpStats& op_stats) { + PodViewerDatabase database; + database.set_device_type(op_stats.run_environment().device_type()); + PodStatsDatabase pod_stats = ConvertOpStatsToPodStats(op_stats); + database.mutable_step_breakdown_events()->Swap( + pod_stats.mutable_step_breakdown_events()); + *database.mutable_pod_stats_sequence() = + ConvertOpStatsToPodStatsSequence(op_stats, std::move(pod_stats)); + PopulateStepDiagnostics(op_stats, database.mutable_diagnostics()); + return database; +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h new file mode 100644 index 00000000000..c45c9939375 --- /dev/null +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ + +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/pod_viewer.pb.h" + +namespace tensorflow { +namespace profiler { + +PodViewerDatabase ConvertOpStatsToPodViewer(const OpStats& op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ diff --git a/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc new file mode 100644 index 00000000000..c63edffbfce --- /dev/null +++ b/tensorflow/core/profiler/convert/op_stats_to_pod_viewer_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h" + +#include "google/protobuf/any.pb.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "tensorflow/core/profiler/utils/diagnostics.h" +#include "tensorflow/core/profiler/utils/event_span.h" +#include "tensorflow/core/profiler/utils/time_utils.h" + +namespace tensorflow { +namespace profiler { +namespace { + +const double kMaxError = 1e-6; +constexpr int kStepNum = 2; +constexpr int kCoreId = 1001; +constexpr int kStepTimePs = 1000; +constexpr int kHostComputePs = 50; +constexpr int kHostCompilePs = 50; +constexpr int kHostToHostPs = 50; +constexpr int kHostToDevicePs = 50; +constexpr int kHostPreparePs = 50; +constexpr int kDeviceCollectivePs = 350; +constexpr int kHostWaitInputPs = 50; +constexpr int kDeviceToDevicePs = 50; +constexpr int kDeviceToHostPs = 50; +constexpr int kDeviceCompute32Ps = 50; +constexpr int kDeviceCompute16Ps = 50; +constexpr int kDeviceWaitDevicePs = 50; +constexpr int kDeviceWaitHostPs = 50; +constexpr int kUnknownTimePs = 50; +static constexpr char kHostname[] = "host:123"; + +void CreateOpStats(OpStats* op_stats) { + PerCoreStepInfo* info = op_stats->mutable_step_db()->add_step_sequence(); + info->set_step_num(kStepNum); + StepInfoResult& step_info = (*info->mutable_step_info_per_core())[kCoreId]; + step_info.set_step_num(kStepNum); + step_info.set_duration_ps(kStepTimePs); + GenericStepBreakdown breakdown; + auto& type_ps = *breakdown.mutable_type_ps(); + type_ps[HOST_COMPUTE] = kHostComputePs; + type_ps[HOST_COMPILE] = kHostCompilePs; + type_ps[HOST_TO_HOST] = kHostToHostPs; + type_ps[HOST_TO_DEVICE] = kHostToDevicePs; + type_ps[HOST_PREPARE] = kHostPreparePs; + type_ps[DEVICE_COLLECTIVES] = kDeviceCollectivePs; + type_ps[HOST_WAIT_INPUT] = kHostWaitInputPs; + type_ps[DEVICE_TO_DEVICE] = kDeviceToDevicePs; + type_ps[DEVICE_TO_HOST] = kDeviceToHostPs; + type_ps[DEVICE_COMPUTE_32] = kDeviceCompute32Ps; + type_ps[DEVICE_COMPUTE_16] = kDeviceCompute16Ps; + type_ps[DEVICE_WAIT_DEVICE] = kDeviceWaitDevicePs; + type_ps[DEVICE_WAIT_HOST] = kDeviceWaitHostPs; + type_ps[UNKNOWN_TIME] = kUnknownTimePs; + step_info.mutable_step_breakdown()->PackFrom(breakdown); + CoreDetails& details = (*op_stats->mutable_core_id_to_details())[kCoreId]; + details.set_hostname(kHostname); +} + +TEST(OpStatsToPodViewer, GpuPodViewer) { + OpStats op_stats; + CreateOpStats(&op_stats); + PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); + EXPECT_EQ(1, pod_viewer_db.pod_stats_sequence().pod_stats_map_size()); + const PodStatsMap& pod_stats_map = + pod_viewer_db.pod_stats_sequence().pod_stats_map(0); + EXPECT_EQ(kStepNum, pod_stats_map.step_num()); + const PodStatsRecord& record = pod_stats_map.pod_stats_per_core().at(kCoreId); + EXPECT_EQ(kStepNum, record.step_num()); + EXPECT_EQ(kHostname, record.host_name()); + EXPECT_NEAR(PicosToMicros(kStepTimePs), record.total_duration_us(), + kMaxError); + const auto& breakdown = record.step_breakdown_us(); + EXPECT_NEAR(PicosToMicros(kDeviceCompute32Ps + kDeviceCompute16Ps), + breakdown.at(kDeviceCompute), kMaxError); + EXPECT_NEAR(PicosToMicros(kDeviceToDevicePs + kDeviceWaitDevicePs), + breakdown.at(kDeviceToDevice), kMaxError); + EXPECT_NEAR(PicosToMicros(kDeviceCollectivePs), + breakdown.at(kDeviceCollectives), kMaxError); + EXPECT_NEAR(PicosToMicros(kHostComputePs), breakdown.at(kHostCompute), + kMaxError); + EXPECT_NEAR(PicosToMicros(kHostPreparePs), breakdown.at(kHostPrepare), + kMaxError); + EXPECT_NEAR( + PicosToMicros(kHostWaitInputPs + kHostToDevicePs + kDeviceWaitHostPs), + breakdown.at(kInput), kMaxError); + EXPECT_NEAR(PicosToMicros(kDeviceToHostPs), breakdown.at(kOutput), kMaxError); + EXPECT_NEAR(PicosToMicros(kHostCompilePs), breakdown.at(kCompile), kMaxError); + EXPECT_NEAR(PicosToMicros(kUnknownTimePs), breakdown.at(kAllOthers), + kMaxError); + + EXPECT_EQ(GetGenericEventTypeStr(kDeviceCollectives), record.bottleneck()); +} + +TEST(OpStatsToPodViewer, Diagnostics) { + OpStats op_stats; + op_stats.mutable_step_db()->set_use_incomplete_step(true); + PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); + EXPECT_EQ(1, pod_viewer_db.diagnostics().warnings_size()); + EXPECT_EQ(kErrorIncompleteStep, pod_viewer_db.diagnostics().warnings(0)); +} + +TEST(OpStatsToPodViewer, DeviceType) { + OpStats op_stats; + op_stats.mutable_run_environment()->set_device_type("GPU"); + PodViewerDatabase pod_viewer_db = ConvertOpStatsToPodViewer(op_stats); + EXPECT_EQ("GPU", pod_viewer_db.device_type()); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc index b379c2d2f5e..581a003eb38 100644 --- a/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/post_process_single_host_xplane.cc @@ -14,13 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/convert/post_process_single_host_xplane.h" -#if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/profiler/internal/profiler_factory.h" #include "tensorflow/core/profiler/utils/derived_timeline.h" #include "tensorflow/core/profiler/utils/group_events.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" -#endif namespace tensorflow { namespace profiler { @@ -59,10 +56,10 @@ void PostProcessSingleHostXSpace(XSpace* space, uint64 start_time_ns) { // 3. Sort each plane of the XSpace SortXSpace(space); // 4. Grouping (i.e. marking step number) events in the XSpace. - GroupMetadataMap group_metadata_map; - GroupTfEvents(space, &group_metadata_map); + EventForest event_forest; + GroupTfEvents(space, &event_forest); // 5. Generated miscellaneous derived time lines for device planes. - GenerateDerivedTimeLines(group_metadata_map, space); + GenerateDerivedTimeLines(event_forest.GetGroupMetadataMap(), space); } } // namespace profiler diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/tensorflow/core/profiler/convert/step_events_to_steps_db.h index bc2927f2df9..ca01e87c4d7 100644 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.h +++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/utils/event_span.h" @@ -23,7 +24,7 @@ limitations under the License. namespace tensorflow { namespace profiler { -ABSL_CONST_INIT extern const uint32 kDefaultGpuLocalCoreId; +TF_CONST_INIT extern const uint32 kDefaultGpuLocalCoreId; // Converts from overlapped Step-Events to StepDatabaseResult. StepDatabaseResult ConvertStepEventsToStepDb( diff --git a/tensorflow/core/profiler/convert/trace_events_to_json.cc b/tensorflow/core/profiler/convert/trace_events_to_json.cc index ad40292ceff..658f6d0c687 100644 --- a/tensorflow/core/profiler/convert/trace_events_to_json.cc +++ b/tensorflow/core/profiler/convert/trace_events_to_json.cc @@ -16,113 +16,110 @@ limitations under the License. #include "tensorflow/core/profiler/convert/trace_events_to_json.h" #include -#include +#include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "json/json.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" +#include "tensorflow/core/profiler/utils/format_utils.h" +#include "tensorflow/core/profiler/utils/time_utils.h" namespace tensorflow { namespace profiler { namespace { -constexpr double kPicosPerMicro = 1000000.0; - -inline void AppendEscapedName(string *json, const string &name) { - absl::StrAppend(json, "\"name\":", Json::valueToQuotedString(name.c_str())); +// Converts the given time from picoseconds to microseconds and then to a string +// using maximum precision. +inline std::string PicosToMicrosString(uint64 ps) { + return MaxPrecision(PicosToMicros(ps)); } -// Adds resource events for a single device. -void AddResourceMetadata(uint32 device_id, - const std::map &resources, - string *json) { - for (const auto &pair : resources) { - uint32 resource_id = pair.first; - const Resource &resource = *pair.second; - if (!resource.name().empty()) { - absl::StrAppendFormat(json, - R"({"ph":"M","pid":%u,"tid":%u,)" - R"("name":"thread_name","args":{)", - device_id, resource_id); - AppendEscapedName(json, resource.name()); - absl::StrAppend(json, "}},"); - } - uint32 sort_index = - resource.sort_index() ? resource.sort_index() : resource_id; - absl::StrAppendFormat( - json, - R"({"ph":"M","pid":%u,"tid":%u,)" - R"("name":"thread_sort_index","args":{"sort_index":%u}},)", - device_id, resource_id, sort_index); +// Escapes and quotes the given string. +inline std::string JsonString(const std::string& s) { + return Json::valueToQuotedString(s.c_str()); +} + +// Returns a vector of pointers to the elements in the given map, sorted by key. +template +std::vector SortByKey(const Map& m) { + std::vector pairs; + pairs.reserve(m.size()); + for (const auto& pair : m) { + pairs.push_back(&pair); } + absl::c_sort(pairs, [](const typename Map::value_type* a, + const typename Map::value_type* b) { + return a->first < b->first; + }); + return pairs; } -void AddDeviceMetadata(const std::map &devices, - string *json) { - for (const auto &pair : devices) { - uint32 device_id = pair.first; - const Device &device = *pair.second; - if (!device.name().empty()) { - absl::StrAppendFormat(json, - R"({"ph":"M","pid":%u,"name":"process_name",)" - R"("args":{)", - device_id); - AppendEscapedName(json, device.name()); - absl::StrAppend(json, "}},"); - } - absl::StrAppendFormat(json, - R"({"ph":"M","pid":%u,"name":"process_sort_index",)" - R"("args":{"sort_index":%u}},)", - device_id, device_id); - // Convert to a std::map so that resources are sorted by the device id. - std::map sorted_resources; - for (const auto &pair : device.resources()) { - sorted_resources[pair.first] = &pair.second; - } - AddResourceMetadata(device_id, sorted_resources, json); +inline void AddDeviceMetadata(uint32 device_id, const Device& device, + std::string* json) { + if (!device.name().empty()) { + absl::StrAppend(json, R"({"ph":"M","pid":)", device_id, + R"(,"name":"process_name","args":{"name":)", + JsonString(device.name()), "}},"); } + absl::StrAppend(json, R"({"ph":"M","pid":)", device_id, + R"(,"name":"process_sort_index","args":{"sort_index":)", + device_id, "}},"); } -inline void AddTraceEvent(const TraceEvent &event, string *json) { - absl::StrAppendFormat(json, R"({"pid":%u,"tid":%u,"ts":%.17g,)", - event.device_id(), event.resource_id(), - event.timestamp_ps() / kPicosPerMicro); - AppendEscapedName(json, event.name()); - absl::StrAppend(json, ","); - uint64 duration_ps = - std::max(static_cast(event.duration_ps()), uint64{1}); - absl::StrAppendFormat(json, R"("ph":"X","dur":%.17g)", - duration_ps / kPicosPerMicro); +inline void AddResourceMetadata(uint32 device_id, uint32 resource_id, + const Resource& resource, std::string* json) { + if (!resource.name().empty()) { + absl::StrAppend(json, R"({"ph":"M","pid":)", device_id, R"(,"tid":)", + resource_id, R"(,"name":"thread_name","args":{"name":)", + JsonString(resource.name()), "}},"); + } + uint32 sort_index = + resource.sort_index() ? resource.sort_index() : resource_id; + absl::StrAppend(json, R"({"ph":"M","pid":)", device_id, R"(,"tid":)", + resource_id, R"(,"name":"thread_sort_index")", + R"(,"args":{"sort_index":)", sort_index, "}},"); +} + +inline void AddTraceEvent(const TraceEvent& event, string* json) { + auto duration_ps = std::max(event.duration_ps(), protobuf_uint64{1}); + absl::StrAppend(json, R"({"ph":"X","pid":)", event.device_id(), R"(,"tid":)", + event.resource_id(), R"(,"ts":)", + PicosToMicrosString(event.timestamp_ps()), R"(,"dur":)", + PicosToMicrosString(duration_ps), R"(,"name":)", + JsonString(event.name())); if (!event.args().empty()) { - std::map sorted_args(event.args().begin(), - event.args().end()); absl::StrAppend(json, R"(,"args":{)"); - for (const auto &arg : sorted_args) { - absl::StrAppend(json, Json::valueToQuotedString(arg.first.c_str()), ":", - Json::valueToQuotedString(arg.second.c_str()), ","); + for (const auto* arg : SortByKey(event.args())) { + absl::StrAppend(json, JsonString(arg->first), ":", + JsonString(arg->second), ","); } - // Removes the trailing comma. - json->pop_back(); - absl::StrAppend(json, "}"); + // Replace trailing comma with closing brace. + json->back() = '}'; } absl::StrAppend(json, "},"); } } // namespace -string TraceEventsToJson(const Trace &trace) { - string json = R"({"displayTimeUnit":"ns","metadata":{"highres-ticks":true},)" - R"("traceEvents":[)"; - // Convert to a std::map so that devices are sorted by the device id. - std::map sorted_devices; - for (const auto &pair : trace.devices()) { - sorted_devices[pair.first] = &pair.second; +std::string TraceEventsToJson(const Trace& trace) { + std::string json = + R"({"displayTimeUnit":"ns","metadata":{"highres-ticks":true},)" + R"("traceEvents":[)"; + for (const auto* id_and_device : SortByKey(trace.devices())) { + uint32 device_id = id_and_device->first; + const Device& device = id_and_device->second; + AddDeviceMetadata(device_id, device, &json); + for (const auto* id_and_resource : SortByKey(device.resources())) { + uint32 resource_id = id_and_resource->first; + const Resource& resource = id_and_resource->second; + AddResourceMetadata(device_id, resource_id, resource, &json); + } } - AddDeviceMetadata(sorted_devices, &json); - for (const TraceEvent &event : trace.trace_events()) { + for (const TraceEvent& event : trace.trace_events()) { AddTraceEvent(event, &json); } // Add one fake event to avoid dealing with no-trailing-comma rule. diff --git a/tensorflow/core/profiler/convert/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_events_to_json.h index 16747fec737..15db9e2d954 100644 --- a/tensorflow/core/profiler/convert/trace_events_to_json.h +++ b/tensorflow/core/profiler/convert/trace_events_to_json.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_EVENTS_TO_JSON_H_ #define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_EVENTS_TO_JSON_H_ +#include + #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" @@ -24,7 +26,7 @@ namespace profiler { // Converts trace events in the trace proto to a JSON string that can be // consumed by catapult trace viewer. -string TraceEventsToJson(const Trace &trace); +std::string TraceEventsToJson(const Trace& trace); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/trace_events_to_json_test.cc b/tensorflow/core/profiler/convert/trace_events_to_json_test.cc index bf08a19e022..2181197e3fc 100644 --- a/tensorflow/core/profiler/convert/trace_events_to_json_test.cc +++ b/tensorflow/core/profiler/convert/trace_events_to_json_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/profiler/convert/trace_events_to_json.h" +#include + #include "json/json.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -24,13 +26,13 @@ namespace tensorflow { namespace profiler { namespace { -string ConvertTextFormattedTraceToJson(const string& trace_str) { +std::string ConvertTextFormattedTraceToJson(const std::string& trace_str) { Trace trace; - ::tensorflow::protobuf::TextFormat::ParseFromString(trace_str, &trace); + EXPECT_TRUE(protobuf::TextFormat::ParseFromString(trace_str, &trace)); return TraceEventsToJson(trace); } -Json::Value ToJsonValue(const string& json_str) { +Json::Value ToJsonValue(const std::string& json_str) { Json::Value json; Json::Reader reader; EXPECT_TRUE(reader.parse(json_str, json)); @@ -38,47 +40,48 @@ Json::Value ToJsonValue(const string& json_str) { } TEST(TraceEventsToJson, JsonConversion) { - string json_output = ConvertTextFormattedTraceToJson(R"( - devices { key: 2 value { + std::string json_output = ConvertTextFormattedTraceToJson(R"proto( + devices { + key: 2 + value { name: 'D2' device_id: 2 - resources { key: 2 value { - resource_id: 2 - name: 'R2.2' - } } - } } - devices { key: 1 value { + resources { + key: 2 + value { resource_id: 2 name: 'R2.2' } + } + } + } + devices { + key: 1 + value { name: 'D1' device_id: 1 - resources { key: 2 value { - resource_id: 1 - name: 'R1.2' - } } - } } + resources { + key: 2 + value { resource_id: 1 name: 'R1.2' } + } + } + } + trace_events { + device_id: 1 + resource_id: 2 + name: 'E1.2.1' + timestamp_ps: 100000 + duration_ps: 10000 + args { key: 'long_name' value: 'E1.2.1 long' } + args { key: 'arg2' value: 'arg2 val' } + } + trace_events { + device_id: 2 + resource_id: 2 + name: 'E2.2.1 # "comment"' + timestamp_ps: 105000 + } + )proto"); + Json::Value json = ToJsonValue(json_output); - trace_events { - device_id: 1 - resource_id: 2 - name: 'E1.2.1' - timestamp_ps: 100000 - duration_ps: 10000 - args { - key: 'long_name' - value: 'E1.2.1 long' - } - args { - key: 'arg2' - value: 'arg2 val' - } - } - trace_events { - device_id: 2 - resource_id: 2 - name: 'E2.2.1 # "comment"' - timestamp_ps: 105000 - } - )"); - string expected_json = R"( + Json::Value expected_json = ToJsonValue(R"( { "displayTimeUnit": "ns", "metadata": { "highres-ticks": true }, @@ -95,7 +98,6 @@ TEST(TraceEventsToJson, JsonConversion) { "args":{"name":"R2.2"}}, {"ph":"M", "pid":2, "tid":2, "name":"thread_sort_index", "args":{"sort_index":2}}, - { "ph" : "X", "pid" : 1, @@ -115,8 +117,9 @@ TEST(TraceEventsToJson, JsonConversion) { }, {} ] - })"; - EXPECT_EQ(ToJsonValue(json_output), ToJsonValue(expected_json)); + })"); + + EXPECT_EQ(json, expected_json); } } // namespace diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc index 3b67124ef27..4bd1aa9d49c 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc @@ -25,7 +25,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/strings/str_format.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/core/framework/types.h" @@ -153,7 +153,7 @@ MemoryProfile GenerateMemoryProfile(const XPlane* host_trace) { switch (stat.Type().value()) { case StatType::kIndexOnHost: case StatType::kDeviceOrdinal: - memory_id = absl::StrFormat("%d", stat.IntValue()); + memory_id = absl::StrCat(stat.IntValue()); break; case StatType::kAllocatorName: memory_id = std::string(stat.StrOrRefValue()); diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc new file mode 100644 index 00000000000..7abc8ac37e3 --- /dev/null +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc @@ -0,0 +1,262 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/profiler/utils/group_events.h" +#include "tensorflow/core/profiler/utils/tf_op_utils.h" +#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h" +#include "tensorflow/core/profiler/utils/timespan.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { +namespace { + +// Returns true if the given iterator event is for a root iterator. +bool IsRootIteratorEvent(const XEventVisitor& iterator_event) { + std::vector split_result = + absl::StrSplit(iterator_event.Name(), "::"); + // The root iterator's name contains only its own name (no parent + // information). + return split_result.size() == 2; +} + +// Returns true if the given iterator event name is for an async iterator. +bool IsAsyncIterator(absl::string_view iterator_event_name) { + static auto* kAsyncIterators = new absl::flat_hash_set( + {"Prefetch", "ParallelInterleave", "ParallelMap", "ParseExample", + "MapAndBatch", "DataService", "LegacyParallelInterleave"}); + return kAsyncIterators->contains(iterator_event_name); +} + +void SetIteratorMetadata(int64 id, const XEventVisitor& event, + IteratorMetadata* metadata) { + metadata->set_id(id); + auto parent_id_stat = event.GetStat(StatType::kParentId); + if (parent_id_stat.has_value()) { + metadata->set_parent_id(parent_id_stat->IntValue()); + } + metadata->set_name(IteratorName(event.Name())); + metadata->set_long_name(event.Name().data(), event.Name().size()); + metadata->set_is_async(IsAsyncIterator(metadata->name())); + // TODO(b/161831651): Set params. +} + +// Returns the parent iterator's id if it is a root of a device input +// pipeline. +absl::optional FindDeviceInputPipeline(const XEventVisitor& event) { + if (event.Type() == HostEventType::kDeviceInputPipelineSecondIterator) { + auto parent_id_stat = event.GetStat(StatType::kParentId); + if (parent_id_stat.has_value()) return parent_id_stat->IntValue(); + } + return absl::nullopt; +} + +// Processes EventForest to do the following: +// (1) set iterator metadata +// (2) find root iterator events +// (3) find device input pipeline ids +void ProcessEventForest(const EventForest& event_forest, + absl::flat_hash_set* device_input_pipeline_ids, + absl::flat_hash_map>* + root_iterator_event_map, + TfDataStats* tf_data_stats) { + const EventNodeMap& event_node_map = event_forest.GetEventNodeMap(); + auto iterator_event_list = + gtl::FindOrNull(event_node_map, HostEventType::kIterator); + if (!iterator_event_list) return; + for (const auto& iterator_event : *iterator_event_list) { + const XEventVisitor& iterator_event_visitor = + iterator_event->GetEventVisitor(); + auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId); + if (!iterator_id_stat.has_value()) continue; + int64 iterator_id = iterator_id_stat->IntValue(); + auto result = tf_data_stats->mutable_iterator_metadata()->insert( + {iterator_id, IteratorMetadata()}); + IteratorMetadata& metadata = result.first->second; + if (result.second) { + // First time processing this iterator. + SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata); + } + if (IsRootIteratorEvent(iterator_event_visitor)) { + // Record root iterator events. + (*root_iterator_event_map)[iterator_id].push_back(iterator_event.get()); + } + } + auto device_input_pipeline_second_iterator_events = gtl::FindOrNull( + event_node_map, HostEventType::kDeviceInputPipelineSecondIterator); + if (!device_input_pipeline_second_iterator_events) return; + for (const auto& iterator_event : + *device_input_pipeline_second_iterator_events) { + const XEventVisitor& iterator_event_visitor = + iterator_event->GetEventVisitor(); + auto iterator_id_stat = iterator_event_visitor.GetStat(StatType::kStepId); + if (!iterator_id_stat.has_value()) continue; + int64 iterator_id = iterator_id_stat->IntValue(); + auto result = tf_data_stats->mutable_iterator_metadata()->insert( + {iterator_id, IteratorMetadata()}); + IteratorMetadata& metadata = result.first->second; + if (result.second) { + // First time processing this iterator. + SetIteratorMetadata(iterator_id, iterator_event_visitor, &metadata); + // Find and record device input pipeline ids. + absl::optional device_input_pipeline_id = + FindDeviceInputPipeline(iterator_event_visitor); + if (device_input_pipeline_id.has_value()) { + device_input_pipeline_ids->insert(*device_input_pipeline_id); + } + } + } +} + +void SetInputPipelineMetadata(int64 id, uint64 name_id, + bool is_device_input_pipeline, + InputPipelineMetadata* metadata) { + constexpr absl::string_view kHostInputPipelinePrefix = "Host:"; + constexpr absl::string_view kDeviceInputPipelinePrefix = "Device:"; + metadata->set_id(id); + if (is_device_input_pipeline) { + metadata->set_type(InputPipelineMetadata::DEVICE); + metadata->set_name(absl::StrCat(kDeviceInputPipelinePrefix, name_id)); + } else { + metadata->set_type(InputPipelineMetadata::HOST); + metadata->set_name(absl::StrCat(kHostInputPipelinePrefix, name_id)); + } +} + +void ProcessIteratorEvent(const EventNode& iterator_event, + InputPipelineStat* input_pipeline_stat, + bool is_blocking) { + const XEventVisitor& visitor = iterator_event.GetEventVisitor(); + auto iterator_id_stat = visitor.GetStat(StatType::kStepId); + if (!iterator_id_stat.has_value()) return; + int64 iterator_id = iterator_id_stat->IntValue(); + auto result = input_pipeline_stat->mutable_iterator_stats()->insert( + {iterator_id, IteratorStat()}); + IteratorStat& iterator_stat = result.first->second; + if (result.second) { + iterator_stat.set_id(iterator_id); + iterator_stat.set_start_time_ps(visitor.TimestampPs()); + } + iterator_stat.set_duration_ps(iterator_stat.duration_ps() + + visitor.DurationPs()); + int64 self_time_ps = visitor.DurationPs(); + tensorflow::profiler::Timespan self_time_span = visitor.GetTimespan(); + for (EventNode* child : iterator_event.GetChildren()) { + const XEventVisitor& child_visitor = child->GetEventVisitor(); + if (ParseTfOpFullname(child_visitor.Name()).category == Category::kTfData) { + int64 overlap_duration_ps = + self_time_span.OverlappedDurationPs(child_visitor.GetTimespan()); + ProcessIteratorEvent(*child, input_pipeline_stat, + is_blocking && overlap_duration_ps); + // Note: Assume no overlap between child events. + self_time_ps -= overlap_duration_ps; + } + } + iterator_stat.set_self_time_ps(iterator_stat.self_time_ps() + self_time_ps); + iterator_stat.set_is_blocking(iterator_stat.is_blocking() || is_blocking); + iterator_stat.set_num_calls(iterator_stat.num_calls() + 1); +} + +void SetBottleneckIteratorId(InputPipelineStat* input_pipeline_stat) { + int64 bottleneck_iterator_id = 0; + int64 max_self_time = 0; + for (const auto& pair : input_pipeline_stat->iterator_stats()) { + const auto& id = pair.first; + const auto& iterator_stat = pair.second; + if (iterator_stat.is_blocking() && + iterator_stat.self_time_ps() > max_self_time) { + bottleneck_iterator_id = id; + max_self_time = iterator_stat.self_time_ps(); + } + } + input_pipeline_stat->set_bottleneck_iterator_id(bottleneck_iterator_id); +} + +void ProcessInputPipelines( + const absl::flat_hash_set& device_input_pipeline_ids, + absl::flat_hash_map>* + root_iterator_event_map, + TfDataStats* tf_data_stats) { + auto* input_pipelines = tf_data_stats->mutable_input_pipelines(); + uint64 num_host_input_pipelines = 0; + uint64 num_device_input_pipelines = 0; + for (auto& id_and_events : *root_iterator_event_map) { + auto& root_iterator_id = id_and_events.first; + auto& root_iterator_events = id_and_events.second; + absl::c_sort(root_iterator_events, + [](const EventNode* lhs, const EventNode* rhs) { + return lhs->GetEventVisitor().DurationPs() > + rhs->GetEventVisitor().DurationPs(); + }); + auto result = + input_pipelines->insert({root_iterator_id, InputPipelineStats()}); + InputPipelineStats& input_pipeline_stats = result.first->second; + InputPipelineMetadata* metadata = input_pipeline_stats.mutable_metadata(); + if (result.second) { + bool is_device_input_pipeline = + device_input_pipeline_ids.contains(root_iterator_id); + uint64 name_id = is_device_input_pipeline ? num_device_input_pipelines++ + : num_host_input_pipelines++; + SetInputPipelineMetadata(root_iterator_id, name_id, + is_device_input_pipeline, metadata); + } + uint64 sum_latency_ps = 0; + uint64 min_latency_ps = UINT64_MAX; + uint64 max_latency_ps = 0; + for (const EventNode* root_iterator_event : root_iterator_events) { + InputPipelineStat* stat = input_pipeline_stats.add_stats(); + ProcessIteratorEvent(*root_iterator_event, stat, + /*is_blocking*/ true); + SetBottleneckIteratorId(stat); + uint64 latency_ps = root_iterator_event->GetEventVisitor().DurationPs(); + sum_latency_ps += latency_ps; + min_latency_ps = std::min(min_latency_ps, latency_ps); + max_latency_ps = std::max(max_latency_ps, latency_ps); + } + input_pipeline_stats.set_avg_latency_ps(sum_latency_ps / + root_iterator_events.size()); + input_pipeline_stats.set_min_latency_ps(min_latency_ps); + input_pipeline_stats.set_max_latency_ps(max_latency_ps); + } +} + +} // namespace + +TfDataStats ConvertXPlaneToTfDataStats(XPlane* host_plane) { + TfDataStats tf_data_stats; + EventForest event_forest; + event_forest.AddPlanes(CreateTfXPlaneVisitor, {host_plane}); + event_forest.ConnectEvents(); + event_forest.ConnectTfDataEvents(); + absl::flat_hash_set device_input_pipeline_ids; + absl::flat_hash_map> root_iterator_event_map; + ProcessEventForest(event_forest, &device_input_pipeline_ids, + &root_iterator_event_map, &tf_data_stats); + ProcessInputPipelines(device_input_pipeline_ids, &root_iterator_event_map, + &tf_data_stats); + return tf_data_stats; +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h new file mode 100644 index 00000000000..486c198d735 --- /dev/null +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ + +#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +TfDataStats ConvertXPlaneToTfDataStats(XPlane* host_plane); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc new file mode 100644 index 00000000000..5b597227c83 --- /dev/null +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats_test.cc @@ -0,0 +1,356 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h" + +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" +#include "tensorflow/core/profiler/utils/xplane_test_utils.h" + +namespace tensorflow { +namespace profiler { +namespace { + +using ::testing::EqualsProto; + +// Test with the following example dataset: +// dataset = tf.data.Dataset.range(8) +// dataset = dataset.prefetch(2) +// for _ in dataset: +// pass +TEST(XPlaneToTfDataStatsTest, HostInputPipeline) { + constexpr int64 kPrefetchIteratorId = 123; + constexpr int64 kRangeIteratorId = 456; + constexpr int64 kFirstElementId = 100; + constexpr int64 kSecondElementId = 200; + + XPlane host_plane; + XPlaneBuilder host_plane_builder(&host_plane); + host_plane_builder.ReserveLines(2); + + auto consumer_thread = host_plane_builder.GetOrCreateLine(0); + CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, + 100, {{StatType::kStepId, kPrefetchIteratorId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, + HostEventType::kPrefetchConsume, 80, 20, + {{StatType::kElementId, kFirstElementId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 200, + 20, {{StatType::kStepId, kPrefetchIteratorId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, + HostEventType::kPrefetchConsume, 210, 10, + {{StatType::kElementId, kSecondElementId}}); + + auto producer_thread = host_plane_builder.GetOrCreateLine(1); + // Blocking producer. + CreateXEvent(&host_plane_builder, &producer_thread, + HostEventType::kPrefetchProduce, 0, 80, + {{StatType::kElementId, kFirstElementId}}); + CreateXEvent(&host_plane_builder, &producer_thread, + "Iterator::Prefetch::Range", 0, 80, + {{StatType::kStepId, kRangeIteratorId}, + {StatType::kParentId, kPrefetchIteratorId}}); + // Non-blocking producer. + CreateXEvent(&host_plane_builder, &producer_thread, + HostEventType::kPrefetchProduce, 100, 80, + {{StatType::kElementId, kSecondElementId}}); + CreateXEvent(&host_plane_builder, &producer_thread, + "Iterator::Prefetch::Range", 100, 80, + {{StatType::kStepId, kRangeIteratorId}, + {StatType::kParentId, kPrefetchIteratorId}}); + + TfDataStats tf_data_stats = ConvertXPlaneToTfDataStats(&host_plane); + EXPECT_THAT(tf_data_stats, EqualsProto(R"pb( + iterator_metadata: { + key: 123, + value: { + id: 123 + name: "Prefetch" + long_name: "Iterator::Prefetch" + is_async: true + } + } + iterator_metadata: { + key: 456, + value: { + id: 456 + parent_id: 123 + name: "Range" + long_name: "Iterator::Prefetch::Range" + is_async: false + } + } + input_pipelines { + key: 123, + value: { + metadata { id: 123 type: HOST name: "Host:0" } + avg_latency_ps: 60 + min_latency_ps: 20 + max_latency_ps: 100 + stats { + bottleneck_iterator_id: 456 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 0 + duration_ps: 100 + self_time_ps: 20 + is_blocking: true + num_calls: 1 + } + } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 0 + duration_ps: 80 + self_time_ps: 80 + is_blocking: true + num_calls: 1 + } + } + } + stats { + bottleneck_iterator_id: 123 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 200 + duration_ps: 20 + self_time_ps: 20 + is_blocking: true + num_calls: 1 + } + } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 100 + duration_ps: 80 + self_time_ps: 80 + is_blocking: false + num_calls: 1 + } + } + } + } + } + )pb")); +} + +TEST(XPlaneToTfDataStatsTest, DeviceInputPipeline) { + constexpr int64 kPrefetchIteratorId = 123; + constexpr int64 kRangeIteratorId = 456; + constexpr int64 kElementId = 100; + + XPlane host_plane; + XPlaneBuilder host_plane_builder(&host_plane); + host_plane_builder.ReserveLines(2); + + auto consumer_thread = host_plane_builder.GetOrCreateLine(0); + CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 0, + 30, {{StatType::kStepId, kPrefetchIteratorId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::Prefetch", 100, + 100, {{StatType::kStepId, kPrefetchIteratorId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, + HostEventType::kPrefetchConsume, 180, 20, + {{StatType::kElementId, kElementId}}); + + auto producer_thread = host_plane_builder.GetOrCreateLine(1); + CreateXEvent(&host_plane_builder, &producer_thread, + HostEventType::kPrefetchProduce, 100, 80, + {{StatType::kElementId, kElementId}}); + CreateXEvent(&host_plane_builder, &producer_thread, + "Iterator::Prefetch::Generator", 100, 80, + {{StatType::kStepId, kRangeIteratorId}, + {StatType::kParentId, kPrefetchIteratorId}}); + + TfDataStats tf_data_stats = ConvertXPlaneToTfDataStats(&host_plane); + EXPECT_THAT(tf_data_stats, EqualsProto(R"pb( + iterator_metadata: { + key: 123, + value: { + id: 123 + name: "Prefetch" + long_name: "Iterator::Prefetch" + is_async: true + } + } + iterator_metadata: { + key: 456, + value: { + id: 456 + parent_id: 123 + name: "Generator" + long_name: "Iterator::Prefetch::Generator" + is_async: false + } + } + input_pipelines { + key: 123, + value: { + metadata { id: 123 type: DEVICE name: "Device:0" } + avg_latency_ps: 65 + min_latency_ps: 30 + max_latency_ps: 100 + stats { + bottleneck_iterator_id: 456 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 100 + duration_ps: 100 + self_time_ps: 20 + is_blocking: true + num_calls: 1 + } + } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 100 + duration_ps: 80 + self_time_ps: 80 + is_blocking: true + num_calls: 1 + } + } + } + stats { + bottleneck_iterator_id: 123 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 0 + duration_ps: 30 + self_time_ps: 30 + is_blocking: true + num_calls: 1 + } + } + } + } + } + )pb")); +} + +// Test with the following example dataset: +// dataset = tf.data.Dataset.range(8) +// dataset = dataset.map(lambda x: x + 1) +// dataset = dataset.batch(2) +// for _ in dataset: +// pass +TEST(XPlaneToTfDataStatsTest, MapAndBatch) { + constexpr int64 kMapAndBatchIteratorId = 123; + constexpr int64 kRangeIteratorId = 456; + constexpr int64 kElementId = 100; + + XPlane host_plane; + XPlaneBuilder host_plane_builder(&host_plane); + host_plane_builder.ReserveLines(2); + + XLineBuilder consumer_thread = host_plane_builder.GetOrCreateLine(0); + CreateXEvent(&host_plane_builder, &consumer_thread, "Iterator::MapAndBatch", + 0, 100, {{StatType::kStepId, kMapAndBatchIteratorId}}); + CreateXEvent(&host_plane_builder, &consumer_thread, + HostEventType::kMapAndBatchConsume, 80, 20, + {{StatType::kElementId, kElementId}}); + + XLineBuilder producer_thread = host_plane_builder.GetOrCreateLine(1); + CreateXEvent(&host_plane_builder, &producer_thread, + HostEventType::kMapAndBatchProduce, 0, 30, + {{StatType::kElementId, kElementId}}); + CreateXEvent(&host_plane_builder, &producer_thread, + "Iterator::MapAndBatch::Range", 0, 30, + {{StatType::kStepId, kRangeIteratorId}, + {StatType::kParentId, kMapAndBatchIteratorId}}); + CreateXEvent(&host_plane_builder, &producer_thread, + HostEventType::kMapAndBatchProduce, 40, 30, + {{StatType::kElementId, kElementId}}); + CreateXEvent(&host_plane_builder, &producer_thread, + "Iterator::MapAndBatch::Range", 40, 30, + {{StatType::kStepId, kRangeIteratorId}, + {StatType::kParentId, kMapAndBatchIteratorId}}); + + TfDataStats tf_data_stats = ConvertXPlaneToTfDataStats(&host_plane); + EXPECT_THAT(tf_data_stats, EqualsProto(R"pb( + iterator_metadata: { + key: 123, + value: { + id: 123 + name: "MapAndBatch" + long_name: "Iterator::MapAndBatch" + is_async: true + } + } + iterator_metadata: { + key: 456, + value: { + id: 456 + parent_id: 123 + name: "Range" + long_name: "Iterator::MapAndBatch::Range" + is_async: false + } + } + input_pipelines { + key: 123, + value: { + metadata { id: 123 type: HOST name: "Host:0" } + avg_latency_ps: 100 + min_latency_ps: 100 + max_latency_ps: 100 + stats { + bottleneck_iterator_id: 456 + iterator_stats { + key: 123, + value: { + id: 123 + start_time_ps: 0 + duration_ps: 100 + self_time_ps: 40 + is_blocking: true + num_calls: 1 + } + } + iterator_stats { + key: 456, + value: { + id: 456 + start_time_ps: 0 + duration_ps: 60 + self_time_ps: 60 + is_blocking: true + num_calls: 2 + } + } + } + } + } + )pb")); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc new file mode 100644 index 00000000000..59af75109d0 --- /dev/null +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc @@ -0,0 +1,181 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" + +#include +#include + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" +#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" +#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" +#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" +#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" +#include "tensorflow/core/profiler/convert/xplane_to_trace_events.h" +#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" +#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" +#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +namespace { + +std::pair ConvertXSpaceToTraceEvents( + const std::vector& xspace_paths) { + if (xspace_paths.size() != 1) { + LOG(WARNING) << "Trace events tool expects only 1 XSpace path but gets " + << xspace_paths.size(); + return std::make_pair("", false); + } + + XSpace xspace; + Status status = ReadBinaryProto(Env::Default(), xspace_paths[0], &xspace); + if (!status.ok()) { + LOG(WARNING) << "Could not read XSpace for trace events: " + << xspace_paths[0]; + return std::make_pair("", false); + } + std::string content; + ConvertXSpaceToTraceEventsString(xspace, &content); + return std::make_pair(content, true); +} + +std::pair ConvertMultiXSpacesToOverviewPage( + const std::vector& xspace_paths) { + OpStatsOptions options; + options.generate_kernel_stats_db = true; + options.generate_op_metrics_db = true; + options.generate_step_db = true; + OpStats combined_op_stats; + Status status = ConvertMultiXSpacesToCombinedOpStats(xspace_paths, options, + &combined_op_stats); + if (!status.ok()) { + LOG(WARNING) << "Could not generate OpStats for overview page. Error: " + << status.error_message(); + return std::make_pair("", false); + } + // TODO(profiler): xspace should tell whether this is sampling mode. + return std::make_pair( + ConvertOpStatsToOverviewPage(combined_op_stats).SerializeAsString(), + true); +} + +std::pair ConvertMultiXSpacesToInputPipeline( + const std::vector& xspace_paths) { + OpStatsOptions options; + options.generate_op_metrics_db = true; + options.generate_step_db = true; + OpStats combined_op_stats; + Status status = ConvertMultiXSpacesToCombinedOpStats(xspace_paths, options, + &combined_op_stats); + if (!status.ok()) { + LOG(WARNING) << "Could not generate OpStats for input pipeline. Error: " + << status.error_message(); + return std::make_pair("", false); + } + return std::make_pair(ConvertOpStatsToInputPipelineAnalysis(combined_op_stats) + .SerializeAsString(), + true); +} + +std::pair ConvertMultiXSpacesToTfStats( + const std::vector& xspace_paths) { + OpStatsOptions options; + options.generate_op_metrics_db = true; + options.generate_kernel_stats_db = true; + OpStats combined_op_stats; + Status status = ConvertMultiXSpacesToCombinedOpStats(xspace_paths, options, + &combined_op_stats); + if (!status.ok()) { + LOG(WARNING) << "Could not generate OpStats for tensorflow stats. Error: " + << status.error_message(); + return std::make_pair("", false); + } + return std::make_pair( + ConvertOpStatsToTfStats(combined_op_stats).SerializeAsString(), true); +} + +std::pair ConvertMultiXSpacesToKernelStats( + const std::vector& xspace_paths) { + OpStatsOptions options; + options.generate_kernel_stats_db = true; + OpStats combined_op_stats; + Status status = ConvertMultiXSpacesToCombinedOpStats(xspace_paths, options, + &combined_op_stats); + if (!status.ok()) { + LOG(WARNING) << "Could not generate OpStats for kernel stats. Error: " + << status.error_message(); + return std::make_pair("", false); + } + return std::make_pair(combined_op_stats.kernel_stats_db().SerializeAsString(), + true); +} + +std::pair ConvertXSpaceToMemoryProfile( + const std::vector& xspace_paths) { + if (xspace_paths.size() != 1) { + LOG(WARNING) << "Memory profile tool expects only 1 XSpace path but gets " + << xspace_paths.size(); + return std::make_pair("", false); + } + XSpace xspace; + Status status = ReadBinaryProto(Env::Default(), xspace_paths[0], &xspace); + if (!status.ok()) { + LOG(WARNING) << "Could not read XSpace for memory profile: " + << xspace_paths[0]; + return std::make_pair("", false); + } + std::string json_output; + status = ConvertXSpaceToMemoryProfileJson(xspace, &json_output); + if (!status.ok()) { + LOG(WARNING) << "Could not generate memory profile. Error: " + << status.error_message(); + return std::make_pair("", false); + } + return std::make_pair(json_output, true); +} + +} // namespace + +std::pair ConvertMultiXSpacesToToolData( + const std::vector& xspace_paths, + const absl::string_view tool_name) { + if (tool_name == "trace_viewer") { + return ConvertXSpaceToTraceEvents(xspace_paths); + } else if (tool_name == "overview_page") { + return ConvertMultiXSpacesToOverviewPage(xspace_paths); + } else if (tool_name == "input_pipeline_analyzer") { + return ConvertMultiXSpacesToInputPipeline(xspace_paths); + } else if (tool_name == "tensorflow_stats") { + return ConvertMultiXSpacesToTfStats(xspace_paths); + } else if (tool_name == "kernel_stats") { + return ConvertMultiXSpacesToKernelStats(xspace_paths); + } else if (tool_name == "memory_profile") { + return ConvertXSpaceToMemoryProfile(xspace_paths); + } else { + LOG(WARNING) << "Can not find tool: " << tool_name << ". Please update to " + << "the latest version of Tensorflow."; + return std::make_pair("", false); + } +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.h b/tensorflow/core/profiler/convert/xplane_to_tools_data.h new file mode 100644 index 00000000000..0f8155b01ba --- /dev/null +++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" + +namespace tensorflow { +namespace profiler { + +// Convert multiple XSpaces in to tool specific data. +// Return the serialized string of tool specific data and whether the conversion +// is successful. +std::pair ConvertMultiXSpacesToToolData( + const std::vector& xspace_paths, + const absl::string_view tool_name); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD index bbd73edffe6..336375f7d22 100644 --- a/tensorflow/core/profiler/internal/BUILD +++ b/tensorflow/core/profiler/internal/BUILD @@ -1,7 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test") -load("//tensorflow/core/platform:build_config_root.bzl", "if_static") package( default_visibility = ["//tensorflow:internal"], @@ -24,7 +22,7 @@ cc_library( "//tensorflow/c:checkpoint_reader", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings", @@ -55,7 +53,7 @@ cc_library( ":tfprof_utils", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings:str_format", @@ -77,7 +75,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings:str_format", @@ -97,7 +95,7 @@ cc_library( "//tensorflow/c:checkpoint_reader", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings", @@ -121,7 +119,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings", @@ -142,7 +140,7 @@ cc_library( "//tensorflow/c:checkpoint_reader", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings:str_format", @@ -178,7 +176,7 @@ cc_library( "//tensorflow/c:checkpoint_reader", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings", @@ -202,7 +200,7 @@ cc_library( "//tensorflow/c:checkpoint_reader", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:protos_all_cc", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings", @@ -257,7 +255,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", + "//tensorflow/core/platform:regexp", "//tensorflow/core/profiler:tfprof_options", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -360,166 +358,6 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "traceme_recorder", - hdrs = ["traceme_recorder.h"], - visibility = [ - "//perftools/accelerators/xprof/xprofilez:__subpackages__", - "//tensorflow/core/profiler:__subpackages__", - "//third_party/tf_runtime_google:__subpackages__", - ], - deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "//tensorflow/core:lib", - ] + if_static([ - ":traceme_recorder_impl", - ]), -) - -cc_library( - name = "traceme_recorder_impl", - srcs = [ - "traceme_recorder.cc", - "traceme_recorder.h", - ], - visibility = [ - "//tensorflow/core/profiler:__pkg__", - "//tensorflow/python:__pkg__", - ], - deps = [ - "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", - ], - alwayslink = True, -) - -tf_cc_test( - name = "traceme_recorder_test", - srcs = ["traceme_recorder_test.cc"], - deps = [ - ":traceme_recorder", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "profiler_interface", - hdrs = ["profiler_interface.h"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - ], -) - -cc_library( - name = "profiler_factory", - hdrs = ["profiler_factory.h"], - deps = [ - ":profiler_interface", - "//tensorflow/core/profiler:profiler_options_proto_cc", - ] + if_static([ - ":profiler_factory_impl", - ]), -) - -cc_library( - name = "profiler_factory_impl", - srcs = [ - "profiler_factory.cc", - "profiler_factory.h", - ], - visibility = [ - "//tensorflow/core/profiler:__pkg__", - ], - deps = [ - ":profiler_interface", - "//tensorflow/core:lib", - "//tensorflow/core/profiler:profiler_options_proto_cc", - ], - alwayslink = True, -) - -filegroup( - name = "mobile_srcs", - srcs = [ - "profiler_interface.h", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "annotation_stack", - hdrs = ["annotation_stack.h"], - visibility = [ - "//perftools/accelerators/xprof/xprofilez:__subpackages__", - "//tensorflow/core/profiler:__subpackages__", - ], - deps = [ - "@com_google_absl//absl/strings", - "//tensorflow/core:lib", - ] + if_static([ - ":annotation_stack_impl", - ]), -) - -cc_library( - name = "annotation_stack_impl", - srcs = [ - "annotation_stack.cc", - "annotation_stack.h", - ], - visibility = [ - "//tensorflow/core/profiler:__pkg__", - "//tensorflow/python:__pkg__", - ], - deps = [ - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], - alwayslink = True, -) - -tf_cc_test( - name = "scoped_annotation_test", - size = "small", - srcs = ["scoped_annotation_test.cc"], - deps = [ - ":annotation_stack", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/profiler/lib:scoped_annotation", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "parse_annotation", - srcs = ["parse_annotation.cc"], - hdrs = ["parse_annotation.h"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "parse_annotation_test", - srcs = ["parse_annotation_test.cc"], - deps = [ - ":parse_annotation", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/strings", + "//tensorflow/core/platform:regexp", ], ) diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD index 6fb518d413e..552780dab79 100644 --- a/tensorflow/core/profiler/internal/cpu/BUILD +++ b/tensorflow/core/profiler/internal/cpu/BUILD @@ -1,5 +1,7 @@ +load("//tensorflow/core/platform:build_config_root.bzl", "if_static") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( default_visibility = ["//tensorflow:internal"], @@ -10,12 +12,13 @@ cc_library( name = "host_tracer_utils", srcs = ["host_tracer_utils.cc"], hdrs = ["host_tracer_utils.h"], + copts = tf_profiler_copts(), visibility = ["//tensorflow/core/profiler:friends"], deps = [ + ":traceme_recorder", "//tensorflow/core:lib", - "//tensorflow/core/profiler/internal:parse_annotation", - "//tensorflow/core/profiler/internal:traceme_recorder", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:parse_annotation", "//tensorflow/core/profiler/utils:tf_op_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_utils", @@ -27,14 +30,15 @@ cc_library( cc_library( name = "host_tracer", srcs = ["host_tracer.cc"], + copts = tf_profiler_copts(), deps = [ ":host_tracer_utils", + ":traceme_recorder", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler:profiler_options_proto_cc", - "//tensorflow/core/profiler/internal:profiler_factory", - "//tensorflow/core/profiler/internal:profiler_interface", - "//tensorflow/core/profiler/internal:traceme_recorder", + "//tensorflow/core/profiler/lib:profiler_factory", + "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", @@ -54,7 +58,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler:profiler_options_proto_cc", - "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", @@ -62,21 +66,91 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@com_google_googletest//:gtest", ], ) +cc_library( + name = "traceme_recorder", + hdrs = ["traceme_recorder.h"], + copts = tf_profiler_copts(), + visibility = ["//tensorflow/core/profiler:internal"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "//tensorflow/core:lib", + ] + if_static([ + ":traceme_recorder_impl", + ]), +) + +cc_library( + name = "traceme_recorder_impl", + srcs = [ + "traceme_recorder.cc", + "traceme_recorder.h", + ], + copts = tf_profiler_copts(), + visibility = [ + "//tensorflow/core/profiler:__pkg__", + "//tensorflow/python:__pkg__", + ], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], + alwayslink = True, +) + +tf_cc_test( + name = "traceme_recorder_test", + srcs = ["traceme_recorder_test.cc"], + deps = [ + ":traceme_recorder", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "annotation_stack", + hdrs = ["annotation_stack.h"], + copts = tf_profiler_copts(), + visibility = ["//tensorflow/core/profiler:internal"], + deps = [ + "@com_google_absl//absl/strings", + "//tensorflow/core:lib", + ] + if_static([ + ":annotation_stack_impl", + ]), +) + +cc_library( + name = "annotation_stack_impl", + srcs = [ + "annotation_stack.cc", + "annotation_stack.h", + ], + copts = tf_profiler_copts(), + visibility = ["//tensorflow/core/profiler:__pkg__"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], + alwayslink = True, +) + cc_library( name = "python_tracer", srcs = ["python_tracer.cc"], - copts = ["-fexceptions"], + copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler:profiler_options_proto_cc", - "//tensorflow/core/profiler/internal:profiler_factory", - "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/lib:profiler_factory", + "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/python/profiler/internal:python_hooks", ], @@ -86,14 +160,15 @@ cc_library( cc_library( name = "metadata_collector", srcs = ["metadata_collector.cc"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service/gpu:gpu_debug_info_manager", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler:profiler_options_proto_cc", - "//tensorflow/core/profiler/internal:profiler_factory", - "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/lib:profiler_factory", + "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", diff --git a/tensorflow/core/profiler/internal/annotation_stack.cc b/tensorflow/core/profiler/internal/cpu/annotation_stack.cc similarity index 95% rename from tensorflow/core/profiler/internal/annotation_stack.cc rename to tensorflow/core/profiler/internal/cpu/annotation_stack.cc index 4c15ca47c3d..5a4c34e84a0 100644 --- a/tensorflow/core/profiler/internal/annotation_stack.cc +++ b/tensorflow/core/profiler/internal/cpu/annotation_stack.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/annotation_stack.h" +#include "tensorflow/core/profiler/internal/cpu/annotation_stack.h" #include diff --git a/tensorflow/core/profiler/internal/annotation_stack.h b/tensorflow/core/profiler/internal/cpu/annotation_stack.h similarity index 93% rename from tensorflow/core/profiler/internal/annotation_stack.h rename to tensorflow/core/profiler/internal/cpu/annotation_stack.h index e626c4c73cc..b9d865fbf96 100644 --- a/tensorflow/core/profiler/internal/annotation_stack.h +++ b/tensorflow/core/profiler/internal/cpu/annotation_stack.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ANNOTATION_STACK_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_ANNOTATION_STACK_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_ANNOTATION_STACK_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_ANNOTATION_STACK_H_ #include @@ -93,4 +93,4 @@ class AnnotationStack { } // namespace profiler } // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ANNOTATION_STACK_H_ +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_ANNOTATION_STACK_H_ diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc index 61cb75b2ab3..ba882067463 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/internal/cpu/host_tracer_utils.h" -#include "tensorflow/core/profiler/internal/profiler_factory.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" -#include "tensorflow/core/profiler/internal/traceme_recorder.h" +#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h" +#include "tensorflow/core/profiler/lib/profiler_factory.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc index 0e4c3dd7a9b..1b48df0a650 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/core/framework/step_stats.pb.h" @@ -25,7 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/profiler_options.pb.h" diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc index 5d3e9ba1fc3..4f12776f581 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.cc @@ -21,9 +21,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/internal/parse_annotation.h" -#include "tensorflow/core/profiler/internal/traceme_recorder.h" +#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/parse_annotation.h" #include "tensorflow/core/profiler/utils/tf_op_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.h b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.h index fa5bf382c88..5770058cc6d 100644 --- a/tensorflow/core/profiler/internal/cpu/host_tracer_utils.h +++ b/tensorflow/core/profiler/internal/cpu/host_tracer_utils.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_HOST_TRACER_UTILS_H_ #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/internal/traceme_recorder.h" +#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc index c9a593f101b..4538ca0a344 100644 --- a/tensorflow/core/profiler/internal/cpu/metadata_collector.cc +++ b/tensorflow/core/profiler/internal/cpu/metadata_collector.cc @@ -18,12 +18,24 @@ limitations under the License. #include #include +#if defined(__clang__) && __cplusplus >= 201703L // clang C++17 +#define TF_PROFILER_DISABLE_CXX17_WARNINGS \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wc++98-c++11-c++14-compat\"") +#define TF_PROFILER_ENABLE_CXX17_WARNINGS _Pragma("clang diagnostic pop") +#else +#define TF_PROFILER_DISABLE_CXX17_WARNINGS +#define TF_PROFILER_ENABLE_CXX17_WARNINGS +#endif + +TF_PROFILER_DISABLE_CXX17_WARNINGS #include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h" +TF_PROFILER_ENABLE_CXX17_WARNINGS #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/profiler/internal/profiler_factory.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_factory.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc index 289122e8a16..00f5cac61f1 100644 --- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/profiler/internal/profiler_factory.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_factory.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/protobuf/config.pb.h" diff --git a/tensorflow/core/profiler/internal/traceme_recorder.cc b/tensorflow/core/profiler/internal/cpu/traceme_recorder.cc similarity index 99% rename from tensorflow/core/profiler/internal/traceme_recorder.cc rename to tensorflow/core/profiler/internal/cpu/traceme_recorder.cc index 268585bde8c..57d36127493 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder.cc +++ b/tensorflow/core/profiler/internal/cpu/traceme_recorder.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/traceme_recorder.h" +#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h" #include diff --git a/tensorflow/core/profiler/internal/traceme_recorder.h b/tensorflow/core/profiler/internal/cpu/traceme_recorder.h similarity index 95% rename from tensorflow/core/profiler/internal/traceme_recorder.h rename to tensorflow/core/profiler/internal/cpu/traceme_recorder.h index 5fdea5bddbd..69246ac4aa9 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder.h +++ b/tensorflow/core/profiler/internal/cpu/traceme_recorder.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_ +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_TRACEME_RECORDER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_TRACEME_RECORDER_H_ #include #include @@ -120,4 +120,5 @@ class TraceMeRecorder { } // namespace profiler } // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_ + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_TRACEME_RECORDER_H_ diff --git a/tensorflow/core/profiler/internal/traceme_recorder_test.cc b/tensorflow/core/profiler/internal/cpu/traceme_recorder_test.cc similarity index 98% rename from tensorflow/core/profiler/internal/traceme_recorder_test.cc rename to tensorflow/core/profiler/internal/cpu/traceme_recorder_test.cc index 8d7abc94e8f..b6ab78ed546 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder_test.cc +++ b/tensorflow/core/profiler/internal/cpu/traceme_recorder_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/traceme_recorder.h" +#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h" #include #include @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include #include "absl/strings/str_cat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env_time.h" diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index dd8ff1e9a27..aa88db2b224 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -19,6 +19,7 @@ load( "//tensorflow/stream_executor:build_defs.bzl", "tf_additional_cupti_deps", ) +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( default_visibility = ["//tensorflow:internal"], @@ -28,7 +29,7 @@ package( tf_cuda_library( name = "device_tracer", srcs = tf_additional_device_tracer_srcs(), - copts = tf_copts(), + copts = tf_profiler_copts() + tf_copts(), cuda_deps = [ "//tensorflow/core/profiler/internal/gpu:cupti_tracer", "//tensorflow/core/profiler/internal/gpu:cupti_wrapper", @@ -38,11 +39,11 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler/internal:annotation_stack", - "//tensorflow/core/profiler/internal:parse_annotation", - "//tensorflow/core/profiler/internal:profiler_factory", - "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/internal/cpu:annotation_stack", + "//tensorflow/core/profiler/lib:profiler_factory", + "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:parse_annotation", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", @@ -50,7 +51,6 @@ tf_cuda_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", ], alwayslink = 1, @@ -73,7 +73,6 @@ tf_cc_test_gpu( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", - "//tensorflow/core:direct_session_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:gpu_runtime", @@ -83,8 +82,9 @@ tf_cc_test_gpu( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/profiler/internal:profiler_interface", + "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/utils:tf_xplane_visitor", "//tensorflow/core/profiler/utils:xplane_schema", @@ -95,7 +95,7 @@ tf_cc_test_gpu( tf_cuda_library( name = "cupti_interface", hdrs = if_cuda_is_configured_compat(["cupti_interface.h"]), - copts = tf_copts(), + copts = tf_profiler_copts() + tf_copts(), visibility = ["//visibility:public"], deps = [ "//tensorflow/core:lib", @@ -113,7 +113,7 @@ tf_cuda_library( name = "cupti_wrapper", srcs = if_cuda_is_configured_compat(["cupti_wrapper.cc"]), hdrs = if_cuda_is_configured_compat(["cupti_wrapper.h"]), - copts = tf_copts(), + copts = tf_profiler_copts() + tf_copts(), linkstatic = 1, visibility = ["//visibility:public"], deps = [ @@ -125,13 +125,13 @@ tf_cuda_library( name = "cupti_tracer", srcs = if_cuda_is_configured_compat(["cupti_tracer.cc"]), hdrs = if_cuda_is_configured_compat(["cupti_tracer.h"]), - copts = tf_copts(), + copts = tf_profiler_copts() + tf_copts(), visibility = ["//visibility:public"], deps = [ ":cupti_interface", ":cupti_utils", "//tensorflow/core:lib", - "//tensorflow/core/profiler/internal:annotation_stack", + "//tensorflow/core/profiler/internal/cpu:annotation_stack", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:node_hash_map", @@ -143,7 +143,7 @@ tf_cuda_library( tf_cuda_library( name = "cupti_utils", srcs = if_cuda_is_configured_compat(["cupti_utils.cc"]), - copts = tf_copts(), + copts = tf_profiler_copts() + tf_copts(), cuda_deps = [ ":cupti_interface", ":cupti_wrapper", diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc index bda7d5840ab..aedb1722fad 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/profiler/internal/annotation_stack.h" +#include "tensorflow/core/profiler/internal/cpu/annotation_stack.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer.cc b/tensorflow/core/profiler/internal/gpu/device_tracer.cc index 978f3a90cd3..ed675e712a1 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "tensorflow/core/framework/step_stats.pb.h" @@ -35,13 +34,13 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/profiler/internal/annotation_stack.h" +#include "tensorflow/core/profiler/internal/cpu/annotation_stack.h" #include "tensorflow/core/profiler/internal/gpu/cupti_tracer.h" #include "tensorflow/core/profiler/internal/gpu/cupti_wrapper.h" -#include "tensorflow/core/profiler/internal/parse_annotation.h" -#include "tensorflow/core/profiler/internal/profiler_factory.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_factory.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/parse_annotation.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" @@ -96,13 +95,12 @@ void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, absl::StrCat("$$", static_cast(event.context_id))); } if (event.type == CuptiTracerEventType::Kernel) { - const std::string kernel_details = - absl::StrFormat("regs:%u shm:%u grid:%u,%u,%u block:%u,%u,%u", - event.kernel_info.registers_per_thread, - event.kernel_info.static_shared_memory_usage, - event.kernel_info.grid_x, event.kernel_info.grid_y, - event.kernel_info.grid_z, event.kernel_info.block_x, - event.kernel_info.block_y, event.kernel_info.block_z); + std::string kernel_details = absl::StrCat( + "regs:", event.kernel_info.registers_per_thread, + " shm:", event.kernel_info.static_shared_memory_usage, + " grid:", event.kernel_info.grid_x, ",", event.kernel_info.grid_y, ",", + event.kernel_info.grid_z, " block:", event.kernel_info.block_x, ",", + event.kernel_info.block_y, ",", event.kernel_info.block_z); xevent.AddStatValue(*plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kKernelDetails)), *plane->GetOrCreateStatMetadata(kernel_details)); @@ -112,15 +110,15 @@ void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, event.type == CuptiTracerEventType::MemcpyP2P || event.type == CuptiTracerEventType::MemcpyOther) { const auto& memcpy_info = event.memcpy_info; - std::string memcpy_details = - absl::StrFormat("size:%u dest:%u async:%u", memcpy_info.num_bytes, - memcpy_info.destination, memcpy_info.async); + std::string memcpy_details = absl::StrCat("size:", memcpy_info.num_bytes, + " dest:", memcpy_info.destination, + " async:", memcpy_info.async); xevent.AddStatValue(*plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kMemcpyDetails)), memcpy_details); } else if (event.type == CuptiTracerEventType::MemoryAlloc) { std::string memalloc_details = - absl::StrFormat("num_bytes:%u", event.memalloc_info.num_bytes); + absl::StrCat("num_bytes:", event.memalloc_info.num_bytes); xevent.AddStatValue(*plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kMemallocDetails)), memalloc_details); @@ -353,14 +351,14 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { ns->set_node_name(activity_name); switch (event.type) { case CuptiTracerEventType::Kernel: { - ns->set_timeline_label(absl::StrFormat( - "%s regs:%u shm:%u grid:%u,%u,%u block:%u,%u,%u@@%s", - kernel_name, event.kernel_info.registers_per_thread, - event.kernel_info.static_shared_memory_usage, - event.kernel_info.grid_x, event.kernel_info.grid_y, - event.kernel_info.grid_z, event.kernel_info.block_x, - event.kernel_info.block_y, event.kernel_info.block_z, - event.annotation)); + ns->set_timeline_label(absl::StrCat( + kernel_name, " regs:", event.kernel_info.registers_per_thread, + " shm:", event.kernel_info.static_shared_memory_usage, + " grid: ", event.kernel_info.grid_x, ",", + event.kernel_info.grid_y, ",", event.kernel_info.grid_z, + " block:", event.kernel_info.block_x, ",", + event.kernel_info.block_y, ",", event.kernel_info.block_z, + "@@", event.annotation)); DeviceStepStats*& stream_dev_stats = stream_dev_stats_map[std::make_pair(event.stream_id, event.type)]; diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc index 973167ff51b..55ccdbed977 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index fd209259c3b..10e5df84345 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -1,12 +1,12 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/core/platform:build_config_root.bzl", "if_static") load("//tensorflow:tensorflow.bzl", "if_not_android", "tf_cc_test", "tf_cuda_library") - -# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "filegroup") - -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") +load( + "//tensorflow/core/profiler/builds:build_config.bzl", + "tf_profiler_copts", + "tf_profiler_pybind_cc_library_wrapper", +) package( default_visibility = [ @@ -16,13 +16,10 @@ package( licenses = ["notice"], # Apache 2.0 ) -tf_pybind_cc_library_wrapper( - name = "profiler_session_headers", - visibility = [ - "//tensorflow/core/profiler/rpc:__pkg__", - "//tensorflow/python/profiler/internal:__pkg__", - ], - deps = [":profiler_session"], +tf_profiler_pybind_cc_library_wrapper( + name = "profiler_session_for_pybind", + actual = ":profiler_session", + visibility = ["//tensorflow/python/profiler/internal:__pkg__"], ) cc_library( @@ -30,7 +27,7 @@ cc_library( hdrs = ["profiler_session.h"], visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/core/profiler/internal:profiler_interface", + ":profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler:profiler_options_proto_cc", "//tensorflow/core:lib", @@ -46,6 +43,7 @@ cc_library( "profiler_session.cc", "profiler_session.h", ], + copts = tf_profiler_copts(), visibility = [ "//tensorflow/core/profiler:__pkg__", "//tensorflow/python:__pkg__", @@ -53,7 +51,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core/platform", - "//tensorflow/core/profiler/internal:profiler_interface", + ":profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/core/profiler:profiler_options_proto_cc", "@com_google_absl//absl/memory", @@ -61,7 +59,7 @@ cc_library( ] + if_not_android([ ":profiler_lock", "//tensorflow/core/profiler/convert:post_process_single_host_xplane", - "//tensorflow/core/profiler/internal:profiler_factory", + "//tensorflow/core/profiler/lib:profiler_factory", "//tensorflow/core/profiler/utils:derived_timeline", "//tensorflow/core/profiler/utils:group_events", "//tensorflow/core/profiler/utils:xplane_utils", @@ -70,57 +68,46 @@ cc_library( alwayslink = True, ) -tf_pybind_cc_library_wrapper( - name = "local_profiler_headers", - visibility = [ - "//tensorflow/core/profiler/rpc:__pkg__", - "//tensorflow/python/profiler/internal:__pkg__", - ], - deps = [":local_profiler"], -) - cc_library( - name = "local_profiler", - hdrs = ["local_profiler.h"], - visibility = ["//tensorflow/core/profiler:internal"], + name = "profiler_factory", + hdrs = ["profiler_factory.h"], deps = [ - "@com_google_absl//absl/memory", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform", + ":profiler_interface", "//tensorflow/core/profiler:profiler_options_proto_cc", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "//tensorflow/core/profiler/internal:profiler_interface", ] + if_static([ - ":local_profiler_impl", + ":profiler_factory_impl", ]), ) cc_library( - name = "local_profiler_impl", - srcs = ["local_profiler.cc"], - hdrs = ["local_profiler.h"], - visibility = ["//tensorflow/core/profiler:internal"], + name = "profiler_factory_impl", + srcs = [ + "profiler_factory.cc", + "profiler_factory.h", + ], + copts = tf_profiler_copts(), + visibility = [ + "//tensorflow/core/profiler:__pkg__", + ], deps = [ - ":profiler_lock", + ":profiler_interface", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform", "//tensorflow/core/profiler:profiler_options_proto_cc", - "//tensorflow/core/profiler/convert:post_process_single_host_xplane", - "//tensorflow/core/profiler/internal:profiler_factory", - "//tensorflow/core/profiler/internal:profiler_interface", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "//tensorflow/core/profiler/utils:derived_timeline", - "//tensorflow/core/profiler/utils:group_events", - "//tensorflow/core/profiler/utils:xplane_schema", - "//tensorflow/core/profiler/utils:xplane_utils", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/time", ], alwayslink = True, ) +cc_library( + name = "profiler_interface", + hdrs = ["profiler_interface.h"], + copts = tf_profiler_copts(), + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + ], +) + tf_cuda_library( name = "profiler_backends", cuda_deps = [ @@ -140,10 +127,10 @@ tf_cuda_library( alwayslink = True, ) -tf_pybind_cc_library_wrapper( - name = "traceme_headers", +tf_profiler_pybind_cc_library_wrapper( + name = "traceme_for_pybind", + actual = ":traceme", visibility = ["//tensorflow/python/profiler/internal:__pkg__"], - deps = [":traceme"], ) cc_library( @@ -156,7 +143,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/platform", ] + if_not_android([ - "//tensorflow/core/profiler/internal:traceme_recorder", + "//tensorflow/core/profiler/internal/cpu:traceme_recorder", ]), ) @@ -217,14 +204,30 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/platform", ] + if_not_android([ - "//tensorflow/core/profiler/internal:annotation_stack", + "//tensorflow/core/profiler/internal/cpu:annotation_stack", ]), ) +tf_cc_test( + name = "scoped_annotation_test", + size = "small", + srcs = ["scoped_annotation_test.cc"], + deps = [ + ":scoped_annotation", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler/internal/cpu:annotation_stack", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "profiler_lock", srcs = ["profiler_lock.cc"], hdrs = ["profiler_lock.h"], + copts = tf_profiler_copts(), visibility = ["//tensorflow/core/profiler:internal"], ) @@ -233,6 +236,7 @@ filegroup( srcs = [ "annotated_traceme.h", "connected_traceme.h", + "profiler_interface.h", "profiler_session.cc", "profiler_session.h", "scoped_annotation.h", diff --git a/tensorflow/core/profiler/lib/local_profiler.cc b/tensorflow/core/profiler/lib/local_profiler.cc deleted file mode 100644 index b3a8cedefb0..00000000000 --- a/tensorflow/core/profiler/lib/local_profiler.cc +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/profiler/lib/local_profiler.h" - -#include - -#include "absl/memory/memory.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "tensorflow/core/platform/env_time.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/platform.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/post_process_single_host_xplane.h" -#include "tensorflow/core/profiler/internal/profiler_factory.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" -#include "tensorflow/core/profiler/lib/profiler_lock.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/core/profiler/utils/derived_timeline.h" -#include "tensorflow/core/profiler/utils/group_events.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" - -namespace tensorflow { -namespace profiler { - -/*static*/ std::unique_ptr LocalProfiler::Create( - const ProfileOptions& options, Status* out_status) { - auto profiler = absl::WrapUnique(new LocalProfiler(options)); - Status status = profiler->Init(); - if (out_status) { - *out_status = status; - } - if (!status.ok()) { - LOG(ERROR) << status; - return nullptr; - } - return profiler; -} - -LocalProfiler::LocalProfiler(ProfileOptions options) - : options_(std::move(options)) {} - -LocalProfiler::~LocalProfiler() { - mutex_lock lock(mutex_); - - for (auto& profiler : profilers_) { - profiler->Stop().IgnoreError(); - } - - if (active_) { - // Allow another LocalProfiler to be instantiated. - ReleaseProfilerLock(); - active_ = false; - } -} - -Status LocalProfiler::Init() { - mutex_lock lock(mutex_); - VLOG(1) << "Creating a LocalProfiler."; - - bool active_ = AcquireProfilerLock(); - if (!active_) { - return errors::Unavailable("Another LocalProfiler is active."); - } - - CreateProfilers(options_, &profilers_); - - VLOG(1) << "LocalProfiler initialized with " << profilers_.size() - << " profilers."; - return Status::OK(); -} - -Status LocalProfiler::Start() { - mutex_lock lock(mutex_); - VLOG(1) << "Starting all profilers."; - - if (!active_) { - return errors::FailedPrecondition("LocalProfiler is inactive."); - } - - if (start_time_ns_ != 0) { - return errors::FailedPrecondition("LocalProfiler is not restartable."); - } - - start_time_ns_ = EnvTime::NowNanos(); - - Status status; - for (auto& profiler : profilers_) { - Status start_status = profiler->Start(); - if (!start_status.ok()) { - LOG(WARNING) << "Encountered error while starting profiler: " - << start_status.ToString(); - } - status.Update(start_status); - } - - VLOG(1) << "Started all profilers."; - return status; -} - -Status LocalProfiler::Stop() { - mutex_lock lock(mutex_); - VLOG(1) << "Stopping all profilers."; - - if (!active_) { - return errors::FailedPrecondition("LocalProfiler is inactive."); - } - - if (start_time_ns_ == 0) { - return errors::FailedPrecondition( - "LocalProfiler needs to Start() before it can stop producing data."); - } - - Status status; - for (auto& profiler : profilers_) { - status.Update(profiler->Stop()); - } - - // Allow another LocalProfiler to be instantiated. - if (active_) { - ReleaseProfilerLock(); - active_ = false; - } - - VLOG(1) << "Stopped all profilers."; - return status; -} - -Status LocalProfiler::CollectData(XSpace* space) { - Status status; - uint64 data_start_time_ns; - - { - mutex_lock lock(mutex_); - VLOG(1) << "Collecting data from " << profilers_.size() << " profilers."; - - if (!active_) { - return errors::FailedPrecondition("LocalProfiler is inactive."); - } - - if (start_time_ns_ != 0) { - return errors::FailedPrecondition( - "LocalProfiler needs to Stop() before collecting data."); - } - - for (auto& profiler : profilers_) { - VLOG(3) << "Collecting data from " << typeid(*profiler).name(); - status.Update(profiler->CollectData(space)); - } - - profilers_.clear(); - - data_start_time_ns = start_time_ns_; - } - - PostProcessSingleHostXSpace(space, data_start_time_ns); - return status; -} - -Status LocalProfiler::CollectData(RunMetadata* run_metadata) { - return errors::Unimplemented( - "Collecting profiler data into RunMetaData is unsupported."); -} - -} // namespace profiler -} // namespace tensorflow diff --git a/tensorflow/core/profiler/lib/local_profiler.h b/tensorflow/core/profiler/lib/local_profiler.h deleted file mode 100644 index 1de71d13676..00000000000 --- a/tensorflow/core/profiler/lib/local_profiler.h +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_PROFILER_LIB_LOCAL_PROFILER_H_ -#define TENSORFLOW_CORE_PROFILER_LIB_LOCAL_PROFILER_H_ - -#include -#include - -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/core/protobuf/config.pb.h" - -namespace tensorflow { -namespace profiler { - -// LocalProfiler encapsulates multiple profiler backends that each implements. -// ProfilerInterface. -// Thread-safety: LocalProfiler is thread-safe. -class LocalProfiler : public ProfilerInterface { - public: - // Instantiates a LocalProfiler if there is not one already active. - // Returns null on errors, which will be indicated by the Status code. - static std::unique_ptr Create(const ProfileOptions& options, - Status* status); - - static ProfileOptions DefaultOptions() { - ProfileOptions options; - options.set_version(1); - options.set_device_tracer_level(1); - options.set_host_tracer_level(2); - options.set_device_type(ProfileOptions::UNSPECIFIED); - options.set_python_tracer_level(0); - options.set_enable_hlo_proto(false); - options.set_include_dataset_ops(true); - return options; - } - - // Starts all profilers. - Status Start() override TF_LOCKS_EXCLUDED(mutex_); - - // Stops all profilers. - Status Stop() override TF_LOCKS_EXCLUDED(mutex_); - - // Collects data from all profilers into XSpace. Post-process the XSpace - // (e.g., groups trace events per step). This is best effort profiling and - // XSpace may contain data collected before any errors occurred. - Status CollectData(XSpace* space) override TF_LOCKS_EXCLUDED(mutex_); - - // Unimplemented, do not use. This will be deprecated in future. - Status CollectData(RunMetadata* run_metadata) override; - - // Deletes an existing Profiler and enables starting a new one. - ~LocalProfiler() override; - - private: - // Constructs an instance of the class and starts profiling - explicit LocalProfiler(ProfileOptions options); - - // Neither copyable or movable. - LocalProfiler(const LocalProfiler&) = delete; - LocalProfiler& operator=(const LocalProfiler&) = delete; - - // Initializes LocalProfiler and sets ups all profilers. - Status Init(); - - mutex mutex_; - - std::vector> profilers_ - TF_GUARDED_BY(mutex_); - - // True if the LocalProfiler is active. - bool active_ TF_GUARDED_BY(mutex_) = false; - - // Time when Start() was called. - uint64 start_time_ns_ TF_GUARDED_BY(mutex_) = 0; - - ProfileOptions options_; -}; - -} // namespace profiler -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_PROFILER_LIB_LOCAL_PROFILER_H_ diff --git a/tensorflow/core/profiler/internal/profiler_factory.cc b/tensorflow/core/profiler/lib/profiler_factory.cc similarity index 92% rename from tensorflow/core/profiler/internal/profiler_factory.cc rename to tensorflow/core/profiler/lib/profiler_factory.cc index 5152e79bdc8..e5bb3836365 100644 --- a/tensorflow/core/profiler/internal/profiler_factory.cc +++ b/tensorflow/core/profiler/lib/profiler_factory.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/profiler_factory.h" +#include "tensorflow/core/profiler/lib/profiler_factory.h" #include #include @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/profiler_options.pb.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/internal/profiler_factory.h b/tensorflow/core/profiler/lib/profiler_factory.h similarity index 81% rename from tensorflow/core/profiler/internal/profiler_factory.h rename to tensorflow/core/profiler/lib/profiler_factory.h index c223d7275d9..b8b8d67e1bf 100644 --- a/tensorflow/core/profiler/internal/profiler_factory.h +++ b/tensorflow/core/profiler/lib/profiler_factory.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_FACTORY_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_FACTORY_H_ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_FACTORY_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_FACTORY_H_ #include #include -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/profiler_options.pb.h" namespace tensorflow { @@ -35,4 +35,4 @@ void CreateProfilers(const ProfileOptions& options, } // namespace profiler } // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_FACTORY_H_ +#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_FACTORY_H_ diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/lib/profiler_interface.h similarity index 90% rename from tensorflow/core/profiler/internal/profiler_interface.h rename to tensorflow/core/profiler/lib/profiler_interface.h index 9fe85e38652..31a9b262ca9 100644 --- a/tensorflow/core/profiler/internal/profiler_interface.h +++ b/tensorflow/core/profiler/lib/profiler_interface.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_INTERFACE_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_INTERFACE_H_ #include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -54,4 +54,4 @@ class ProfilerInterface { } // namespace profiler } // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_ +#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_INTERFACE_H_ diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index d8958180e2a..f37cb12ebab 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/protobuf/config.pb.h" @@ -32,7 +32,7 @@ limitations under the License. #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/profiler/convert/post_process_single_host_xplane.h" -#include "tensorflow/core/profiler/internal/profiler_factory.h" +#include "tensorflow/core/profiler/lib/profiler_factory.h" #include "tensorflow/core/profiler/lib/profiler_lock.h" #include "tensorflow/core/profiler/utils/derived_timeline.h" #include "tensorflow/core/profiler/utils/group_events.h" @@ -147,6 +147,7 @@ ProfilerSession::ProfilerSession(ProfileOptions options) } ProfilerSession::~ProfilerSession() { + VLOG(1) << "Profiler session stopping."; for (auto& profiler : profilers_) { profiler->Stop().IgnoreError(); } diff --git a/tensorflow/core/profiler/lib/profiler_session.h b/tensorflow/core/profiler/lib/profiler_session.h index 93541f501ce..c179f1710c9 100644 --- a/tensorflow/core/profiler/lib/profiler_session.h +++ b/tensorflow/core/profiler/lib/profiler_session.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" #include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/protobuf/config.pb.h" diff --git a/tensorflow/core/profiler/lib/scoped_annotation.h b/tensorflow/core/profiler/lib/scoped_annotation.h index 2cad5fd4708..57c1731608a 100644 --- a/tensorflow/core/profiler/lib/scoped_annotation.h +++ b/tensorflow/core/profiler/lib/scoped_annotation.h @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" #if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/profiler/internal/annotation_stack.h" +#include "tensorflow/core/profiler/internal/cpu/annotation_stack.h" #endif namespace tensorflow { diff --git a/tensorflow/core/profiler/internal/scoped_annotation_test.cc b/tensorflow/core/profiler/lib/scoped_annotation_test.cc similarity index 98% rename from tensorflow/core/profiler/internal/scoped_annotation_test.cc rename to tensorflow/core/profiler/lib/scoped_annotation_test.cc index 50c1244b9ee..0e948cdbb02 100644 --- a/tensorflow/core/profiler/internal/scoped_annotation_test.cc +++ b/tensorflow/core/profiler/lib/scoped_annotation_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/profiler/internal/annotation_stack.h" +#include "tensorflow/core/profiler/internal/cpu/annotation_stack.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h index 526f6d5104d..ca3ea735765 100644 --- a/tensorflow/core/profiler/lib/traceme.h +++ b/tensorflow/core/profiler/lib/traceme.h @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" #if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/profiler/internal/traceme_recorder.h" +#include "tensorflow/core/profiler/internal/cpu/traceme_recorder.h" #endif #include "tensorflow/core/profiler/lib/traceme_encode.h" // IWYU pragma: export @@ -123,8 +123,11 @@ class TraceMe { // name_generator is templated, rather than a std::function to avoid // allocations std::function might make even if never called. // Example Usage: + // TraceMe trace_me([&]() { + // return StrCat("my_trace", id); + // } // TraceMe op_trace_me([&]() { - // return StrCat(op_name, ":", op_type); + // return TraceMeOp(op_name, op_type); // } // TraceMe trace_me_with_metadata([&value1]() { // return TraceMeEncode("my_trace", {{"key1", value1}, {"key2", 42}}); diff --git a/tensorflow/core/profiler/profiler_options.proto b/tensorflow/core/profiler/profiler_options.proto index 8b4fc3de6fc..7858f08c8ec 100644 --- a/tensorflow/core/profiler/profiler_options.proto +++ b/tensorflow/core/profiler/profiler_options.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow; +// Next ID: 11 message ProfileOptions { // Some default value of option are not proto3 default value. Use this version // to determine if we should use default option value instead of proto3 @@ -50,5 +51,32 @@ message ProfileOptions { // Whether serialize hlo_proto when XLA is used. (version >= 1) bool enable_hlo_proto = 7; - // next-field: 8 + // The local profiler starts profiling at this Unix timestamp in nanoseconds. + uint64 start_timestamp_ns = 8; + + // The local profiler collects `duration_ms` milliseconds of data. If the + // value is 0, profiling continues until interrupted. + uint64 duration_ms = 9; + + // Directory to save profile data to. No-op when empty. + string repository_path = 10; +} + +// Options for remote profiler session manager. +// Next ID: 5 +message RemoteProfilerSessionManagerOptions { + // Options for each local profiler. + ProfileOptions profiler_options = 1; + + // List of servers to profile. Supported formats: host:port. + repeated string service_addresses = 2; + + // Unix timestamp of when the session was started. + uint64 session_creation_timestamp_ns = 3; + + // Maximum time (in milliseconds) a profiling session manager waits for all + // profilers to finish after issuing gRPC request. If value is 0, session + // continues until interrupted. Otherwise, value must be greater than + // profiler_options.duration_ms. + uint64 max_session_duration_ms = 4; } diff --git a/tensorflow/core/profiler/profiler_service.proto b/tensorflow/core/profiler/profiler_service.proto index 69ccf04e304..f32b10f01bd 100644 --- a/tensorflow/core/profiler/profiler_service.proto +++ b/tensorflow/core/profiler/profiler_service.proto @@ -29,6 +29,7 @@ message ToolRequestOptions { bool save_to_repo = 3; } +// Next-ID: 9 message ProfileRequest { // In future, the caller will be able to customize when profiling starts and // stops. For now, it collects `duration_ms` milliseconds worth of data. @@ -61,7 +62,6 @@ message ProfileRequest { // In future, the caller will indicate which TF session is being profiled, and // only data relating to that program will be returned. For now, we assume // all activity during the profiling period is relevant. - // next-field: 9 } message ProfileToolData { @@ -73,6 +73,7 @@ message ProfileToolData { bytes data = 2; } +// Next-ID: 8 message ProfileResponse { // Data payload for each required tools. repeated ProfileToolData tool_data = 6; @@ -82,7 +83,6 @@ message ProfileResponse { bool empty_trace = 7; reserved 1, 2, 3, 4, 5; - // next-field: 8 } message TerminateRequest { @@ -92,6 +92,7 @@ message TerminateRequest { message TerminateResponse {} +// Next-ID: 4 message MonitorRequest { // Duration for which to profile between each update. uint64 duration_ms = 1; @@ -106,10 +107,9 @@ message MonitorRequest { int32 monitoring_level = 2; // True to display timestamp in monitoring result. bool timestamp = 3; - - // next-field: 4 } +// Next-ID: 11 message MonitorResponse { // Properly formatted string data that can be directly returned back to user. string data = 1; @@ -118,6 +118,4 @@ message MonitorResponse { ProfilerServiceMonitorResult monitor_result = 10; reserved 2, 3, 4, 5, 6, 7, 8, 9; - - // next-field: 11 } diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 81bc222e119..02c8f2b6ad8 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -65,6 +65,27 @@ tf_proto_library( visibility = [":friends"], ) +tf_proto_library( + name = "pod_stats_proto", + srcs = ["pod_stats.proto"], + cc_api_version = 2, + protodeps = [ + ":diagnostics_proto", + ], + visibility = [":friends"], +) + +tf_proto_library( + name = "pod_viewer_proto", + srcs = ["pod_viewer.proto"], + cc_api_version = 2, + protodeps = [ + ":diagnostics_proto", + ":pod_stats_proto", + ], + visibility = [":friends"], +) + tf_proto_library( name = "steps_db_proto", srcs = ["steps_db.proto"], @@ -148,3 +169,10 @@ tf_proto_library( cc_api_version = 2, visibility = [":friends"], ) + +tf_proto_library( + name = "tf_data_stats_proto", + srcs = ["tf_data_stats.proto"], + cc_api_version = 2, + visibility = [":friends"], +) diff --git a/tensorflow/core/profiler/protobuf/op_stats.proto b/tensorflow/core/profiler/protobuf/op_stats.proto index 500de69048a..53d6499b401 100644 --- a/tensorflow/core/profiler/protobuf/op_stats.proto +++ b/tensorflow/core/profiler/protobuf/op_stats.proto @@ -90,6 +90,17 @@ message RunEnvironment { uint32 host_trace_level = 12; } +// Next ID: 7 +message CoreDetails { + string hostname = 1; + uint32 device_ordinal = 2; // unique within host, TPU core only + uint32 core_num = 3; // unique within chip per core type + uint32 local_chip_id = 4; // unique within host + uint32 global_chip_id = 5; // unique within mesh + uint32 global_core_id = 6; // unique within mesh, TPU core only +} + +// Next ID: 12 // Operator Statistics. message OpStats { // The database for the op metrics collected from the host over the entire @@ -110,6 +121,8 @@ message OpStats { KernelStatsDb kernel_stats_db = 6; // Statistics for all tf-functions. TfFunctionDb tf_function_db = 8; + // A map from core ID to details. + map core_id_to_details = 11; // Error and warning messages for diagnosing profiling issues. Diagnostics diagnostics = 9; reserved 7; diff --git a/tensorflow/core/profiler/protobuf/pod_stats.proto b/tensorflow/core/profiler/protobuf/pod_stats.proto new file mode 100644 index 00000000000..abe55a5e8e0 --- /dev/null +++ b/tensorflow/core/profiler/protobuf/pod_stats.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +package tensorflow.profiler; + +import "tensorflow/core/profiler/protobuf/diagnostics.proto"; + +message StepBreakdownEvents { + int32 id = 1; + string name = 2; +} + +// A database of PodStats records. +message PodStatsDatabase { + // All PodStats records, one for each row in the PodStats tool. + repeated PodStatsRecord pod_stats_record = 1; + // Error and warning messages for diagnosing profiling issues. + Diagnostics diagnostics = 3; + // A map from event type number to event name string for step breakdown. + repeated StepBreakdownEvents step_breakdown_events = 4; + reserved 2; +} + +// Next ID: 20 +// There is one PodStatsRecord for each step traced on each compute node. +message PodStatsRecord { + // The host name where the trace was collected. + string host_name = 1; + // The TPU global chip id where the trace was collected. + int32 chip_id = 2; + // The TPU node id where the trace was collected. + int32 node_id = 3; + // The step number. + uint32 step_num = 4; + // The step duration in micro-seconds. + double total_duration_us = 5; + // Breakdown the durations for each event type in micro-seconds. + map step_breakdown_us = 19; + // Indicates the bottleneck out of the above mentioned metrics. + string bottleneck = 14; + reserved 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18; +} diff --git a/tensorflow/core/profiler/protobuf/pod_viewer.proto b/tensorflow/core/profiler/protobuf/pod_viewer.proto new file mode 100644 index 00000000000..3cfaa096c54 --- /dev/null +++ b/tensorflow/core/profiler/protobuf/pod_viewer.proto @@ -0,0 +1,125 @@ +// This proto describes the format of the output profile file from +// the Pod Viewer tool. +syntax = "proto3"; + +package tensorflow.profiler; + +import "tensorflow/core/profiler/protobuf/diagnostics.proto"; +import "tensorflow/core/profiler/protobuf/pod_stats.proto"; + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some ops (e.g., all-to-all). + repeated int64 replica_ids = 1; +} + +message AllReduceOpInfo { + // Name of this OP. + string name = 1; + // Number of instances that this OP occurred. + uint32 occurrences = 2; + // The time in microseconds spent in this OP (averaged across all of its + // occurrences). + double duration_us = 3; + // Byte size of data transferred. + uint64 data_size = 4; + // Replica groups. + repeated ReplicaGroup replica_groups = 5; + // Description (e.g. XLA expression). + string description = 6; +} + +// Result proto for information in a step across all cores. +message PodStatsMap { + // The (micro) step number. + uint32 step_num = 1; + // A map from core_id to PodStatsRecord. + map pod_stats_per_core = 2; + // A database of channel info. + repeated ChannelInfo channel_db = 3; + // A map from core ID to program replica id. Replica id map could change + // during a profile session, but should stay stable within a step. + map core_id_to_replica_id_map = 4; + // A database of all reduce ops. + repeated AllReduceOpInfo all_reduce_op_db = 5; +} + +// A sequence of PodStatsMap for each step. +message PodStatsSequence { + repeated PodStatsMap pod_stats_map = 1; +} + +// Next ID: 14 +// Information about a send and recv channel. +message ChannelInfo { + // Id of the channel. + int64 channel_id = 1; + // Core ids of send ops. + repeated uint32 src_core_ids = 11; + // Core ids of recv ops. + repeated uint32 dst_core_ids = 12; + // Byte size of the data transferred. + uint64 data_size = 4; + // Duration from the beginning of send to the end of recv-done in + // microseconds. + double duration_us = 5; + // Number of occurrences of a channel. + uint32 occurrences = 6; + // Percentage of the link BW utilized over the peak link BW. + double utilization = 7; + // A list of hlo names associated with this channel id. + repeated string hlo_names = 8; + // Duration from the beginning of the recv-done to the beginning of send in + // microseconds. If the recv-done op starts after the beginning of the send + // op, the delay is zero. + double send_delay_us = 9; + // Description (e.g. XLA expression). + string description = 13; + + reserved 2, 3, 10; +} + +message PodViewerSummary { + repeated string warnings = 1; +} + +// Next ID: 8 +// Topology graph draws all the cores in the system in a 2-D rectangle or +// 3-D cube. It is hierarchically grouped by host, chip and core. +message PodViewerTopology { + // Number of cores in the x dimension of the rectangle/cube. + int32 x_dimension = 1; + // Number of cores in the y dimension of the rectangle/cube. + int32 y_dimension = 2; + // Number of cores in the z dimension of the cube. + int32 z_dimension = 3; + // Number of cores in the x dimension of each host. + int32 host_x_stride = 4; + // Number of cores in the y dimension of each host. + int32 host_y_stride = 5; + // Number of cores in the z dimension of each host. + int32 host_z_stride = 6; + // Number of cores per chip. + int32 num_cores_per_chip = 7; +} + +// Next ID: 12 +// A database of pod viewer records. +message PodViewerDatabase { + // The type of device used. + string device_type = 10; + // Pod level stats for each step. + PodStatsSequence pod_stats_sequence = 3; + // Top level summary of pod viewer. + PodViewerSummary summary = 7; + // Error and warning messages for diagnosing profiling issues. + Diagnostics diagnostics = 8; + // A map from event type number to event name string for step breakdown. + repeated StepBreakdownEvents step_breakdown_events = 9; + // Info to draw the topology graph. + PodViewerTopology topology = 11; + + reserved 1, 2, 4, 5, 6; +} diff --git a/tensorflow/core/profiler/protobuf/tf_data_stats.proto b/tensorflow/core/profiler/protobuf/tf_data_stats.proto new file mode 100644 index 00000000000..f4edc4144d4 --- /dev/null +++ b/tensorflow/core/profiler/protobuf/tf_data_stats.proto @@ -0,0 +1,89 @@ +// This proto describes the format of the output profile file from +// the tf.data stats tool. +syntax = "proto3"; + +package tensorflow.profiler; + +// Stat for iterator. +message IteratorStat { + // Id of the iterator. + int64 id = 1; + // Start time of the iterator's GetNext in ps. + int64 start_time_ps = 2; + // Duration of the iterator's GetNext in ps. + int64 duration_ps = 3; + // Self time of the iterator's GetNext in ps. It takes account into async + // iterators. It is calculated by subtracting the time overlapped with its + // child iterator's duration from the iterator's duration. + int64 self_time_ps = 4; + // Whether it is blocking the root iterator. An async iterator's child + // iterator may not block its parent iterator if it is executed in advance and + // does not overlap with the parent iterator. + bool is_blocking = 5; + // The number of times this iterator is called. For example, a batch + // iterator's child iterator may be called multiple times. + int64 num_calls = 6; +} + +// Metadata for iterator. +message IteratorMetadata { + // Id of the iterator. + int64 id = 1; + // Id of the parent iterator. + int64 parent_id = 2; + // Name of the iterator. + string name = 3; + // Long name of the iterator. + string long_name = 6; + // Whether it is an async iterator. + bool is_async = 4; + // Parameters of the iterator (e.g., num_parallel_calls). + map params = 5; +} + +// Stat and metadata for input pipeline. +message InputPipelineStat { + // Id of the blocking iterator with the longest self time. + int64 bottleneck_iterator_id = 2; + // Stats per iterator. + map iterator_stats = 1; +} + +// Metadata for input pipeline. +message InputPipelineMetadata { + // The distribution strategy creates one "host" input pipeline which actually + // runs tf.data user code. Also, it creates a "device" input pipeline per + // device (e.g., TensorCore) which takes an element from the host input + // pipeline and transfers it to the device. + enum InputPipelineType { + HOST = 0; + DEVICE = 1; + } + // Id of the input pipeline which is set to the id of its root iterator. + int64 id = 1; + InputPipelineType type = 2; + string name = 4; + reserved 3; +} + +// Collection of metadata and stats of input pipeline. +message InputPipelineStats { + // Metadata of the input pipeline. + InputPipelineMetadata metadata = 1; + // Average latency (i.e., the root iterator's latency) of the input pipeline. + int64 avg_latency_ps = 3; + // Minimum latency of the input pipeline. + int64 min_latency_ps = 4; + // Maximum latency of the input pipeline. + int64 max_latency_ps = 5; + // Stats per call sorted by the root iterator's duration. + repeated InputPipelineStat stats = 2; +} + +// Collection of stats of tf.data input pipelines within a host. +message TfDataStats { + // Metadata per iterator. + map iterator_metadata = 2; + // Stats per input pipeline. + map input_pipelines = 1; +} diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD index 3531e0dae9d..e54ffe0f615 100644 --- a/tensorflow/core/profiler/rpc/BUILD +++ b/tensorflow/core/profiler/rpc/BUILD @@ -1,13 +1,15 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") # buildifier: disable=same-origin-load -load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_alias") +load( + "//tensorflow/core/profiler/builds:build_config.bzl", + "tf_profiler_alias", + "tf_profiler_copts", + "tf_profiler_pybind_cc_library_wrapper", +) package( - default_visibility = [ - "//tensorflow/core/profiler:internal", - ], + default_visibility = ["//tensorflow/core/profiler:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -27,11 +29,12 @@ exports_files( visibility = ["//tensorflow/core/profiler/rpc:__subpackages__"], ) +# Linked to pywrap_tensorflow. cc_library( name = "profiler_service_impl", srcs = ["profiler_service_impl.cc"], hdrs = ["profiler_service_impl.h"], - features = ["-layering_check"], + copts = tf_profiler_copts(), visibility = tf_external_workspace_visible( [ "//tensorflow/core/data/service:__pkg__", @@ -42,19 +45,31 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core/profiler:profiler_service_proto_cc", - "//tensorflow/core/profiler/lib:profiler_session_headers", + "//tensorflow/core/profiler/lib:profiler_session", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:file_system_utils", + "//tensorflow/core/profiler/utils:xplane_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", tf_grpc_cc_dependency(), ], ) +tf_profiler_pybind_cc_library_wrapper( + name = "profiler_server_for_pybind", + actual = ":profiler_server_impl", + visibility = ["//tensorflow/python/profiler/internal:__pkg__"], +) + cc_library( name = "profiler_server_impl", srcs = ["profiler_server.cc"], hdrs = ["profiler_server.h"], + copts = tf_profiler_copts(), visibility = [ "//tensorflow/compiler/xla/python:__pkg__", + "//tensorflow/core/profiler:internal", "//tensorflow/python:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", ], @@ -67,11 +82,3 @@ cc_library( ], alwayslink = True, ) - -tf_pybind_cc_library_wrapper( - name = "profiler_server_headers", - visibility = [ - "//tensorflow/python/profiler/internal:__pkg__", - ], - deps = [":profiler_server_impl"], -) diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD index 5702e83d7c4..efe9e6eec9a 100644 --- a/tensorflow/core/profiler/rpc/client/BUILD +++ b/tensorflow/core/profiler/rpc/client/BUILD @@ -1,11 +1,16 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") -load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core/profiler/builds:build_config.bzl", + "tf_profiler_copts", + "tf_profiler_pybind_cc_library_wrapper", +) # For platform specific build config load( "//tensorflow/core/platform:build_config.bzl", - "tf_profiler_client_deps", + "tf_protos_profiler_service", ) package( @@ -16,26 +21,30 @@ cc_library( name = "capture_profile", srcs = ["capture_profile.cc"], hdrs = ["capture_profile.h"], + copts = tf_profiler_copts(), visibility = [ "//tensorflow/python/profiler/internal:__pkg__", ], deps = [ + ":profiler_client_for_pybind", + ":remote_profiler_session_manager", ":save_profile", - "@com_google_absl//absl/strings", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/profiler:profiler_analysis_proto_cc", "//tensorflow/core/profiler:profiler_options_proto_cc", "//tensorflow/core/profiler:profiler_service_proto_cc", "//tensorflow/core/profiler/convert:xplane_to_profile_response", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "//tensorflow/core:lib_internal", - ] + tf_profiler_client_deps(), + "@com_google_absl//absl/strings", + ], ) cc_library( name = "save_profile", srcs = ["save_profile.cc"], hdrs = ["save_profile.h"], + copts = tf_profiler_copts(), visibility = ["//tensorflow/core/profiler:internal"], deps = [ "//tensorflow/core:framework", @@ -48,23 +57,108 @@ cc_library( ], ) +tf_profiler_pybind_cc_library_wrapper( + name = "profiler_client_for_pybind", + actual = ":profiler_client", +) + +cc_library( + name = "profiler_client", + hdrs = ["profiler_client.h"], + deps = [ + ":profiler_client_impl", + "//tensorflow/core:lib", + "//tensorflow/core/profiler:profiler_analysis_proto_cc", + "//tensorflow/core/profiler:profiler_service_proto_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +# Linked to pywrap_tensorflow to avoid ODR violation due to tf_grpc_cc_dependency(). cc_library( name = "profiler_client_impl", - srcs = ["profiler_client.cc"], - hdrs = ["profiler_client.h"], + srcs = [ + "profiler_client.cc", + "profiler_client.h", + ], + copts = tf_profiler_copts(), visibility = ["//tensorflow/python:__pkg__"], deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler:profiler_analysis_proto_cc", "//tensorflow/core/profiler:profiler_service_proto_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", tf_grpc_cc_dependency(), ], alwayslink = True, ) -tf_pybind_cc_library_wrapper( - name = "profiler_client_headers", - visibility = ["//tensorflow/python/profiler/internal:__pkg__"], - deps = [":profiler_client_impl"], +cc_library( + name = "profiler_client_test_util", + testonly = 1, + hdrs = ["profiler_client_test_util.h"], + deps = [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core/profiler:profiler_options_proto_cc", + "//tensorflow/core/profiler/lib:profiler_session", + "//tensorflow/core/profiler/rpc:profiler_server_impl", + ] + tf_protos_profiler_service(), +) + +tf_cc_test( + name = "profiler_client_test", + srcs = ["profiler_client_test.cc"], + tags = ["external"], # So that test suite reruns unconditionally. + deps = [ + ":profiler_client", + ":profiler_client_impl", # for oss + ":profiler_client_test_util", + "@com_google_absl//absl/time", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler/rpc:profiler_server_impl", + ] + tf_protos_profiler_service(), +) + +cc_library( + name = "remote_profiler_session_manager", + srcs = ["remote_profiler_session_manager.cc"], + hdrs = ["remote_profiler_session_manager.h"], + copts = tf_profiler_copts(), + deps = [ + ":profiler_client_for_pybind", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler:profiler_options_proto_cc", + "//tensorflow/core/profiler/utils:time_utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +tf_cc_test( + name = "remote_profiler_session_manager_test", + srcs = ["remote_profiler_session_manager_test.cc"], + tags = ["external"], # So that test suite reruns unconditionally. + deps = [ + ":profiler_client_impl", # for oss + ":profiler_client_test_util", + ":remote_profiler_session_manager", + "@com_google_absl//absl/time", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/profiler:profiler_options_proto_cc", + "//tensorflow/core/profiler/rpc:profiler_server_impl", + ] + tf_protos_profiler_service(), ) diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.cc b/tensorflow/core/profiler/rpc/client/capture_profile.cc index d7707aff5c2..3f59d2ba265 100644 --- a/tensorflow/core/profiler/rpc/client/capture_profile.cc +++ b/tensorflow/core/profiler/rpc/client/capture_profile.cc @@ -30,12 +30,16 @@ limitations under the License. #include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/profiler_service.pb.h" #include "tensorflow/core/profiler/rpc/client/profiler_client.h" +#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h" #include "tensorflow/core/profiler/rpc/client/save_profile.h" namespace tensorflow { namespace profiler { namespace { +using ::tensorflow::profiler::RemoteProfilerSessionManager; +using Response = ::tensorflow::profiler::RemoteProfilerSessionManager::Response; + constexpr uint64 kMaxEvents = 1000000; const absl::string_view kXPlanePb = "xplane.pb"; @@ -48,17 +52,18 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level, return request; } -ProfileRequest PopulateProfileRequest(int duration_ms, - const std::string& repository_root, - const std::string& session_id, - const std::string& host_name, - const ProfileOptions& opts) { +ProfileRequest PopulateProfileRequest( + absl::string_view repository_root, absl::string_view session_id, + absl::string_view host_name, + const RemoteProfilerSessionManagerOptions& options) { ProfileRequest request; - request.set_duration_ms(duration_ms); + // TODO(b/169976117) Remove duration from request. + request.set_duration_ms(options.profiler_options().duration_ms()); request.set_max_events(kMaxEvents); - request.set_repository_root(repository_root); - request.set_session_id(session_id); - request.set_host_name(host_name); + request.set_repository_root(repository_root.data(), repository_root.size()); + request.set_session_id(session_id.data(), session_id.size()); + request.set_host_name(host_name.data(), host_name.size()); + // These tools are only used by TPU profiler. request.add_tools("trace_viewer"); request.add_tools("op_profile"); request.add_tools("input_pipeline"); @@ -68,21 +73,26 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.add_tools("overview_page"); request.add_tools("pod_viewer"); request.add_tools("tensorflow_stats"); - *request.mutable_opts() = opts; + // XPlane tool is only used by OSS profiler and safely ignored by TPU + // profiler. + request.add_tools(kXPlanePb.data(), kXPlanePb.size()); + *request.mutable_opts() = options.profiler_options(); return request; } NewProfileSessionRequest PopulateNewProfileSessionRequest( - const std::string& service_addr, const std::string& repository_root, - const std::vector& hostnames, int duration_ms, - const std::string& session_id, const ProfileOptions& opts) { + absl::string_view repository_root, absl::string_view session_id, + const RemoteProfilerSessionManagerOptions& opts) { NewProfileSessionRequest request; - std::vector parts = absl::StrSplit(service_addr, ':'); - *request.mutable_request() = PopulateProfileRequest( - duration_ms, repository_root, session_id, parts[0], opts); - request.set_repository_root(repository_root); - request.set_session_id(session_id); - for (const auto& hostname : hostnames) { + std::vector parts = + absl::StrSplit(opts.service_addresses(0), ':'); + DCHECK(!parts.empty()); + + *request.mutable_request() = + PopulateProfileRequest(repository_root, session_id, parts[0], opts); + request.set_repository_root(repository_root.data(), repository_root.size()); + request.set_session_id(session_id.data(), session_id.size()); + for (const auto& hostname : opts.service_addresses()) { request.add_hosts(hostname); } return request; @@ -99,44 +109,40 @@ inline bool ShouldRetryTracing(Status status) { status.error_message() == "Stream removed"); } -// If the ProfileResponse has single 'xplane.pb' tool, convert the xplane to -// other tools and add in ProfileResponse. Otherwise, the ProfileResponse is -// already converted, simply return. -Status ConvertXSpaceToToolsInProfileResponse(const ProfileRequest& request, - ProfileResponse* response) { - if (response->tool_data_size() != 1) return Status::OK(); - if (response->tool_data(0).name() != kXPlanePb) return Status::OK(); - XSpace xspace; - xspace.ParseFromString(response->tool_data(0).data()); - TF_RETURN_IF_ERROR(ConvertXSpaceToProfileResponse(xspace, request, response)); - return Status::OK(); -} +Status Profile(const std::string& repository_root, + const std::string& session_id, + const RemoteProfilerSessionManagerOptions& opts) { + Status status; + // Host name will be overwritten by RemoteProfilerSessionManager later. + ProfileRequest request = PopulateProfileRequest(repository_root, session_id, + /*host_name=*/"", opts); + auto session = RemoteProfilerSessionManager::Create(opts, request, status); + TF_RETURN_IF_ERROR(status); + // Expect one or more service addresses. + DCHECK_GT(opts.service_addresses_size(), 0); + std::vector responses = session->WaitForCompletion(); + // Expect responses to have the same size as clients. + DCHECK_EQ(responses.size(), opts.service_addresses_size()); -Status Profile(const std::string& service_addr, - const std::string& repository_root, int duration_ms, - const std::string& session_id, const ProfileOptions& opts) { - std::vector parts = absl::StrSplit(service_addr, ':'); - ProfileRequest request = PopulateProfileRequest(duration_ms, repository_root, - session_id, parts[0], opts); - ProfileResponse response; - TF_RETURN_IF_ERROR(ProfileGrpc(service_addr, request, &response)); - - if (!response.empty_trace()) { - TF_RETURN_IF_ERROR( - ConvertXSpaceToToolsInProfileResponse(request, &response)); - TF_RETURN_IF_ERROR(SaveProfile(repository_root, session_id, - request.host_name(), response, &std::cout)); - // Print this at the end so that it's not buried in irrelevant LOG messages. - std::cout - << "NOTE: using the trace duration " << duration_ms << "ms.\n" - << "Set an appropriate duration (with --duration_ms) if you " - "don't see a full step in your trace or the captured trace is too " - "large." - << std::endl; + bool has_trace_data = false; + for (const auto& client_response : responses) { + ProfileResponse& response = *client_response.profile_response; + if (response.empty_trace()) { + LOG(WARNING) << "No trace event is collected from " + << client_response.service_address; + } else { + has_trace_data = true; + } + if (!client_response.status.ok()) { + LOG(WARNING) << client_response.service_address << " returned " + << client_response.status; + } } - if (response.empty_trace()) { - return Status(error::Code::UNAVAILABLE, "No trace event is collected"); + if (!has_trace_data) { + return Status(error::Code::UNAVAILABLE, + "No trace event was collected because there were no responses" + " from clients or the responses did not have trace data."); } return Status::OK(); } @@ -144,52 +150,47 @@ Status Profile(const std::string& service_addr, // Start a new profiling session that include all the hosts included in // hostnames, for the time interval of duration_ms. Possibly save the profiling // result in the directory specified by repository_root and session_id. -Status NewSession(const std::string& service_addr, - const std::string& repository_root, - const std::vector& hostnames, int duration_ms, - const std::string& session_id, const ProfileOptions& opts) { - NewProfileSessionRequest request = PopulateNewProfileSessionRequest( - service_addr, repository_root, hostnames, duration_ms, session_id, opts); +Status NewSession(absl::string_view repository_root, + absl::string_view session_id, + const RemoteProfilerSessionManagerOptions& opts) { + NewProfileSessionRequest request = + PopulateNewProfileSessionRequest(repository_root, session_id, opts); NewProfileSessionResponse response; - TF_RETURN_IF_ERROR(NewSessionGrpc(service_addr, request, &response)); + TF_RETURN_IF_ERROR( + NewSessionGrpc(opts.service_addresses(0), request, &response)); std::cout << "Profile session succeed for host(s):" - << absl::StrJoin(hostnames, ",") << std::endl; + << absl::StrJoin(opts.service_addresses(), ",") << std::endl; if (response.empty_trace()) { - return Status(error::Code::UNAVAILABLE, "No trace event is collected"); + return errors::Unavailable("No trace event is collected"); } return Status::OK(); } } // namespace -// Starts tracing on a single or multiple hosts and saves the result in the -// given logdir. If no trace was collected, retries tracing for -// num_tracing_attempts. -Status Trace(const std::string& service_addr, const std::string& logdir, - const std::string& workers_list, int duration_ms, - int num_tracing_attempts, const ProfileOptions& opts) { +Status Trace(const std::string& logdir, int num_tracing_attempts, + const RemoteProfilerSessionManagerOptions& opts, + bool is_cloud_tpu_session) { + DCHECK_GT(opts.profiler_options().duration_ms(), 0); + DCHECK(!opts.service_addresses().empty()); + // Use the current timestamp as the run name. std::string session_id = GetCurrentTimeStampAsString(); - std::vector hostnames; - if (!workers_list.empty()) { - hostnames = absl::StrSplit(workers_list, ','); - } + std::string repository_root = GetTensorBoardProfilePluginDir(logdir); + auto duration_ms = opts.profiler_options().duration_ms(); TF_RETURN_IF_ERROR(MaybeCreateEmptyEventFile(logdir)); - std::string repository_root = - profiler::GetTensorBoardProfilePluginDir(logdir); - Status status = Status::OK(); + Status status; int remaining_attempts = num_tracing_attempts; while (true) { std::cout << "Starting to trace for " << duration_ms << " ms. " << "Remaining attempt(s): " << --remaining_attempts << std::endl; - if (hostnames.empty()) { - status = - Profile(service_addr, repository_root, duration_ms, session_id, opts); + + if (is_cloud_tpu_session) { + status = NewSession(repository_root, session_id, opts); } else { - status = NewSession(service_addr, repository_root, hostnames, duration_ms, - session_id, opts); + status = Profile(repository_root, session_id, opts); } if (remaining_attempts <= 0 || status.ok() || !ShouldRetryTracing(status)) break; @@ -223,11 +224,10 @@ Status ExportToTensorBoard(const XSpace& xspace, const std::string& logdir) { ProfileResponse response; ProfileRequest request = PopulateProfileRequest( - /*duration_ms=*/0, GetTensorBoardProfilePluginDir(logdir), - GetCurrentTimeStampAsString(), port::Hostname(), /*opts=*/{}); + GetTensorBoardProfilePluginDir(logdir), GetCurrentTimeStampAsString(), + port::Hostname(), /*options=*/{}); TF_RETURN_IF_ERROR( ConvertXSpaceToProfileResponse(xspace, request, &response)); - std::stringstream ss; // Record LOG messages. TF_RETURN_IF_ERROR(SaveProfile(request.repository_root(), request.session_id(), request.host_name(), diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.h b/tensorflow/core/profiler/rpc/client/capture_profile.h index 771f1fee722..cb9183e28a7 100644 --- a/tensorflow/core/profiler/rpc/client/capture_profile.h +++ b/tensorflow/core/profiler/rpc/client/capture_profile.h @@ -36,12 +36,12 @@ Status Monitor(const std::string& service_addr, int duration_ms, int monitoring_level, bool display_timestamp, std::string* result); -// Starts tracing on a single or multiple hosts and saves the result in the -// given logdir. If no trace was collected, retries tracing for -// num_tracing_attempts. -Status Trace(const std::string& service_addr, const std::string& logdir, - const std::string& workers_list, int duration_ms, - int num_tracing_attempts, const ProfileOptions& opts); +// Starts tracing on a single or multiple hosts. Each host will save the result +// in the given logdir. If no trace was collected, retries tracing for +// num_tracing_attempts. Assumes that options have been validated. +Status Trace(const std::string& logdir, int num_tracing_attempts, + const RemoteProfilerSessionManagerOptions& opts, + bool is_cloud_tpu_session); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/client/profiler_client.cc b/tensorflow/core/profiler/rpc/client/profiler_client.cc index 5eec02b2e95..f46075e8c44 100644 --- a/tensorflow/core/profiler/rpc/client/profiler_client.cc +++ b/tensorflow/core/profiler/rpc/client/profiler_client.cc @@ -17,6 +17,9 @@ limitations under the License. #include #include "grpcpp/grpcpp.h" +#include "absl/memory/memory.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" @@ -34,50 +37,126 @@ inline Status FromGrpcStatus(const ::grpc::Status& s) { } template -std::unique_ptr CreateStub(const std::string& service_addr) { +std::unique_ptr CreateStub( + const std::string& service_address) { ::grpc::ChannelArguments channel_args; channel_args.SetMaxReceiveMessageSize(std::numeric_limits::max()); // Default URI prefix is "dns:///" if not provided. auto channel = ::grpc::CreateCustomChannel( - service_addr, ::grpc::InsecureChannelCredentials(), channel_args); + service_address, ::grpc::InsecureChannelCredentials(), channel_args); if (!channel) { - LOG(ERROR) << "Unable to create channel" << service_addr; + LOG(ERROR) << "Unable to create channel" << service_address; } return T::NewStub(channel); } } // namespace -Status ProfileGrpc(const std::string& service_addr, +Status ProfileGrpc(const std::string& service_address, const ProfileRequest& request, ProfileResponse* response) { ::grpc::ClientContext context; std::unique_ptr stub = - CreateStub(service_addr); + CreateStub(service_address); TF_RETURN_IF_ERROR( FromGrpcStatus(stub->Profile(&context, request, response))); return Status::OK(); } -Status NewSessionGrpc(const std::string& service_addr, +Status NewSessionGrpc(const std::string& service_address, const NewProfileSessionRequest& request, NewProfileSessionResponse* response) { ::grpc::ClientContext context; std::unique_ptr stub = - CreateStub(service_addr); + CreateStub(service_address); TF_RETURN_IF_ERROR( FromGrpcStatus(stub->NewSession(&context, request, response))); return Status::OK(); } -Status MonitorGrpc(const std::string& service_addr, +Status MonitorGrpc(const std::string& service_address, const MonitorRequest& request, MonitorResponse* response) { ::grpc::ClientContext context; std::unique_ptr stub = - CreateStub(service_addr); + CreateStub(service_address); TF_RETURN_IF_ERROR( FromGrpcStatus(stub->Monitor(&context, request, response))); return Status::OK(); } +/*static*/ std::unique_ptr RemoteProfilerSession::Create( + std::string service_address, absl::Time deadline, + ProfileRequest profile_request) { + auto instance = absl::WrapUnique(new RemoteProfilerSession( + std::move(service_address), deadline, std::move(profile_request))); + instance->ProfileAsync(); + return instance; +} + +RemoteProfilerSession::RemoteProfilerSession(std::string service_address, + absl::Time deadline, + ProfileRequest profile_request) + : response_(absl::make_unique()), + service_address_(std::move(service_address)), + stub_(CreateStub(service_address_)), + deadline_(deadline), + profile_request_(std::move(profile_request)) { + response_->set_empty_trace(true); +} + +RemoteProfilerSession::~RemoteProfilerSession() { + Status dummy; + WaitForCompletion(dummy); + grpc_context_.TryCancel(); +} + +void RemoteProfilerSession::ProfileAsync() { + LOG(INFO) << "Asynchronous gRPC Profile() to " << service_address_; + grpc_context_.set_deadline(absl::ToChronoTime(deadline_)); + VLOG(1) << "Deadline set to " << deadline_; + rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_); + // Connection failure will create lame channel whereby grpc_status_ will be an + // error. + rpc_->Finish(response_.get(), &grpc_status_, + static_cast(&status_on_completion_)); + VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now(); +} + +std::unique_ptr RemoteProfilerSession::WaitForCompletion( + Status& out_status) { + if (!response_) { + out_status = errors::FailedPrecondition( + "WaitForCompletion must only be called once."); + return nullptr; + } + LOG(INFO) << "Waiting for completion."; + + void* got_tag = nullptr; + bool ok = false; + // Next blocks until there is a response in the completion queue. Expect the + // completion queue to have exactly a single response because deadline is set + // and completion queue is only drained once at destruction time. + bool success = cq_.Next(&got_tag, &ok); + if (!success || !ok || got_tag == nullptr) { + out_status = + errors::Internal("Missing or invalid event from completion queue."); + return nullptr; + } + + VLOG(1) << "Writing out status."; + // For the event read from the completion queue, expect that got_tag points to + // the memory location of status_on_completion. + DCHECK_EQ(got_tag, &status_on_completion_); + // tagged status points to pre-allocated memory which is okay to overwrite. + status_on_completion_.Update(FromGrpcStatus(grpc_status_)); + if (status_on_completion_.code() == error::DEADLINE_EXCEEDED) { + LOG(WARNING) << status_on_completion_; + } else if (!status_on_completion_.ok()) { + LOG(ERROR) << status_on_completion_; + } + + out_status = status_on_completion_; + return std::move(response_); +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/client/profiler_client.h b/tensorflow/core/profiler/rpc/client/profiler_client.h index d946d607e55..84c71a55d5f 100644 --- a/tensorflow/core/profiler/rpc/client/profiler_client.h +++ b/tensorflow/core/profiler/rpc/client/profiler_client.h @@ -17,6 +17,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_H_ #define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_H_ +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/time/time.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/profiler_analysis.grpc.pb.h" #include "tensorflow/core/profiler/profiler_service.grpc.pb.h" @@ -24,16 +29,67 @@ limitations under the License. namespace tensorflow { namespace profiler { -Status ProfileGrpc(const std::string& service_addr, +// Note that tensorflow/tools/def_file_filter/symbols_pybind.txt is incompatible +// with absl::string_view. +Status ProfileGrpc(const std::string& service_address, const ProfileRequest& request, ProfileResponse* response); -Status NewSessionGrpc(const std::string& service_addr, +Status NewSessionGrpc(const std::string& service_address, const NewProfileSessionRequest& request, NewProfileSessionResponse* response); -Status MonitorGrpc(const std::string& service_addr, +Status MonitorGrpc(const std::string& service_address, const MonitorRequest& request, MonitorResponse* response); +class RemoteProfilerSession { + public: + // Creates an instance and starts a remote profiling session immediately. + // This is a non-blocking call and does not wait for a response. + // Response must outlive the instantiation. + static std::unique_ptr Create( + std::string service_address, absl::Time deadline, + ProfileRequest profile_request); + + // Not copyable or movable. + RemoteProfilerSession(const RemoteProfilerSession&) = delete; + RemoteProfilerSession operator=(const RemoteProfilerSession&) = delete; + + ~RemoteProfilerSession(); + + absl::string_view GetServiceAddress() const { return service_address_; } + + // Blocks until a response has been received or until deadline expiry, + // whichever is first. Subsequent calls after the first will yield nullptr and + // an error status. + std::unique_ptr WaitForCompletion(Status& out_status); + + private: + explicit RemoteProfilerSession(std::string service_addr, absl::Time deadline, + ProfileRequest profile_request); + + // Starts a remote profiling session. This is a non-blocking call. + // Will be called exactly once during instantiation. + // RPC will write to response.profile_response eagerly. However, since + // response.status requires a conversion from grpc::Status, it can only be + // evaluated lazily at WaitForCompletion() time. + void ProfileAsync(); + + Status status_on_completion_; + std::unique_ptr response_; + // Client address and connection attributes. + std::string service_address_; + std::unique_ptr stub_; + absl::Time deadline_; + ::grpc::ClientContext grpc_context_; + std::unique_ptr<::grpc::ClientAsyncResponseReader> rpc_; + ::grpc::Status grpc_status_ = ::grpc::Status::OK; + + // Asynchronous completion queue states. + ::grpc::CompletionQueue cq_; + + ProfileRequest profile_request_; +}; + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/client/profiler_client_test.cc b/tensorflow/core/profiler/rpc/client/profiler_client_test.cc new file mode 100644 index 00000000000..ae42c5cacd0 --- /dev/null +++ b/tensorflow/core/profiler/rpc/client/profiler_client_test.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/profiler/rpc/client/profiler_client.h" + +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/profiler_service.pb.h" +#include "tensorflow/core/profiler/rpc/client/profiler_client_test_util.h" + +namespace tensorflow { +namespace profiler { +namespace { + +using ::tensorflow::profiler::test::DurationApproxLess; +using ::tensorflow::profiler::test::DurationNear; +using ::tensorflow::profiler::test::StartServer; + +TEST(RemoteProfilerSession, Simple) { + absl::Duration duration = absl::Milliseconds(10); + ProfileRequest request; + std::string service_addr; + auto server = StartServer(duration, &service_addr, &request); + absl::Duration grace = absl::Seconds(1); + absl::Duration max_duration = duration + grace; + absl::Time approx_start = absl::Now(); + absl::Time deadline = approx_start + max_duration; + + auto remote_session = + RemoteProfilerSession::Create(service_addr, deadline, request); + + Status status; + auto response = remote_session->WaitForCompletion(status); + absl::Duration elapsed = absl::Now() - approx_start; + // At end of session this evaluates to true still. + EXPECT_TRUE(status.ok()); + // True because there was no workload traced and subsequently no XEvents. + EXPECT_TRUE(response->empty_trace()); + // XSpaces are serialized and not returned as tools in ProfileResponse. + EXPECT_EQ(response->tool_data_size(), 0); + EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); +} + +TEST(RemoteProfilerSession, WaitNotCalled) { + absl::Duration duration = absl::Milliseconds(10); + ProfileRequest request; + std::string service_addr; + auto server = StartServer(duration, &service_addr, &request); + absl::Duration grace = absl::Seconds(1); + absl::Duration max_duration = duration + grace; + absl::Time approx_start = absl::Now(); + absl::Time deadline = approx_start + max_duration; + + auto remote_session = + RemoteProfilerSession::Create(service_addr, deadline, request); + absl::Duration elapsed = absl::Now() - approx_start; + + EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); +} + +TEST(RemoteProfilerSession, Timeout) { + absl::Duration duration = absl::Milliseconds(10); + ProfileRequest request; + std::string service_addr; + auto server = StartServer(duration, &service_addr, &request); + // Expect this to fail immediately since deadline was set to the past, + auto remote_session = + RemoteProfilerSession::Create(service_addr, absl::Now(), request); + Status status; + auto response = remote_session->WaitForCompletion(status); + // At end of session we will have a timeout error. + EXPECT_TRUE(errors::IsDeadlineExceeded(status)); + // True because there was no workload traced and subsequently no XEvents. + EXPECT_TRUE(response->empty_trace()); + // XSpaces are serialized and not returned as tools in ProfileResponse. + EXPECT_EQ(response->tool_data_size(), 0); +} + +TEST(RemoteProfilerSession, LongDeadline) { + absl::Duration duration = absl::Milliseconds(10); + ProfileRequest request; + std::string service_addr; + auto server = StartServer(duration, &service_addr, &request); + + absl::Time approx_start = absl::Now(); + absl::Duration grace = absl::Seconds(1000); + absl::Duration max_duration = duration + grace; + const absl::Time deadline = approx_start + max_duration; + + auto remote_session = + RemoteProfilerSession::Create(service_addr, deadline, request); + Status status; + auto response = remote_session->WaitForCompletion(status); + absl::Duration elapsed = absl::Now() - approx_start; + // At end of session this evaluates to true still. + EXPECT_TRUE(status.ok()); + // True because there was no workload traced and subsequently no XEvents. + EXPECT_TRUE(response->empty_trace()); + // XSpaces are serialized and not returned as tools in ProfileResponse. + EXPECT_EQ(response->tool_data_size(), 0); + // Elapsed time is near profiling duration despite long grace period. + EXPECT_THAT(elapsed, DurationNear(duration)); +} + +TEST(RemoteProfilerSession, LongDuration) { + absl::Duration duration = absl::Seconds(3); + ProfileRequest request; + std::string service_addr; + auto server = StartServer(duration, &service_addr, &request); + + absl::Time approx_start = absl::Now(); + // Empirically determined value. + absl::Duration grace = absl::Seconds(20); + absl::Duration max_duration = duration + grace; + const absl::Time deadline = approx_start + max_duration; + + auto remote_session = + RemoteProfilerSession::Create(service_addr, deadline, request); + Status status; + auto response = remote_session->WaitForCompletion(status); + absl::Duration elapsed = absl::Now() - approx_start; + // At end of session this evaluates to true still. + EXPECT_TRUE(status.ok()); + // True because there was no workload traced and subsequently no XEvents. + EXPECT_TRUE(response->empty_trace()); + // XSpaces are serialized and not returned as tools in ProfileResponse. + EXPECT_EQ(response->tool_data_size(), 0); + // Elapsed time takes longer to complete for larger traces. + EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/client/profiler_client_test_util.h b/tensorflow/core/profiler/rpc/client/profiler_client_test_util.h new file mode 100644 index 00000000000..da10dcc6193 --- /dev/null +++ b/tensorflow/core/profiler/rpc/client/profiler_client_test_util.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GRPC client to perform on-demand profiling + +#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_TEST_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_TEST_H_ + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/profiler_session.h" +#include "tensorflow/core/profiler/profiler_options.pb.h" +#include "tensorflow/core/profiler/profiler_service.pb.h" +#include "tensorflow/core/profiler/rpc/profiler_server.h" + +namespace tensorflow { +namespace profiler { +namespace test { + +inline std::unique_ptr StartServer( + absl::Duration duration, std::string* service_address, + ProfileRequest* request = nullptr) { + auto profiler_server = absl::make_unique(); + int port = testing::PickUnusedPortOrDie(); + profiler_server->StartProfilerServer(port); + + DCHECK(service_address); + *service_address = absl::StrCat("localhost:", port); + + if (request) { + request->set_duration_ms(absl::ToInt64Milliseconds(duration)); + request->set_max_events(10000); + *request->mutable_opts() = ProfilerSession::DefaultOptions(); + request->mutable_opts()->set_duration_ms( + absl::ToInt64Milliseconds(duration)); + request->set_session_id("test_session"); + request->set_host_name(*service_address); + request->set_repository_root(testing::TmpDir()); + } + + LOG(INFO) << "Started " << *service_address << " at " << absl::Now(); + LOG(INFO) << "Duration: " << duration; + + return profiler_server; +} + +inline ::testing::Matcher DurationNear( + const absl::Duration duration, absl::Duration epsilon = absl::Seconds(1)) { + return ::testing::AllOf(::testing::Ge(duration - epsilon), + ::testing::Le(duration + epsilon)); +} + +inline ::testing::Matcher DurationApproxLess( + const absl::Duration duration, absl::Duration epsilon = absl::Seconds(1)) { + return ::testing::Le(duration + epsilon); +} + +} // namespace test +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_TEST_H_ diff --git a/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.cc b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.cc new file mode 100644 index 00000000000..2eeffa292f0 --- /dev/null +++ b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.cc @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/rpc/client/profiler_client.h" +#include "tensorflow/core/profiler/utils/time_utils.h" + +namespace tensorflow { +namespace profiler { + +/*static*/ std::unique_ptr +RemoteProfilerSessionManager::Create( + const RemoteProfilerSessionManagerOptions& options, + const ProfileRequest& request, tensorflow::Status& out_status, + AddressResolver resolver) { + VLOG(1) << "Creating a RemoteProfilerSessionManager."; + auto session_manager = absl::WrapUnique( + new RemoteProfilerSessionManager(options, request, resolver)); + out_status = session_manager->Init(); + if (!out_status.ok()) { + return nullptr; + } + return session_manager; +} + +RemoteProfilerSessionManager::RemoteProfilerSessionManager( + RemoteProfilerSessionManagerOptions options, ProfileRequest request, + AddressResolver resolver) + : options_(std::move(options)), request_(std::move(request)) { + if (resolver) { + resolver_ = std::move(resolver); + } else { + resolver_ = [](absl::string_view addr) { return std::string(addr); }; + } +} + +RemoteProfilerSessionManager::~RemoteProfilerSessionManager() { + VLOG(2) << "Destroying RemoteProfilerSessionManager."; +} + +Status RemoteProfilerSessionManager::Init() { + mutex_lock lock(mutex_); + VLOG(1) << "SessionManager initializing."; + + const absl::Time session_created_ts = + absl::FromUnixNanos(options_.session_creation_timestamp_ns()); + const absl::Time deadline = + session_created_ts + + absl::Milliseconds(options_.max_session_duration_ms()); + + LOG(INFO) << "Deadline set to " << deadline + << " because max_session_duration_ms was " + << options_.max_session_duration_ms() + << " and session_creation_timestamp_ns was " + << options_.session_creation_timestamp_ns() << " [" + << session_created_ts << "]"; + + // Prepare a list of clients. + clients_.reserve(options_.service_addresses_size()); + + for (auto& service_address : options_.service_addresses()) { + std::string resolved_service_address = resolver_(service_address); + ProfileRequest request = request_; + request.set_host_name(resolved_service_address); + + // Creation also issues Profile RPC asynchronously. + auto client = RemoteProfilerSession::Create( + std::move(resolved_service_address), deadline, std::move(request)); + clients_.push_back(std::move(client)); + } + + LOG(INFO) << "Issued Profile gRPC to " << clients_.size() << " clients"; + return Status::OK(); +} + +std::vector +RemoteProfilerSessionManager::WaitForCompletion() { + mutex_lock lock(mutex_); + std::vector remote_responses( + clients_.size()); + + for (int32 idx = 0; idx < clients_.size(); ++idx) { + auto& remote_response = remote_responses[idx]; + auto* client = clients_[idx].get(); + remote_response.profile_response = + client->WaitForCompletion(remote_response.status); + remote_response.service_address = std::string(client->GetServiceAddress()); + } + return remote_responses; +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h new file mode 100644 index 00000000000..1dc240ad530 --- /dev/null +++ b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/rpc/client/profiler_client.h" + +namespace tensorflow { +namespace profiler { + +using AddressResolver = std::function; + +// Manages one or more remote profiling sessions. +class RemoteProfilerSessionManager { + public: + struct Response { + std::string service_address; + std::unique_ptr profile_response; + Status status; + }; + // Instantiates a collection of RemoteProfilerSessions starts profiling on + // each of them immediately. Assumes that options have already been validated. + static std::unique_ptr Create( + const RemoteProfilerSessionManagerOptions& options, + const ProfileRequest& request, tensorflow::Status& out_status, + AddressResolver resolver = nullptr); + + // Awaits for responses from remote profiler sessions and returns them as a + // list. Subsequent calls beyond the first will yield a list of errors. + std::vector WaitForCompletion(); + + // Not copyable or movable. + RemoteProfilerSessionManager(const RemoteProfilerSessionManager&) = delete; + RemoteProfilerSessionManager operator=(const RemoteProfilerSessionManager&) = + delete; + + ~RemoteProfilerSessionManager(); + + private: + explicit RemoteProfilerSessionManager( + RemoteProfilerSessionManagerOptions options, ProfileRequest request, + AddressResolver resolver); + + // Initialization of all client contexts. + Status Init(); + + mutex mutex_; + // Remote profiler session options. + RemoteProfilerSessionManagerOptions options_ TF_GUARDED_BY(mutex_); + ProfileRequest request_ TF_GUARDED_BY(mutex_); + // List of clients, each connects to a profiling service. + std::vector> clients_ + TF_GUARDED_BY(mutex_); + // Resolves an address into a format that gRPC understands. + AddressResolver resolver_ TF_GUARDED_BY(mutex_); +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_ diff --git a/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager_test.cc b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager_test.cc new file mode 100644 index 00000000000..c6d2d640eae --- /dev/null +++ b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h" + +#include +#include +#include + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/profiler_options.pb.h" +#include "tensorflow/core/profiler/profiler_service.pb.h" +#include "tensorflow/core/profiler/rpc/client/profiler_client_test_util.h" + +namespace tensorflow { +namespace profiler { +namespace { + +using ::tensorflow::profiler::test::DurationApproxLess; +using ::tensorflow::profiler::test::DurationNear; +using ::tensorflow::profiler::test::StartServer; +using ::tensorflow::testing::TmpDir; +using Response = tensorflow::profiler::RemoteProfilerSessionManager::Response; + +// Copied from capture_profile to not introduce a dependency. +ProfileRequest PopulateProfileRequest( + absl::string_view repository_root, absl::string_view session_id, + absl::string_view host_name, + const RemoteProfilerSessionManagerOptions& options) { + constexpr uint64 kMaxEvents = 1000000; + const absl::string_view kXPlanePb = "xplane.pb"; + ProfileRequest request; + // TODO(b/169976117) Remove duration from request. + request.set_duration_ms(options.profiler_options().duration_ms()); + request.set_max_events(kMaxEvents); + request.set_repository_root(repository_root.data(), repository_root.size()); + request.set_session_id(session_id.data(), session_id.size()); + request.set_host_name(host_name.data(), host_name.size()); + // XPlane tool is only used by OSS profiler and safely ignored by TPU + // profiler. + request.add_tools(kXPlanePb.data(), kXPlanePb.size()); + *request.mutable_opts() = options.profiler_options(); + return request; +} + +TEST(RemoteProfilerSessionManagerTest, Simple) { + absl::Duration duration = absl::Milliseconds(30); + RemoteProfilerSessionManagerOptions options; + *options.mutable_profiler_options() = + tensorflow::ProfilerSession::DefaultOptions(); + options.mutable_profiler_options()->set_duration_ms( + absl::ToInt64Milliseconds(duration)); + + std::string service_address; + auto server = StartServer(duration, &service_address); + options.add_service_addresses(service_address); + absl::Time approx_start = absl::Now(); + absl::Duration grace = absl::Seconds(1); + absl::Duration max_duration = duration + grace; + options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration)); + options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start)); + + ProfileRequest request = + PopulateProfileRequest(TmpDir(), "session_id", service_address, options); + Status status; + auto sessions = + RemoteProfilerSessionManager::Create(options, request, status); + EXPECT_TRUE(status.ok()); + std::vector responses = sessions->WaitForCompletion(); + absl::Duration elapsed = absl::Now() - approx_start; + ASSERT_EQ(responses.size(), 1); + EXPECT_TRUE(responses.back().status.ok()); + EXPECT_TRUE(responses.back().profile_response->empty_trace()); + EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0); + EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); +} + +TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) { + absl::Duration duration = absl::Milliseconds(30); + RemoteProfilerSessionManagerOptions options; + *options.mutable_profiler_options() = + tensorflow::ProfilerSession::DefaultOptions(); + options.mutable_profiler_options()->set_duration_ms( + absl::ToInt64Milliseconds(duration)); + + std::string service_address; + auto server = StartServer(duration, &service_address); + options.add_service_addresses(service_address); + absl::Duration grace = absl::Seconds(1); + absl::Duration max_duration = duration + grace; + options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration)); + // This will create a deadline in the past. + options.set_session_creation_timestamp_ns(0); + + absl::Time approx_start = absl::Now(); + ProfileRequest request = + PopulateProfileRequest(TmpDir(), "session_id", service_address, options); + Status status; + auto sessions = + RemoteProfilerSessionManager::Create(options, request, status); + EXPECT_TRUE(status.ok()); + std::vector responses = sessions->WaitForCompletion(); + absl::Duration elapsed = absl::Now() - approx_start; + EXPECT_THAT(elapsed, DurationNear(absl::Seconds(0))); + ASSERT_EQ(responses.size(), 1); + EXPECT_TRUE(errors::IsDeadlineExceeded(responses.back().status)); + EXPECT_TRUE(responses.back().profile_response->empty_trace()); + EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0); +} + +TEST(RemoteProfilerSessionManagerTest, LongSession) { + absl::Duration duration = absl::Seconds(3); + RemoteProfilerSessionManagerOptions options; + *options.mutable_profiler_options() = + tensorflow::ProfilerSession::DefaultOptions(); + options.mutable_profiler_options()->set_duration_ms( + absl::ToInt64Milliseconds(duration)); + + std::string service_address; + auto server = StartServer(duration, &service_address); + options.add_service_addresses(service_address); + absl::Time approx_start = absl::Now(); + // Empirically determined value. + absl::Duration grace = absl::Seconds(20); + absl::Duration max_duration = duration + grace; + options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration)); + options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start)); + + ProfileRequest request = + PopulateProfileRequest(TmpDir(), "session_id", service_address, options); + Status status; + auto sessions = + RemoteProfilerSessionManager::Create(options, request, status); + EXPECT_TRUE(status.ok()); + std::vector responses = sessions->WaitForCompletion(); + absl::Duration elapsed = absl::Now() - approx_start; + ASSERT_EQ(responses.size(), 1); + EXPECT_TRUE(responses.back().status.ok()); + EXPECT_TRUE(responses.back().profile_response->empty_trace()); + EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0); + EXPECT_THAT(elapsed, DurationApproxLess(max_duration)); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/oss/BUILD b/tensorflow/core/profiler/rpc/oss/BUILD index 7e30302b718..cd455e225b7 100644 --- a/tensorflow/core/profiler/rpc/oss/BUILD +++ b/tensorflow/core/profiler/rpc/oss/BUILD @@ -1,10 +1,9 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( - default_visibility = [ - "//tensorflow/core/profiler:internal", - ], + default_visibility = ["//tensorflow/core/profiler:internal"], licenses = ["notice"], # Apache 2.0 ) @@ -14,6 +13,7 @@ cc_library( "grpc.cc", "//tensorflow/core/profiler/rpc:grpc.h", ], + copts = tf_profiler_copts(), deps = [ tf_grpc_cc_dependency(), ], diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc index fa8810cd535..54eedb65fa0 100644 --- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc +++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc @@ -20,6 +20,7 @@ limitations under the License. #include "grpcpp/support/status.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/errors.h" @@ -27,11 +28,12 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/profiler/internal/profiler_interface.h" #include "tensorflow/core/profiler/lib/profiler_session.h" #include "tensorflow/core/profiler/profiler_service.grpc.pb.h" #include "tensorflow/core/profiler/profiler_service.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/file_system_utils.h" +#include "tensorflow/core/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { @@ -39,15 +41,31 @@ namespace { const absl::string_view kXPlanePb = "xplane.pb"; -Status CollectDataToResponse(const ProfileRequest& req, - ProfilerSession* profiler, - ProfileResponse* response) { - profiler::XSpace xspace; +// Collects data in XSpace format. The data is saved to a repository +// unconditionally. +Status CollectDataToRepository(const ProfileRequest& request, + ProfilerSession* profiler, + ProfileResponse* response) { + response->set_empty_trace(true); + // Read the profile data into xspace. + XSpace xspace; TF_RETURN_IF_ERROR(profiler->CollectData(&xspace)); - auto* tool_data = response->add_tool_data(); - tool_data->set_name(kXPlanePb.data(), kXPlanePb.size()); - xspace.SerializeToString(tool_data->mutable_data()); - return Status::OK(); + VLOG(3) << "Collected XSpace to repository."; + response->set_empty_trace(IsEmpty(xspace)); + + std::string log_dir_path = + ProfilerJoinPath(request.repository_root(), request.session_id()); + VLOG(1) << "Creating " << log_dir_path; + TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(log_dir_path)); + + std::string file_name = absl::StrCat(request.host_name(), ".", kXPlanePb); + // Windows file names do not support colons. + absl::StrReplaceAll({{":", "_"}}, &file_name); + // Dumps profile data to //_. + std::string out_path = ProfilerJoinPath(log_dir_path, file_name); + LOG(INFO) << "Collecting XSpace to repository: " << out_path; + + return WriteBinaryProto(Env::Default(), out_path, xspace); } class ProfilerServiceImpl : public grpc::ProfilerService::Service { @@ -69,7 +87,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { } Env* env = Env::Default(); - for (uint64 i = 0; i < req->duration_ms(); ++i) { + for (uint64 i = 0; i < req->opts().duration_ms(); ++i) { env->SleepForMicroseconds(EnvTime::kMillisToMicros); if (ctx->IsCancelled()) { return ::grpc::Status::CANCELLED; @@ -81,7 +99,7 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service { } } - status = CollectDataToResponse(*req, profiler.get(), response); + status = CollectDataToRepository(*req, profiler.get(), response); if (!status.ok()) { return ::grpc::Status(::grpc::StatusCode::INTERNAL, status.error_message()); @@ -117,5 +135,4 @@ std::unique_ptr CreateProfilerService() { } } // namespace profiler - } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 68697d32b9c..f7c6d5496d5 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( default_visibility = ["//tensorflow/core/profiler:internal"], @@ -17,7 +18,9 @@ cc_library( name = "diagnostics", srcs = ["diagnostics.cc"], hdrs = ["diagnostics.h"], + copts = tf_profiler_copts(), deps = [ + "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:diagnostics_proto_cc", "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", @@ -30,6 +33,7 @@ cc_library( name = "event_span", srcs = ["event_span.cc"], hdrs = ["event_span.h"], + copts = tf_profiler_copts(), deps = [ ":timespan", "//tensorflow/core:lib", @@ -46,6 +50,7 @@ cc_library( name = "hardware_type_utils", srcs = ["hardware_type_utils.cc"], hdrs = ["hardware_type_utils.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", @@ -58,6 +63,14 @@ cc_library( hdrs = ["math_utils.h"], ) +cc_library( + name = "format_utils", + hdrs = ["format_utils.h"], + deps = [ + "//tensorflow/core:lib", + ], +) + cc_library( name = "html_utils", hdrs = ["html_utils.h"], @@ -70,6 +83,7 @@ cc_library( name = "op_metrics_db_utils", srcs = ["op_metrics_db_utils.cc"], hdrs = ["op_metrics_db_utils.h"], + copts = tf_profiler_copts(), deps = [ ":math_utils", ":tf_op_utils", @@ -95,6 +109,7 @@ cc_library( name = "op_utils", srcs = ["op_utils.cc"], hdrs = ["op_utils.h"], + copts = tf_profiler_copts(), deps = [ ":op_metrics_db_utils", ":tf_op_utils", @@ -110,8 +125,10 @@ cc_library( name = "tf_op_utils", srcs = ["tf_op_utils.cc"], hdrs = ["tf_op_utils.h"], + copts = tf_profiler_copts(), deps = [ - "//tensorflow/core:regexp_internal", + "//tensorflow/core:lib", + "//tensorflow/core/platform:regexp", "@com_google_absl//absl/strings", ], ) @@ -131,6 +148,7 @@ tf_cc_test( cc_library( name = "timespan", hdrs = ["timespan.h"], + copts = tf_profiler_copts(), deps = [ ":time_utils", "//tensorflow/core:lib", @@ -151,6 +169,7 @@ tf_cc_test( cc_library( name = "time_utils", hdrs = ["time_utils.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", ], @@ -159,6 +178,7 @@ cc_library( cc_library( name = "trace_utils", hdrs = ["trace_utils.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", ], @@ -168,6 +188,7 @@ cc_library( name = "xplane_builder", srcs = ["xplane_builder.cc"], hdrs = ["xplane_builder.h"], + copts = tf_profiler_copts(), visibility = [":friends"], deps = [ ":time_utils", @@ -196,8 +217,10 @@ cc_library( name = "xplane_schema", srcs = ["xplane_schema.cc"], hdrs = ["xplane_schema.h"], + copts = tf_profiler_copts(), visibility = [":friends"], deps = [ + ":tf_op_utils", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/container:flat_hash_map", @@ -210,6 +233,7 @@ cc_library( name = "xplane_utils", srcs = ["xplane_utils.cc"], hdrs = ["xplane_utils.h"], + copts = tf_profiler_copts(), visibility = [":friends"], deps = [ ":timespan", @@ -246,6 +270,7 @@ cc_library( testonly = True, srcs = ["xplane_test_utils.cc"], hdrs = ["xplane_test_utils.h"], + copts = tf_profiler_copts(), visibility = [":friends"], deps = [ ":xplane_builder", @@ -263,6 +288,7 @@ cc_library( name = "xplane_visitor", srcs = ["xplane_visitor.cc"], hdrs = ["xplane_visitor.h"], + copts = tf_profiler_copts(), visibility = [":friends"], deps = [ ":time_utils", @@ -279,6 +305,7 @@ cc_library( cc_library( name = "tf_xplane_visitor", hdrs = ["tf_xplane_visitor.h"], + copts = tf_profiler_copts(), visibility = [":friends"], deps = [ ":xplane_schema", @@ -291,9 +318,9 @@ cc_library( name = "group_events", srcs = ["group_events.cc"], hdrs = ["group_events.h"], + copts = tf_profiler_copts(), visibility = [":friends"], deps = [ - ":tf_op_utils", ":tf_xplane_visitor", ":xplane_builder", ":xplane_schema", @@ -336,6 +363,7 @@ cc_library( name = "cost_utils", srcs = ["cost_utils.cc"], hdrs = ["cost_utils.h"], + copts = tf_profiler_copts(), deps = [ ":tf_op_utils", ":xplane_schema", @@ -357,6 +385,7 @@ cc_library( name = "derived_timeline", srcs = ["derived_timeline.cc"], hdrs = ["derived_timeline.h"], + copts = tf_profiler_copts(), deps = [ ":group_events", ":tf_op_utils", @@ -402,6 +431,7 @@ cc_library( name = "kernel_stats_utils", srcs = ["kernel_stats_utils.cc"], hdrs = ["kernel_stats_utils.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", @@ -426,6 +456,7 @@ cc_library( name = "tfstreamz_utils", srcs = ["tfstreamz_utils.cc"], hdrs = ["tfstreamz_utils.h"], + copts = tf_profiler_copts(), deps = [ ":xplane_builder", "//tensorflow/core:lib", @@ -442,6 +473,7 @@ cc_library( name = "step_intersection", srcs = ["step_intersection.cc"], hdrs = ["step_intersection.h"], + copts = tf_profiler_copts(), deps = [ ":timespan", "//tensorflow/core:lib", @@ -467,8 +499,31 @@ tf_cc_test( cc_library( name = "file_system_utils", hdrs = ["file_system_utils.h"], + copts = tf_profiler_copts(), deps = [ "//tensorflow/core/platform", "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "parse_annotation", + srcs = ["parse_annotation.cc"], + hdrs = ["parse_annotation.h"], + copts = tf_profiler_copts(), + visibility = ["//tensorflow/core/profiler:friends"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "parse_annotation_test", + srcs = ["parse_annotation_test.cc"], + deps = [ + ":parse_annotation", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index 26b33f1e621..a78404af849 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -43,13 +43,6 @@ namespace tensorflow { namespace profiler { namespace { -// TODO(profiler): Once we capture HLO protos for xla/gpu, we should use that -// to look up tensorflow op name from hlo_module/hlo_op. -absl::string_view DummySymbolResolver(absl::string_view hlo_module, - absl::string_view hlo_op) { - return absl::string_view(); -} - const absl::string_view kAnnotationDelimiter = "::"; XEvent CreateXEvent(const XEventMetadata& metadata, int64 offset_ps, @@ -338,17 +331,17 @@ void DeriveEventsFromHostTrace(const XPlane* host_trace, void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map, XSpace* space, bool step_info_only) { + // TODO(profiler): Once we capture HLO protos for xla/gpu, we should use that + // to look up tensorflow op name from hlo_module/hlo_op. + auto dummy_symbol_resolver = [](absl::string_view hlo_module, + absl::string_view hlo_op) { + return absl::string_view(); + }; std::vector device_traces = FindMutablePlanesWithPrefix(space, kGpuPlanePrefix); - GenerateDerivedTimeLines(group_metadata_map, device_traces, step_info_only); -} - -void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map, - const std::vector& device_traces, - bool step_info_only) { for (XPlane* plane : device_traces) { - DeriveEventsFromAnnotations(DummySymbolResolver, group_metadata_map, plane, - step_info_only); + DeriveEventsFromAnnotations(dummy_symbol_resolver, group_metadata_map, + plane, step_info_only); } } diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h index bf8280708fa..5fcf8b44710 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.h +++ b/tensorflow/core/profiler/utils/derived_timeline.h @@ -98,9 +98,6 @@ void DeriveEventsFromHostTrace(const XPlane* host_trace, // derived timelines for the plane by calling DeriveEventsFromAnnotations. void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map, XSpace* space, bool step_info_only = false); -void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map, - const std::vector& device_traces, - bool step_info_only = false); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/diagnostics.h b/tensorflow/core/profiler/utils/diagnostics.h index 3c96ebb0f06..80753ff4d51 100644 --- a/tensorflow/core/profiler/utils/diagnostics.h +++ b/tensorflow/core/profiler/utils/diagnostics.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_UTILS_ERRORS_H_ #include "absl/strings/string_view.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" #include "tensorflow/core/profiler/protobuf/op_stats.pb.h" @@ -24,15 +25,15 @@ namespace tensorflow { namespace profiler { // Error message that the visualization is based on incomplete step. -ABSL_CONST_INIT extern const absl::string_view kErrorIncompleteStep; +TF_CONST_INIT extern const absl::string_view kErrorIncompleteStep; // Error message that no step marker is seen and visualization contains no // step info. -ABSL_CONST_INIT extern const absl::string_view kErrorNoStepMarker; +TF_CONST_INIT extern const absl::string_view kErrorNoStepMarker; -ABSL_CONST_INIT extern const absl::string_view kNoDeviceTraceCollected; +TF_CONST_INIT extern const absl::string_view kNoDeviceTraceCollected; -ABSL_CONST_INIT extern const absl::string_view kStepsDropped; +TF_CONST_INIT extern const absl::string_view kStepsDropped; void PopulateStepDiagnostics(const OpStats& op_stats, Diagnostics* diag); diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc index 137a798c7f8..3bc2505ea89 100644 --- a/tensorflow/core/profiler/utils/event_span.cc +++ b/tensorflow/core/profiler/utils/event_span.cc @@ -162,8 +162,35 @@ EventType ClassifyDeviceCompute(absl::string_view event_name, } } +constexpr int kNumGenericEventTypes = GenericEventType::kLastGenericEventType - + GenericEventType::kFirstGenericEventType + + 1; + +using GenericEventTypeStrMap = + absl::flat_hash_map; + +const GenericEventTypeStrMap& GetGenericEventTypeStrMap() { + static const auto* generic_event_type_str_map = new GenericEventTypeStrMap({ + {kDeviceCompute, "Device compute"}, + {kDeviceToDevice, "Device to device"}, + {kDeviceCollectives, "Device collective communication"}, + {kHostCompute, "Host compute"}, + {kHostPrepare, "Kernel launch"}, + {kInput, "Input"}, + {kOutput, "Output"}, + {kCompile, "Compilation"}, + {kAllOthers, "All others"}, + }); + DCHECK_EQ(generic_event_type_str_map->size(), kNumGenericEventTypes); + return *generic_event_type_str_map; +} + } // namespace +absl::string_view GetGenericEventTypeStr(GenericEventType event_type) { + return GetGenericEventTypeStrMap().at(event_type); +} + EventType ClassifyGpuEvent(absl::string_view event_name, absl::string_view tensor_shapes) { if (absl::StartsWithIgnoreCase(event_name, "MEMCPYHtoD")) @@ -210,6 +237,8 @@ std::string PrintEventType(EventType event_type) { return "host_to_device"; case HOST_PREPARE: return "host_prepare"; + case DEVICE_COLLECTIVES: + return "device_collectives"; case HOST_WAIT_INPUT: return "host_wait_input"; case DEVICE_TO_DEVICE: diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h index e2c8f99a2e7..898d6ce7ad3 100644 --- a/tensorflow/core/profiler/utils/event_span.h +++ b/tensorflow/core/profiler/utils/event_span.h @@ -68,6 +68,30 @@ enum EventType { LAST_EVENT_TYPE = DEVICE_WAIT_HOST }; +// Generic event types that shown to the user. +enum GenericEventType { + kFirstGenericEventType = 1, + // Device is computing. + kDeviceCompute = kFirstGenericEventType, + // Device-to-device communication. + kDeviceToDevice, + // Collective Ops such as All-Reduce and NCCL. + kDeviceCollectives, + // Host is computing. + kHostCompute, + // Host is preparing to launch a computation on device. + kHostPrepare, + // Device waiting for input from the host. + kInput, + // Device sending output to the host. + kOutput, + // Host is compling. + kCompile, + // No recognized event associated with the time. + kAllOthers, + kLastGenericEventType = kAllOthers, +}; + // Contains the type and timespan of an event. struct EventTypeSpan { EventType type; // type of this event. @@ -197,6 +221,9 @@ EventType ClassifyGpuEvent(absl::string_view event_name, // Returns the name of the given EventType. std::string PrintEventType(EventType event_type); +// Returns the string of the given GenericEventType. +absl::string_view GetGenericEventTypeStr(GenericEventType event_type); + // Returns a string that prints the given EventTypeSpan. std::string PrintEventTypeSpan(const EventTypeSpan& event_type_span); diff --git a/tensorflow/core/profiler/utils/format_utils.h b/tensorflow/core/profiler/utils/format_utils.h new file mode 100644 index 00000000000..ca989f0b2ec --- /dev/null +++ b/tensorflow/core/profiler/utils/format_utils.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_FORMAT_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_FORMAT_UTILS_H_ + +#include + +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace profiler { +namespace internal { + +inline std::string FormatDouble(const char* fmt, double d) { + constexpr int kBufferSize = 32; + char buffer[kBufferSize]; + int result = snprintf(buffer, kBufferSize, fmt, d); + DCHECK(result > 0 && result < kBufferSize); + return std::string(buffer); +} + +} // namespace internal + +// Formats d with one digit after the decimal point. +inline std::string OneDigit(double d) { + return internal::FormatDouble("%.1f", d); +} + +// Formats d with maximum precision to allow parsing the result back to the same +// number. +inline std::string MaxPrecision(double d) { + return internal::FormatDouble("%.17g", d); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_FORMAT_UTILS_H_ diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc index cf1830b9d76..a78fe5a1513 100644 --- a/tensorflow/core/profiler/utils/group_events.cc +++ b/tensorflow/core/profiler/utils/group_events.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/connected_traceme.h" -#include "tensorflow/core/profiler/utils/tf_op_utils.h" #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" @@ -79,16 +78,7 @@ int64 GetEventType(bool is_host_plane, const EventNode& event) { } else if (absl::StartsWith(name, "ProcessBatch")) { return HostEventType::kProcessBatch; } - // TF op names. - Category category = ParseTfOpFullname(name).category; - switch (category) { - case Category::kTensorFlow: - return HostEventType::kTfOpRun; - case Category::kTfData: - return HostEventType::kIterator; - default: - return HostEventType::kUnknownHostEventType; - } + return HostEventType::kUnknownHostEventType; } } @@ -304,6 +294,11 @@ bool HasJaxEvent(const EventNodeMap& event_node_map) { return event_node_map.contains(HostEventType::kExecuteOnLocalDevices); } +bool IsIteratorEventType(absl::optional event_type) { + return event_type == HostEventType::kIterator || + event_type == HostEventType::kDeviceInputPipelineSecondIterator; +} + } // namespace EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line, @@ -488,15 +483,14 @@ bool EventNode::StartsBefore(const EventNode& other) const { other.GetEventVisitor().TimestampPs(); } -void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor, - XPlane* plane, +void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor, ContextGroupMap* context_groups) { // TODO(b/149095099): avoid string comparison. - bool is_host_plane = (visitor.Name() == kHostThreadsPlaneName); + bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName); for (auto& line : *plane->mutable_lines()) { std::vector parent_nodes; for (auto& event : *line.mutable_events()) { - auto cur_node = absl::make_unique(&visitor, &line, &event); + auto cur_node = absl::make_unique(visitor, &line, &event); // Update `context_groups` for `ConnectInterThread`. SetContextGroup(cur_node.get(), context_groups); // Update `root_events_` for `CreateEventGroup`. @@ -578,14 +572,14 @@ void EventForest::ProcessLegacyRootEvents( } } -void EventForest::CreateEventGroup(GroupMetadataMap* group_metadata_map) { +void EventForest::CreateEventGroups() { // Handle inference batching profiles. if (event_node_map_.contains(HostEventType::kProcessBatch)) { // Assign group_id per batch. for (const auto& process_batch_node : event_node_map_[HostEventType::kProcessBatch]) { ProcessRootEvent(next_group_id_++, /*set_step_name=*/false, - process_batch_node.get(), group_metadata_map); + process_batch_node.get(), &group_metadata_map_); } HostEventType request_event_type = event_node_map_.contains(HostEventType::kBatchingSessionRun) @@ -596,15 +590,15 @@ void EventForest::CreateEventGroup(GroupMetadataMap* group_metadata_map) { // Assign group_id per request. for (const auto& request_event : *request_events) { ProcessRootEvent(next_group_id_++, /*set_step_name=*/false, - request_event.get(), group_metadata_map); + request_event.get(), &group_metadata_map_); // Also, set a helper stat for selected_group_ids. - request_event->AddSelectedGroupIds(*group_metadata_map); + request_event->AddSelectedGroupIds(group_metadata_map_); } } // Set a helper stat for selected_group_ids per batch. for (const auto& process_batch_node : event_node_map_[HostEventType::kProcessBatch]) { - process_batch_node->AddSelectedGroupIds(*group_metadata_map); + process_batch_node->AddSelectedGroupIds(group_metadata_map_); } return; } @@ -612,7 +606,7 @@ void EventForest::CreateEventGroup(GroupMetadataMap* group_metadata_map) { if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) { for (EventNode* root_event : tf_loop_root_events_) { ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event, - group_metadata_map); + &group_metadata_map_); } return; } @@ -624,7 +618,7 @@ void EventForest::CreateEventGroup(GroupMetadataMap* group_metadata_map) { (!HasJaxEvent(event_node_map_) || !IsLegacyRootEvent(root_event->GetEventVisitor()))) { ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event, - group_metadata_map); + &group_metadata_map_); } } } @@ -732,7 +726,7 @@ void EventForest::ProcessWorker() { } } -void EventForest::ProcessModelIds(GroupMetadataMap* group_metadata_map) { +void EventForest::ProcessModelIds() { auto session_run_event_list = gtl::FindOrNull(event_node_map_, HostEventType::kSessionRun); if (!session_run_event_list) return; @@ -742,11 +736,45 @@ void EventForest::ProcessModelIds(GroupMetadataMap* group_metadata_map) { absl::optional model_id = session_run_event->GetEventVisitor().GetStat(StatType::kModelId); if (!model_id.has_value()) continue; - (*group_metadata_map)[*group_id].model_id = model_id->ToString(); + group_metadata_map_[*group_id].model_id = model_id->ToString(); } } -void EventForest::ProcessTfDataEvents() { +void EventForest::AddPlane( + const std::function visitor_factory, + XPlane* plane) { + CreateStatMetadata(plane); + planes_.push_back({plane, visitor_factory(plane)}); +} + +void EventForest::AddSpace( + const std::function visitor_factory, + XSpace* space) { + for (XPlane& plane : *space->mutable_planes()) { + AddPlane(visitor_factory, &plane); + } +} + +void EventForest::AddPlanes( + const std::function visitor_factory, + const std::vector& planes) { + for (XPlane* plane : planes) { + AddPlane(visitor_factory, plane); + } +} + +void EventForest::ConnectEvents( + const std::vector& connect_info_list) { + ContextGroupMap context_groups; + for (auto& plane_visitor : planes_) { + ConnectIntraThread(plane_visitor.first, &plane_visitor.second, + &context_groups); + } + ConnectInterThread(connect_info_list); + ConnectContextGroups(context_groups); +} + +void EventForest::ConnectTfDataEvents() { absl::flat_hash_map, std::vector> produce_iterator_map; @@ -765,8 +793,7 @@ void EventForest::ProcessTfDataEvents() { produce_event->GetEventVisitor().GetStat(StatType::kElementId); if (!element_id.has_value()) continue; for (EventNode* produce_iterator : produce_event->GetChildren()) { - if (IsDatasetOp(ParseTfOpFullname( - produce_iterator->GetEventVisitor().Name()))) { + if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) { absl::optional iterator_id = produce_iterator->GetEventVisitor().GetStat(StatType::kParentId); if (!iterator_id.has_value()) break; @@ -799,8 +826,7 @@ void EventForest::ProcessTfDataEvents() { // parents. EventNode* consume_iterator = consume_event->GetParents().at(0); if (!consume_iterator || - !IsDatasetOp( - ParseTfOpFullname(consume_iterator->GetEventVisitor().Name()))) { + !IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) { continue; } absl::optional iterator_id = @@ -819,38 +845,14 @@ void EventForest::ProcessTfDataEvents() { VLOG(1) << num_matched << " consumer iterators matched."; } -EventForest::EventForest( - const std::vector& connect_info_list, - const std::vector& root_event_types, - const std::function visitor_factory, - XSpace* space, GroupMetadataMap* group_metadata_map) { - ContextGroupMap context_groups; - visitors_.reserve(space->planes_size()); - for (auto& plane : *space->mutable_planes()) { - CreateStatMetadata(&plane); - visitors_.push_back(visitor_factory(&plane)); - ConnectIntraThread(visitors_.back(), &plane, &context_groups); - } - ConnectInterThread(connect_info_list); - ConnectContextGroups(context_groups); +void EventForest::GroupEvents(const std::vector& root_event_types) { ProcessTensorFlowLoop(); ProcessWorker(); ProcessLegacyRootEvents(root_event_types); - CreateEventGroup(group_metadata_map); + CreateEventGroups(); MarkEagerlyExecutedGpuKernels(); MarkEagerlyExecutedCpuTfOps(); - ProcessModelIds(group_metadata_map); -} - -EventForest::EventForest( - const std::function visitor_factory, - XPlane* plane) { - ContextGroupMap context_groups; - visitors_.reserve(1); - CreateStatMetadata(plane); - visitors_.push_back(visitor_factory(plane)); - ConnectIntraThread(visitors_.back(), plane, &context_groups); - ConnectContextGroups(context_groups); + ProcessModelIds(); } std::vector CreateInterThreadConnectInfoList() { @@ -867,16 +869,17 @@ std::vector CreateInterThreadConnectInfoList() { return connect_info_list; } -void GroupTfEvents(XSpace* space, GroupMetadataMap* group_metadata_map) { +void GroupTfEvents(XSpace* space, EventForest* event_forest) { std::vector connect_info_list = CreateInterThreadConnectInfoList(); - EventForest event_forest(connect_info_list, {}, CreateTfXPlaneVisitor, space, - group_metadata_map); + event_forest->AddSpace(CreateTfXPlaneVisitor, space); + event_forest->ConnectEvents(connect_info_list); + event_forest->GroupEvents(); } void GroupTfEvents(XSpace* space) { - GroupMetadataMap group_metadata_map; - GroupTfEvents(space, &group_metadata_map); + EventForest event_forest; + GroupTfEvents(space, &event_forest); } } // namespace profiler diff --git a/tensorflow/core/profiler/utils/group_events.h b/tensorflow/core/profiler/utils/group_events.h index dfca990e3f8..57519f361a6 100644 --- a/tensorflow/core/profiler/utils/group_events.h +++ b/tensorflow/core/profiler/utils/group_events.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_GROUP_EVENTS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_GROUP_EVENTS_H_ +#include #include #include #include @@ -160,23 +161,35 @@ using ContextGroupMap = absl::flat_hash_map< // events specified in root_event_types or marked by the semantic argument. class EventForest { public: - EventForest(const std::vector& connect_info_list, - const std::vector& root_event_types, - const std::function visitor_factory, - XSpace* space, GroupMetadataMap* group_metadata_map); + void AddSpace( + const std::function visitor_factory, + XSpace* space); - EventForest(const std::function visitor_factory, - XPlane* plane); + void AddPlanes( + const std::function visitor_factory, + const std::vector& planes); + + void ConnectEvents( + const std::vector& connect_info_list = {}); + + void ConnectTfDataEvents(); + + void GroupEvents(const std::vector& root_event_types = {}); const EventNodeMap& GetEventNodeMap() const { return event_node_map_; } - // Connects tf.data events across threads. - void ProcessTfDataEvents(); + const GroupMetadataMap& GetGroupMetadataMap() const { + return group_metadata_map_; + } private: + void AddPlane( + const std::function visitor_factory, + XPlane* plane); + // Creates an EventNode for each event in event_node_map and connect events // according to the nesting relationship within the thread. - void ConnectIntraThread(const XPlaneVisitor& visitor, XPlane* plane, + void ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor, ContextGroupMap* context_groups); // Connects events across threads according to connect_info_list. @@ -190,7 +203,7 @@ class EventForest { // used, each TF loop iteration becomes a root. Otherwise, top root events // (i.e., none of their ancestors is a root event) are used as roots. A new // group is created with all events reachable from a root. - void CreateEventGroup(GroupMetadataMap* group_metadata_map); + void CreateEventGroups(); // Sets the is_eager stat to true for the eagerly executed GPU kernel events. void MarkEagerlyExecutedGpuKernels(); @@ -202,17 +215,20 @@ class EventForest { // iteraton to `tf_loop_root_events_`. void ProcessTensorFlowLoop(); - // Processes the worker thread by grouping a FunctionRun with the following + // Processes the worker thread by connecting a FunctionRun with the following // eager ops (e.g., for Keras callback). void ProcessWorker(); // Adds model ids to group_metadata_map for inference profiles. - void ProcessModelIds(GroupMetadataMap* group_metadata_map); + void ProcessModelIds(); EventNodeMap event_node_map_; std::vector visitors_; + // std::deque for pointer stability. + std::deque> planes_; EventList root_events_; EventList tf_loop_root_events_; + GroupMetadataMap group_metadata_map_; int64 next_group_id_ = 0; }; @@ -220,7 +236,7 @@ std::vector CreateInterThreadConnectInfoList(); // Calls GroupEvents with connect_info_list and root_event_types specific to // TensorFlow. -void GroupTfEvents(XSpace* space, GroupMetadataMap* group_metadata_map); +void GroupTfEvents(XSpace* space, EventForest* event_forest); void GroupTfEvents(XSpace* space); } // namespace profiler diff --git a/tensorflow/core/profiler/utils/group_events_test.cc b/tensorflow/core/profiler/utils/group_events_test.cc index 1e9650ea651..604485f03e5 100644 --- a/tensorflow/core/profiler/utils/group_events_test.cc +++ b/tensorflow/core/profiler/utils/group_events_test.cc @@ -68,15 +68,17 @@ TEST(GroupEventsTest, GroupGpuTraceLegacyRootTest) { CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300, {{StatType::kCorrelationId, kCorrelationId}}); - GroupMetadataMap group_metadata_map; - GroupTfEvents(&space, &group_metadata_map); + EventForest event_forest; + GroupTfEvents(&space, &event_forest); + const GroupMetadataMap& group_metadata_map = + event_forest.GetGroupMetadataMap(); XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane); EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3); EXPECT_EQ(device_plane_visitor.GetStatType( device_plane->lines(0).events(0).stats(1)), StatType::kGroupId); EXPECT_EQ(group_metadata_map.size(), 1); - EXPECT_EQ(group_metadata_map[0].name, "train 123"); + EXPECT_EQ(group_metadata_map.at(0).name, "train 123"); } TEST(GroupEventsTest, GroupGpuTraceTest) { @@ -109,15 +111,17 @@ TEST(GroupEventsTest, GroupGpuTraceTest) { CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300, {{StatType::kCorrelationId, kCorrelationId}}); - GroupMetadataMap group_metadata_map; - GroupTfEvents(&space, &group_metadata_map); + EventForest event_forest; + GroupTfEvents(&space, &event_forest); + const GroupMetadataMap& group_metadata_map = + event_forest.GetGroupMetadataMap(); XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane); EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3); EXPECT_EQ(device_plane_visitor.GetStatType( device_plane->lines(0).events(0).stats(1)), StatType::kGroupId); EXPECT_EQ(group_metadata_map.size(), 1); - EXPECT_EQ(group_metadata_map[0].name, "train 123"); + EXPECT_EQ(group_metadata_map.at(0).name, "train 123"); } TEST(GroupEventsTest, GroupTensorFlowLoopTest) { @@ -147,8 +151,10 @@ TEST(GroupEventsTest, GroupTensorFlowLoopTest) { CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300, {{StatType::kCorrelationId, kCorrelationId}}); - GroupMetadataMap group_metadata_map; - GroupTfEvents(&space, &group_metadata_map); + EventForest event_forest; + GroupTfEvents(&space, &event_forest); + const GroupMetadataMap& group_metadata_map = + event_forest.GetGroupMetadataMap(); XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane); EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3); EXPECT_EQ(device_plane_visitor.GetStatType( @@ -156,7 +162,7 @@ TEST(GroupEventsTest, GroupTensorFlowLoopTest) { StatType::kGroupId); EXPECT_EQ(device_plane->lines(0).events(0).stats(1).int64_value(), 10); EXPECT_EQ(group_metadata_map.size(), 1); - EXPECT_EQ(group_metadata_map[10].name, "10"); + EXPECT_EQ(group_metadata_map.at(10).name, "10"); } // When there are multiple TF loops, group_id is assigned in the order of TF @@ -194,8 +200,10 @@ TEST(GroupEventsTest, GroupMultipleTensorFlowLoopsTest) { {{StatType::kStepId, kFirstStepId}, {StatType::kIterNum, kFirstIterNumStart + 1}}); - GroupMetadataMap group_metadata_map; - GroupTfEvents(&space, &group_metadata_map); + EventForest event_forest; + GroupTfEvents(&space, &event_forest); + const GroupMetadataMap& group_metadata_map = + event_forest.GetGroupMetadataMap(); EXPECT_EQ(group_metadata_map.size(), 4); EXPECT_TRUE(group_metadata_map.count(10)); EXPECT_TRUE(group_metadata_map.count(11)); @@ -230,8 +238,7 @@ TEST(GroupEventsTest, GroupFunctionalOp) { HostEventType::kExecutorStateProcess, 100, 150, {{StatType::kStepId, kFunctionStepId}}); - GroupMetadataMap group_metadata_map; - GroupTfEvents(&space, &group_metadata_map); + GroupTfEvents(&space); XPlaneVisitor host_plane_visitor = CreateTfXPlaneVisitor(host_plane); // Check that RemoteCallOp is grouped correctly so that all events belong // to the same group. @@ -609,16 +616,18 @@ TEST(GroupEventsTest, BatchingSessionTest) { {{StatType::kConsumerType, kBatchContextType}, {StatType::kConsumerId, kBatchContextId}}); - GroupMetadataMap group_metadata_map; - GroupTfEvents(&raw_space, &group_metadata_map); + EventForest event_forest; + GroupTfEvents(&raw_space, &event_forest); + const GroupMetadataMap& group_metadata_map = + event_forest.GetGroupMetadataMap(); EXPECT_EQ(group_metadata_map.size(), 3); // Check that the ProcessBatch group has two BatchingSessionRun groups as // parents. - EXPECT_EQ(group_metadata_map[0].parents.size(), 2); + EXPECT_EQ(group_metadata_map.at(0).parents.size(), 2); // Check that the BatchingSessionRun groups have one ProcessBatch group as a // child. - EXPECT_EQ(group_metadata_map[1].children.size(), 1); - EXPECT_EQ(group_metadata_map[2].children.size(), 1); + EXPECT_EQ(group_metadata_map.at(1).children.size(), 1); + EXPECT_EQ(group_metadata_map.at(2).children.size(), 1); // Chech that the events have the selected_group_ids stat set. uint64 num_checked = 0; CreateTfXPlaneVisitor(raw_plane).ForEachLine( diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc index b7cae767813..3fc382f92a9 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.cc +++ b/tensorflow/core/profiler/utils/hardware_type_utils.cc @@ -57,6 +57,17 @@ uint32 GetFmaMaxThroughputPerSMPerCycle(const DeviceCapabilities& device_cap) { n_fp32_cores = 64; n_tc_cores = 8; break; + case 8: + // Ampere + if (device_cap.compute_capability().minor() >= 6) { + // Ampere SM86 + n_fp32_cores = 128; + } else { + // Ampere SM80 + n_fp32_cores = 64; + } + n_tc_cores = 4; + break; default: LOG(ERROR) << "Invalid GPU compute capability."; break; diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc index b19d02ffa44..982434369ea 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc @@ -88,9 +88,11 @@ bool IsKernelUsingTensorCore(absl::string_view kernel_name) { // Some examples: volta_h884gemm, volta_fp16_s884gemm, // turing_fp16_s1688cudnn_fp16 bool possible_tensor_kernel = absl::StrContains(kernel_name, "884") || - absl::StrContains(kernel_name, "1688"); + absl::StrContains(kernel_name, "1688") || + absl::StrContains(kernel_name, "hmma") || + absl::StrContains(kernel_name, "xmma"); if (possible_tensor_kernel) { - VLOG(1) << "Possible tensor kernel: " << kernel_name << "\n"; + VLOG(3) << "Possible tensor kernel: " << kernel_name; } return (absl::StartsWith(kernel_name, "volta_i884") || @@ -104,7 +106,9 @@ bool IsKernelUsingTensorCore(absl::string_view kernel_name) { absl::StartsWith(kernel_name, "turing_s1688") || absl::StartsWith(kernel_name, "turing_fp16_i1688") || absl::StartsWith(kernel_name, "turing_fp16_h1688") || - absl::StartsWith(kernel_name, "turing_fp16_s1688")); + absl::StartsWith(kernel_name, "turing_fp16_s1688") || + absl::StrContains(kernel_name, "hmma") || + absl::StrContains(kernel_name, "xmma")); } // This list is not exhaustive. diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index 7cb776abfe7..af00028169e 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" @@ -27,7 +28,7 @@ namespace tensorflow { namespace profiler { // The name of OpMetrics to represent the idle time. -ABSL_CONST_INIT extern const absl::string_view kIdle; +TF_CONST_INIT extern const absl::string_view kIdle; // Helps build an op metrics database (borrowed). // Enables fast lookup of existing ops and prevents the creation of duplicate diff --git a/tensorflow/core/profiler/internal/parse_annotation.cc b/tensorflow/core/profiler/utils/parse_annotation.cc similarity index 98% rename from tensorflow/core/profiler/internal/parse_annotation.cc rename to tensorflow/core/profiler/utils/parse_annotation.cc index a4cdc09739d..ce948f28021 100644 --- a/tensorflow/core/profiler/internal/parse_annotation.cc +++ b/tensorflow/core/profiler/utils/parse_annotation.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/parse_annotation.h" +#include "tensorflow/core/profiler/utils/parse_annotation.h" #include #include diff --git a/tensorflow/core/profiler/internal/parse_annotation.h b/tensorflow/core/profiler/utils/parse_annotation.h similarity index 87% rename from tensorflow/core/profiler/internal/parse_annotation.h rename to tensorflow/core/profiler/utils/parse_annotation.h index bb0f12217d3..f8b1fd5df9a 100644 --- a/tensorflow/core/profiler/internal/parse_annotation.h +++ b/tensorflow/core/profiler/utils/parse_annotation.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PARSE_ANNOTATION_H_ -#define TENSORFLOW_CORE_PROFILER_INTERNAL_PARSE_ANNOTATION_H_ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_PARSE_ANNOTATION_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_PARSE_ANNOTATION_H_ #include @@ -43,4 +43,4 @@ std::vector ParseAnnotationStack( } // namespace profiler } // namespace tensorflow -#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PARSE_ANNOTATION_H_ +#endif // TENSORFLOW_CORE_PROFILER_UTILS_PARSE_ANNOTATION_H_ diff --git a/tensorflow/core/profiler/internal/parse_annotation_test.cc b/tensorflow/core/profiler/utils/parse_annotation_test.cc similarity index 98% rename from tensorflow/core/profiler/internal/parse_annotation_test.cc rename to tensorflow/core/profiler/utils/parse_annotation_test.cc index e5d876ac5af..3a88782c576 100644 --- a/tensorflow/core/profiler/internal/parse_annotation_test.cc +++ b/tensorflow/core/profiler/utils/parse_annotation_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/internal/parse_annotation.h" +#include "tensorflow/core/profiler/utils/parse_annotation.h" #include diff --git a/tensorflow/core/profiler/utils/tf_op_utils.h b/tensorflow/core/profiler/utils/tf_op_utils.h index af14e1ccb8e..0713e34faf4 100644 --- a/tensorflow/core/profiler/utils/tf_op_utils.h +++ b/tensorflow/core/profiler/utils/tf_op_utils.h @@ -21,15 +21,16 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { namespace profiler { // Special op types. -ABSL_CONST_INIT extern const absl::string_view kUnknownOp; -ABSL_CONST_INIT extern const absl::string_view kDatasetOp; -ABSL_CONST_INIT extern const absl::string_view kMemcpyHToDOp; -ABSL_CONST_INIT extern const absl::string_view kMemcpyDToHOp; +TF_CONST_INIT extern const absl::string_view kUnknownOp; +TF_CONST_INIT extern const absl::string_view kDatasetOp; +TF_CONST_INIT extern const absl::string_view kMemcpyHToDOp; +TF_CONST_INIT extern const absl::string_view kMemcpyDToHOp; enum class Category { kTensorFlow, diff --git a/tensorflow/core/profiler/utils/tf_xplane_visitor.h b/tensorflow/core/profiler/utils/tf_xplane_visitor.h index 17a7b94ef92..45459173a71 100644 --- a/tensorflow/core/profiler/utils/tf_xplane_visitor.h +++ b/tensorflow/core/profiler/utils/tf_xplane_visitor.h @@ -24,7 +24,8 @@ namespace tensorflow { namespace profiler { inline XPlaneVisitor CreateTfXPlaneVisitor(const XPlane* plane) { - return XPlaneVisitor(plane, {FindHostEventType}, {FindStatType}); + return XPlaneVisitor(plane, {FindHostEventType, FindTfOpEventType}, + {FindStatType}); } } // namespace profiler diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index dec10e327bc..78fd7af584c 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { @@ -231,6 +232,19 @@ absl::optional FindHostEventType(absl::string_view event_name) { return absl::nullopt; } +absl::optional FindTfOpEventType(absl::string_view event_name) { + // TF op names. + Category category = ParseTfOpFullname(event_name).category; + switch (category) { + case Category::kTensorFlow: + return HostEventType::kTfOpRun; + case Category::kTfData: + return HostEventType::kIterator; + default: + return absl::nullopt; + } +} + absl::string_view GetStatTypeStr(StatType stat_type) { return GetStatTypeStrMap().at(stat_type); } diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index e73e810b7cd..8a9ac4fd278 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -21,31 +21,32 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace profiler { // Name of XPlane that contains TraceMe events. -ABSL_CONST_INIT extern const absl::string_view kHostThreadsPlaneName; +TF_CONST_INIT extern const absl::string_view kHostThreadsPlaneName; // Name prefix of XPlane that contains GPU events. -ABSL_CONST_INIT extern const absl::string_view kGpuPlanePrefix; +TF_CONST_INIT extern const absl::string_view kGpuPlanePrefix; // Name of XPlane that contains CUPTI driver API generated events. -ABSL_CONST_INIT extern const absl::string_view kCuptiDriverApiPlaneName; +TF_CONST_INIT extern const absl::string_view kCuptiDriverApiPlaneName; // Name of XPlane that contains profile metadata such as XLA debug info. -ABSL_CONST_INIT extern const absl::string_view kMetadataPlaneName; +TF_CONST_INIT extern const absl::string_view kMetadataPlaneName; // Name of XPlane that contains kpi related metrics. -ABSL_CONST_INIT extern const absl::string_view kTFStreamzPlaneName; +TF_CONST_INIT extern const absl::string_view kTFStreamzPlaneName; // Name of XPlane that contains events from python tracer. -ABSL_CONST_INIT extern const absl::string_view kPythonTracerPlaneName; +TF_CONST_INIT extern const absl::string_view kPythonTracerPlaneName; // Names of XLines that contain ML-level events. -ABSL_CONST_INIT extern const absl::string_view kStepLineName; -ABSL_CONST_INIT extern const absl::string_view kTensorFlowNameScopeLineName; -ABSL_CONST_INIT extern const absl::string_view kTensorFlowOpLineName; -ABSL_CONST_INIT extern const absl::string_view kXlaModuleLineName; -ABSL_CONST_INIT extern const absl::string_view kXlaOpLineName; -ABSL_CONST_INIT extern const absl::string_view kKernelLaunchLineName; +TF_CONST_INIT extern const absl::string_view kStepLineName; +TF_CONST_INIT extern const absl::string_view kTensorFlowNameScopeLineName; +TF_CONST_INIT extern const absl::string_view kTensorFlowOpLineName; +TF_CONST_INIT extern const absl::string_view kXlaModuleLineName; +TF_CONST_INIT extern const absl::string_view kXlaOpLineName; +TF_CONST_INIT extern const absl::string_view kKernelLaunchLineName; // Interesting event types (i.e., TraceMe names). enum HostEventType { @@ -208,6 +209,8 @@ inline bool IsHostEventType(HostEventType event_type, absl::optional FindHostEventType(absl::string_view event_name); +absl::optional FindTfOpEventType(absl::string_view event_name); + absl::string_view GetStatTypeStr(StatType stat_type); bool IsStatType(StatType stat_type, absl::string_view stat_name); diff --git a/tensorflow/core/profiler/utils/xplane_utils.cc b/tensorflow/core/profiler/utils/xplane_utils.cc index 825469f9eab..1389af2f16a 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.cc +++ b/tensorflow/core/profiler/utils/xplane_utils.cc @@ -240,5 +240,16 @@ uint64 GetStartTimestampNs(const XPlane& plane) { return plane_timestamp; } +bool IsEmpty(const XSpace& space) { + for (const auto& plane : space.planes()) { + for (const auto& line : plane.lines()) { + if (!line.events().empty()) { + return false; + } + } + } + return true; +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h index 5cd5275e85e..eab2c8a8858 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.h +++ b/tensorflow/core/profiler/utils/xplane_utils.h @@ -110,6 +110,9 @@ void MergePlanes(const XPlane& src_plane, XPlane* dst_plane); // timestamps. If zero line exists, return 0; uint64 GetStartTimestampNs(const XPlane& plane); +// Returns true if there are no XEvents. +bool IsEmpty(const XSpace& space); + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xplane_visitor.cc b/tensorflow/core/profiler/utils/xplane_visitor.cc index 42068b7c61a..626657a5c2d 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.cc +++ b/tensorflow/core/profiler/utils/xplane_visitor.cc @@ -81,38 +81,40 @@ XPlaneVisitor::XPlaneVisitor(const XPlane* plane, const TypeGetterList& event_type_getter_list, const TypeGetterList& stat_type_getter_list) : XStatsOwner(this, plane), plane_(plane) { - for (const auto& event_type_getter : event_type_getter_list) { - BuildEventTypeMap(plane, event_type_getter); - } - for (const auto& stat_type_getter : stat_type_getter_list) { - BuildStatTypeMap(plane, stat_type_getter); - } + BuildEventTypeMap(plane, event_type_getter_list); + BuildStatTypeMap(plane, stat_type_getter_list); } -void XPlaneVisitor::BuildEventTypeMap(const XPlane* plane, - const TypeGetter& event_type_getter) { +void XPlaneVisitor::BuildEventTypeMap( + const XPlane* plane, const TypeGetterList& event_type_getter_list) { for (const auto& event_metadata : plane->event_metadata()) { uint64 metadata_id = event_metadata.first; const auto& metadata = event_metadata.second; - absl::optional event_type = event_type_getter(metadata.name()); - if (event_type.has_value()) { - auto result = event_metadata_id_map_.emplace(metadata_id, *event_type); - DCHECK(result.second); // inserted - event_type_map_.emplace(*event_type, &metadata); + for (const auto& event_type_getter : event_type_getter_list) { + absl::optional event_type = event_type_getter(metadata.name()); + if (event_type.has_value()) { + auto result = event_metadata_id_map_.emplace(metadata_id, *event_type); + DCHECK(result.second); // inserted + event_type_map_.emplace(*event_type, &metadata); + break; + } } } } -void XPlaneVisitor::BuildStatTypeMap(const XPlane* plane, - const TypeGetter& stat_type_getter) { +void XPlaneVisitor::BuildStatTypeMap( + const XPlane* plane, const TypeGetterList& stat_type_getter_list) { for (const auto& stat_metadata : plane->stat_metadata()) { uint64 metadata_id = stat_metadata.first; const auto& metadata = stat_metadata.second; - absl::optional stat_type = stat_type_getter(metadata.name()); - if (stat_type.has_value()) { - auto result = stat_metadata_id_map_.emplace(metadata_id, *stat_type); - DCHECK(result.second); // inserted - stat_type_map_.emplace(*stat_type, &metadata); + for (const auto& stat_type_getter : stat_type_getter_list) { + absl::optional stat_type = stat_type_getter(metadata.name()); + if (stat_type.has_value()) { + auto result = stat_metadata_id_map_.emplace(metadata_id, *stat_type); + DCHECK(result.second); // inserted + stat_type_map_.emplace(*stat_type, &metadata); + break; + } } } } diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h index f0e4204e4d3..93830c0852a 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.h +++ b/tensorflow/core/profiler/utils/xplane_visitor.h @@ -239,9 +239,9 @@ class XPlaneVisitor : public XStatsOwner { private: void BuildEventTypeMap(const XPlane* plane, - const TypeGetter& event_type_getter); + const TypeGetterList& event_type_getter_list); void BuildStatTypeMap(const XPlane* plane, - const TypeGetter& stat_type_getter); + const TypeGetterList& stat_type_getter_list); const XPlane* plane_; diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index fec929a0b03..29e3e8a4ce3 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -568,6 +568,9 @@ message ConfigProto { // Session::Extend() may not be supported. bool optimize_for_static_graph = 12; + // This field will eventually be deprecated and replaced by + // mlir_bridge_rollout (b/166038521). + // // Whether to enable the MLIR-based TF->XLA bridge. // // This is a replacement to the existing bridge, and not ready for @@ -581,6 +584,22 @@ message ConfigProto { // to lower the encapsulated graph to a particular device. bool enable_mlir_bridge = 13; + // An enum that describes the state of the MLIR bridge rollout. + enum MlirBridgeRollout { + // If this field is left unspecified, the MLIR bridge may be selectively + // enabled on a per graph basis. + MLIR_BRIDGE_ROLLOUT_UNSPECIFIED = 0; + // Enabling the MLIR bridge enables it for all graphs in this session. + MLIR_BRIDGE_ROLLOUT_ENABLED = 1; + // Disabling the MLIR bridge disables it for all graphs in this session. + MLIR_BRIDGE_ROLLOUT_DISABLED = 2; + } + // This field is underdevelopment, for now use enable_mlir_bridge + // (b/166038521). + // + // Whether to enable the MLIR-based TF->XLA bridge. + MlirBridgeRollout mlir_bridge_rollout = 17; + // Whether to enable the MLIR-based Graph optimizations. // // This will become a part of standard Tensorflow graph optimization diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto index f30b282b86c..a5b4cfbe823 100644 --- a/tensorflow/core/protobuf/saved_object_graph.proto +++ b/tensorflow/core/protobuf/saved_object_graph.proto @@ -124,6 +124,13 @@ message SavedBareConcreteFunction { repeated string argument_keywords = 2; // The prefix of `argument_keywords` which may be identified by position. int64 allowed_positional_arguments = 3; + // The spec of the function that this ConcreteFunction is traced from. This + // allows the ConcreteFunction to be called with nest structure inputs. This + // field may not be populated. If this field is absent, the concrete function + // can only be called with flat inputs. + // TODO(b/169361281): support calling saved ConcreteFunction with structured + // inputs in C++ SavedModel API. + FunctionSpec function_spec = 4; } message SavedConstant { diff --git a/tensorflow/core/protobuf/tpu/BUILD b/tensorflow/core/protobuf/tpu/BUILD index 2cd25b4272e..364ebabd16d 100644 --- a/tensorflow/core/protobuf/tpu/BUILD +++ b/tensorflow/core/protobuf/tpu/BUILD @@ -9,8 +9,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tf_proto_library( name = "tpu_embedding_configuration_proto", srcs = [ @@ -30,6 +28,7 @@ tf_proto_library( "optimization_parameters.proto", ], cc_api_version = 2, + protodeps = ["//tensorflow/compiler/xla/service:hlo_proto"], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto index b95574f7827..76c817957df 100644 --- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto +++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package tensorflow.tpu; import "google/protobuf/wrappers.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; message ClippingLimits { google.protobuf.FloatValue lower = 1; // -inf if not set @@ -317,6 +318,34 @@ message FrequencyEstimatorParameters { float weight_exponent = 4; } +// A user-defined optimizer. +// The contained HLO program must take the following arguments in the following +// order: +// 1. gradients +// 2. table weights +// 3. slot variables +// 4. an optional scalar input that is passed in via the dynamic learning +// rate mechanism. +// +// It must return/end in a tuple op that contains the following values in the +// following order: +// 1. new table values +// 2. new slot variable value +// +// The program must have shape (1,1) with dtype float32 throughout and only use +// HLO that operate elementwise (e.g., no reduce, no variables, no control flow +// and no broadcasting outside of the single scalar input). +// The HLO program should be written as if it were a dense update. It will be +// called on each row that needs an update and will applied elementwise. +message UserDefinedProgramParameters { + xla.HloModuleProto program = 1; + // Padding values for the parameter and the slots, see + // StateVariableSpecification.padding_initial_value below for more details on + // how this should be set. One value is needed for the weights and one for + // each slot. + repeated float padding_values = 2; +} + // Status of using gradient accumulation (doing two passes over the input // gradients: one to accumulate them into a temporary array and another to apply // them using the actual optimization algorithm). The extra message is to wrap @@ -395,6 +424,7 @@ message OptimizationParameters { OnlineYogiParameters online_yogi = 20; ProximalYogiParameters proximal_yogi = 21; FrequencyEstimatorParameters frequency_estimator = 23; + UserDefinedProgramParameters user_defined_program = 24; } reserved 15; // Old use_gradient_accumulation. diff --git a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto index 22be27795c7..038c7a1b8aa 100644 --- a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto +++ b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto @@ -87,9 +87,6 @@ message TPUEmbeddingConfiguration { // problem. bool pipeline_execution_with_tensor_core = 7; - // Extended output layout information; if not provided, a compatibility mode - // will use defaults that match the old layout. Providing a value for this - // field is EXPERIMENTAL and most ways of filling it will probably break. Do - // not set it unless you know what you are doing. - TPUEmbeddingOutputLayout output_layout = 8; + // Extended output layout information; deprecated and now ignored. + TPUEmbeddingOutputLayout output_layout = 8 [deprecated = true]; } diff --git a/tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto b/tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto index aed30b2f22f..25ee93fdd57 100644 --- a/tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto +++ b/tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto @@ -5,6 +5,8 @@ package tensorflow.tpu; // In the comments here, "layout" refers to the top-level EmbeddingOutputLayout // proto contained in the TPUEmbeddingConfiguration. +// This proto is deprecated and its contents are no longer used. + // The embedding output consists of a list of tensors, each specified by an // EmbeddingOutputTensor proto within the EmbeddingOutputLayout (the "output" // field). Each table and feature lookup is then placed into some number of @@ -15,6 +17,8 @@ package tensorflow.tpu; // EmbeddingOutputLayout. message TPUEmbeddingOutputLayout { + option deprecated = true; + // Location of one copy of the feature's data. message OutputLocation { // Which output tensor this copy of the feature will go into. Must be diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 1c514b3af33..5556ec07c4f 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 536 // Updated: 2020/9/26 +#define TF_GRAPH_DEF_VERSION 553 // Updated: 2020/10/13 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/summary/BUILD b/tensorflow/core/summary/BUILD index d39abb0b9ec..6fc4d0d5d56 100644 --- a/tensorflow/core/summary/BUILD +++ b/tensorflow/core/summary/BUILD @@ -108,8 +108,8 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:png_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/png:png_io", ], ) diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index cd5beec6296..85586809014 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -27,6 +27,9 @@ cc_library( hdrs = ["tpu_embedding_optimization_parameters_utils.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", @@ -42,7 +45,6 @@ cc_library( hdrs = ["tpu_embedding_output_layout_utils.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", @@ -156,6 +158,7 @@ cc_library( deps = [ ":libtftpu_header", ":tpu_api", + ":tpu_api_dlsym_set_fn", ":tpu_compilation_device", ":tpu_config_c_api", ":tpu_executor_init_fns", @@ -174,10 +177,34 @@ cc_library( ], ) +# This is an alternative to "tpu_api_dlsym_initializer" that only initializes +# methods needed for the base TPU executor APIs (and thus has fewer deps). Do +# not link in both this and "tpu_api_dlsym_initializer". +cc_library( + name = "tpu_executor_dlsym_initializer", + srcs = ["tpu_executor_dlsym_initializer.cc"], + visibility = ["//visibility:public"], + deps = [ + ":tpu_api_dlsym_set_fn", + ":tpu_executor_init_fns", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/tpu:tpu_computation_placer", + "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", + ], + alwayslink = True, +) + +cc_library( + name = "tpu_api_dlsym_set_fn", + hdrs = ["tpu_api_dlsym_set_fn.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "tpu_library_init_fns", hdrs = ["tpu_library_init_fns.inc"], visibility = ["//visibility:public"], + deps = [":tpu_executor_init_fns"], ) cc_library( @@ -284,6 +311,7 @@ cc_library( cc_library( name = "tpu_on_demand_compiler", srcs = ["tpu_on_demand_compiler.cc"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 90107abf4b6..8de50acfd6c 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -5,13 +5,11 @@ load( "//tensorflow/core/platform:build_config.bzl", "tf_proto_library", ) +load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_kernel_library") # buildifier: disable=same-origin-load # Config setting to enable go/libtpu support. -WITH_TPU_SUPPORT = "//tensorflow:with_tpu_support" - -DEFAULT = "//conditions:default" package( default_visibility = [ @@ -44,10 +42,10 @@ cc_library( name = "tpu_compile_op_common", srcs = ["tpu_compile_op_common.cc"], hdrs = ["tpu_compile_op_common.h"], - deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_metrics"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"], - }) + [ + deps = if_libtpu( + [":tpu_compilation_metrics"], + ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"], + ) + [ ":tpu_compilation_cache_entry_unloader", ":tpu_compilation_cache_interface", ":tpu_compilation_metrics_hdrs", @@ -97,14 +95,10 @@ tf_kernel_library( name = "tpu_configuration_ops", srcs = ["tpu_configuration_ops.cc"], hdrs = ["tpu_configuration_ops.h"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), - deps = select({ - WITH_TPU_SUPPORT: [":tpu_util"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"], - }) + [ + deps = if_libtpu( + [":tpu_util"], + ["//tensorflow/core/tpu/kernels:tpu_util"], + ) + [ ":tpu_compilation_cache_factory", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_local_lookup", @@ -346,10 +340,10 @@ cc_library( name = "tpu_compilation_cache_interface", srcs = ["tpu_compilation_cache_interface.cc"], hdrs = ["tpu_compilation_cache_interface.h"], - deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_metrics"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"], - }) + [ + deps = if_libtpu( + [":tpu_compilation_metrics"], + ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"], + ) + [ ":compiled_subgraph", ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_entry", @@ -424,10 +418,7 @@ cc_library( cc_library( name = "tpu_compilation_metrics", srcs = ["tpu_compilation_metrics.cc"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), + copts = tf_copts(), deps = [ ":tpu_compilation_metrics_hdrs", ], @@ -529,14 +520,11 @@ cc_library( cc_library( name = "tpu_compilation_cache_rpc_support_hdrs", hdrs = ["tpu_compilation_cache_rpc_support.h"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), - deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], # build_cleaner: keep - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], # build_cleaner: keep - }) + [ + copts = tf_copts(), + deps = if_libtpu( + [":tpu_compilation_cache_proto_cc"], + ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], + ) + [ ":tpu_compilation_cache_entry", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_lookup", @@ -550,10 +538,7 @@ cc_library( cc_library( name = "tpu_compilation_cache_rpc_support", srcs = ["tpu_compilation_cache_rpc_support.cc"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), + copts = tf_copts(), deps = [ ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_proto_cc", @@ -572,14 +557,11 @@ cc_library( name = "tpu_compilation_cache_rpc_lookup", srcs = ["tpu_compilation_cache_rpc_lookup.cc"], hdrs = ["tpu_compilation_cache_rpc_lookup.h"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), - deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_cache_rpc_support"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support"], - }) + [ + copts = tf_copts(), + deps = if_libtpu( + [":tpu_compilation_cache_rpc_support"], + ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support"], + ) + [ ":tpu_compilation_cache_grpc", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_lookup", @@ -617,14 +599,11 @@ cc_library( name = "tpu_compilation_cache_grpc", srcs = ["tpu_compilation_cache_grpc.cc"], hdrs = ["tpu_compilation_cache_grpc.h"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), - deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], - }) + [ + copts = tf_copts(), + deps = if_libtpu( + [":tpu_compilation_cache_proto_cc"], + ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto"], + ) + [ ":tpu_compilation_cache_common_proto_cc", tf_grpc_cc_dependency(), ], @@ -634,20 +613,17 @@ cc_library( name = "tpu_compilation_cache_service", srcs = ["tpu_compilation_cache_service.cc"], hdrs = ["tpu_compilation_cache_service.h"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), - deps = select({ - WITH_TPU_SUPPORT: [ - ":tpu_compilation_cache_rpc_support", # build_cleaner: keep - ":tpu_compilation_cache_proto_cc", # build_cleaner: keep + copts = tf_copts(), + deps = if_libtpu( + [ + ":tpu_compilation_cache_rpc_support", + ":tpu_compilation_cache_proto_cc", ], - DEFAULT: [ - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", # build_cleaner: keep - "//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto", # build_cleaner: keep + [ + "//tensorflow/core/tpu/kernels:tpu_compilation_cache_rpc_support", + "//tensorflow/core/tpu/kernels:tpu_compilation_cache_cc_proto", ], - }) + [ + ) + [ ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_grpc", ":tpu_compilation_cache_interface", @@ -704,10 +680,7 @@ cc_library( name = "tpu_compile_op_impl", srcs = ["tpu_compile_op_impl.cc"], hdrs = ["tpu_compile_op_impl.h"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), + copts = tf_copts(), deps = [ ":tpu_compilation_cache_key", ":tpu_compile_c_api_hdrs", @@ -952,14 +925,11 @@ cc_library( name = "tpu_pod_state", srcs = ["tpu_pod_state.cc"], hdrs = ["tpu_pod_state.h"], - copts = select({ - WITH_TPU_SUPPORT: ["-DLIBTFTPU"], - DEFAULT: [], - }), - deps = select({ - WITH_TPU_SUPPORT: [":tpu_util"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_util"], - }) + [ + copts = tf_copts(), + deps = if_libtpu( + [":tpu_util"], + ["//tensorflow/core/tpu/kernels:tpu_util"], + ) + [ ":tpu_compilation_cache_service", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc index 207a60e7b48..c3aa62805c0 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc @@ -30,11 +30,11 @@ namespace tensorflow { namespace tpu { static const char* grpcTpuCompilationCacheService_method_names[] = { -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) "/tensorflow.tpu.TpuCompilationCacheServiceExternal/GetTpuProgram", -#else // LIBTFTPU +#else // LIBTPU_ON_GCE "/tensorflow.tpu.TpuCompilationCacheService/GetTpuProgram", -#endif // LIBTFTPU +#endif // LIBTPU_ON_GCE }; std::unique_ptr diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h index 324fc9e6f08..55877d15df2 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h @@ -35,7 +35,7 @@ limitations under the License. #include -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" #else #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara" @@ -48,7 +48,7 @@ namespace grpc { class TpuCompilationCacheService final { public: using RequestType = ::tensorflow::tpu::GetTpuProgramRequest; -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal; #else using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse; @@ -59,7 +59,7 @@ class TpuCompilationCacheService final { enum class MethodId { kGetTpuProgram = 0 }; static constexpr char const* service_full_name() { -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) return "tensorflow.tpu.TpuCompilationCacheServiceExternal"; #else return "tensorflow.tpu.TpuCompilationCacheService"; diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc index 0d0a2ae1b61..5ddce57807d 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc @@ -461,9 +461,11 @@ Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( << marked_for_eviction_size_ << " bytes)."; // Note that InitializeEntry() will Release/Reacquire mu_. entry = InitializeEntry(cache_key, compile_function, subgraph_key); + bool compilation_success = entry->tpu_program_group->program_count() > 0; TRACELITERAL("TPU host compilation cache: compilation done."); LOG(INFO) << strings::StrCat( - "TPU host compilation cache: compilation done for cache_key(", + "TPU host compilation cache: compilation ", + compilation_success ? "complete" : "failed", " for cache_key(", cache_key, "), session_name(", session_name, "), subgraph_key(", subgraph_key.debug_string, ")"); // If session_name is present, log some additional stats related to HBM diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc index 8b0fb674682..7846cc7bbb3 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc @@ -25,7 +25,7 @@ namespace tensorflow { namespace tpu { namespace { -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) using ResponseType = GetTpuProgramResponseExternal; #else using ResponseType = GetTpuProgramResponse; diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc index 9a6ca6be7e4..29ec8701a37 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/platform/casts.h" -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" #endif #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" @@ -30,7 +30,7 @@ std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() { return ::grpc::InsecureChannelCredentials(); // NOLINT } -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) template <> Status DeserializeRpcResponseToCacheEntry( absl::string_view local_proto_key, GetTpuProgramResponseExternal* response, @@ -156,6 +156,6 @@ xla::StatusOr> SerializeCacheEntryToBufferSlices( return std::vector<::grpc::Slice>{::grpc::Slice(encoded_header)}; } -#endif // LIBTFTPU +#endif // LIBTPU_ON_GCE } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_metrics.cc b/tensorflow/core/tpu/kernels/tpu_compilation_metrics.cc index e1a65ad0f32..ce982a1bd9a 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_metrics.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_metrics.cc @@ -19,7 +19,7 @@ namespace tpu { // TODO(henrytan): remove this once `TpuCompilationCache` migration to OSS is // completed. -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) /* static */ void TpuCompilationMetrics::IncrementCacheLookupCount( bool is_cache_hit, absl::string_view session_name) { @@ -36,7 +36,7 @@ void TpuCompilationMetrics::IncrementCompilationCount( absl::string_view session_name) { // A placeholder for tracking metrics. } -#endif // LIBTFTPU +#endif // LIBTPU_ON_GCE } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index 34ea31759ce..73669c21f1e 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -589,7 +589,8 @@ Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache( const std::string session_name = SessionNameFromMetadata(session_metadata); LOG(INFO) << "Compilation of " << key.prefix << " with session name " - << session_name << " took " << duration; + << session_name << " took " << duration << " and " + << (compile_status.ok() ? "succeeded" : "failed"); tpu_program_group->LogProgramMemorySummary(); metrics::UpdateXlaCompilationTime(absl::ToInt64Microseconds(duration)); TpuCompilationMetrics::IncrementCompilationCount(session_name); diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc index 8703dd818f5..270c2c53d7a 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc @@ -68,11 +68,11 @@ class TpuCompileOpImplFactory : public CompileOpImplFactory { } }; -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) REGISTER_MODULE_INITIALIZER(tpu_compile_op_impl_factory, { VLOG(1) << "register TpuCompileOpImplFactory()"; CompileOpImplFactory::Register(new TpuCompileOpImplFactory()); }); -#endif // LIBTFTPU +#endif // LIBTPU_ON_GCE } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_pod_state.cc b/tensorflow/core/tpu/kernels/tpu_pod_state.cc index 7b02998b343..898f02b28e9 100644 --- a/tensorflow/core/tpu/kernels/tpu_pod_state.cc +++ b/tensorflow/core/tpu/kernels/tpu_pod_state.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/tpu/tpu_api.h" -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) #include "tensorflow/core/tpu/kernels/tpu_util.h" #else #include "tensorflow/core/tpu/kernels/tpu_util.h" // copybara" @@ -54,7 +54,7 @@ xla::StatusOr> ConstructCacheService(ResourceMgr* rmgr, int serving_port, tpu::TpuCompilationCacheInterface* compilation_cache) { xla::StatusOr> server_builder; -#if defined(LIBTFTPU) +#if defined(LIBTPU_ON_GCE) server_builder = tpu::CreateServerBuilder(serving_port); #else server_builder = tpu::CreateServerBuilderGoogle(serving_port); diff --git a/tensorflow/core/tpu/libtftpu.h b/tensorflow/core/tpu/libtftpu.h index a4405df8205..9171af87061 100644 --- a/tensorflow/core/tpu/libtftpu.h +++ b/tensorflow/core/tpu/libtftpu.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #ifndef TENSORFLOW_CORE_TPU_LIBTFTPU_H_ #define TENSORFLOW_CORE_TPU_LIBTFTPU_H_ @@ -39,7 +41,7 @@ limitations under the License. extern "C" { #endif -TFTPU_CAPI_EXPORT void TfTpu_Initialize(); +TFTPU_CAPI_EXPORT void TfTpu_Initialize(bool init_library); #ifdef __cplusplus } diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc index 4dc09770c38..e4d723305a9 100644 --- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc +++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h" #if !defined(PLATFORM_GOOGLE) #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/core/tpu/tpu_node_device.h" @@ -27,13 +28,6 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/tpu_platform.h" #endif -#define TFTPU_SET_FN(Struct, FnName) \ - Struct->FnName##Fn = \ - reinterpret_cast(dlsym(library_handle, #FnName)); \ - if (!(Struct->FnName##Fn)) { \ - LOG(FATAL) << #FnName " not available in this library."; \ - return errors::Unimplemented(#FnName " not available in this library."); \ - } // Reminder: Update tpu_library_loader_windows.cc if you are adding new publicly // visible methods. @@ -55,10 +49,10 @@ Status InitializeTpuLibrary(void* library_handle) { // loaded. We do not want to register a TPU platform in XLA without the // supporting library providing the necessary APIs. if (s.ok()) { - void (*initialize_fn)(); + void (*initialize_fn)(bool init_library); initialize_fn = reinterpret_cast( dlsym(library_handle, "TfTpu_Initialize")); - (*initialize_fn)(); + (*initialize_fn)(/*init_library=*/true); RegisterTpuPlatform(); RegisterTpuSystemDevice(); diff --git a/tensorflow/core/tpu/tpu_api_dlsym_set_fn.h b/tensorflow/core/tpu/tpu_api_dlsym_set_fn.h new file mode 100644 index 00000000000..a1e13550d96 --- /dev/null +++ b/tensorflow/core/tpu/tpu_api_dlsym_set_fn.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_API_DLSYM_SET_FN_H_ +#define TENSORFLOW_CORE_TPU_TPU_API_DLSYM_SET_FN_H_ + +#define TFTPU_SET_FN(Struct, FnName) \ + Struct->FnName##Fn = \ + reinterpret_cast(dlsym(library_handle, #FnName)); \ + if (!(Struct->FnName##Fn)) { \ + LOG(FATAL) << #FnName " not available in this library."; \ + return errors::Unimplemented(#FnName " not available in this library."); \ + } + +#endif // TENSORFLOW_CORE_TPU_TPU_API_DLSYM_SET_FN_H_ diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc index 46633d22f9d..84d8ea70308 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" @@ -53,6 +56,8 @@ string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) { return "ProximalYogi"; case OptimizationAlgorithm::kFrequencyEstimator: return "FrequencyEstimator"; + case OptimizationAlgorithm::kUserDefinedProgram: + return "UserDefinedProgram"; case OptimizationAlgorithm::PARAMETERS_NOT_SET: return "*** Not set ***"; } @@ -89,6 +94,8 @@ string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) { return "proximal Yogi"; case OptimizationAlgorithm::kFrequencyEstimator: return "frequency estimator"; + case OptimizationAlgorithm::kUserDefinedProgram: + return "UserDefinedProgram"; case OptimizationAlgorithm::PARAMETERS_NOT_SET: return "unknown (not specified)"; } @@ -143,6 +150,26 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params, case OptimizationAlgorithm::kFrequencyEstimator: *count = 1; return Status::OK(); + case OptimizationAlgorithm::kUserDefinedProgram: { + const xla::ProgramShapeProto& program_shape = + params.user_defined_program().program().host_program_shape(); + + const int num_inputs = program_shape.parameters_size(); + const int num_outputs = program_shape.result().tuple_shapes_size(); + + if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) && + (num_inputs != num_outputs + 2))) { + return errors::InvalidArgument( + "User-defined TPU embedding optimizer program must have at least " + "two inputs and the number of outputs must be 1 or 2 less than the " + "number of inputs. Received ", + num_inputs, " input(s) and ", num_outputs, "output(s)."); + } + + *count = num_outputs - 1; + + return Status::OK(); + } case OptimizationAlgorithm::PARAMETERS_NOT_SET: return errors::InvalidArgument("No optimization algorithm specified"); } @@ -178,100 +205,109 @@ StateVariableSpecification MakeStandardStateVariableSpecification( Status GetOptimizationAlgorithmStateVariables( const OptimizationParameters& params, bool use_gradient_accumulation, std::vector* state_variables) { - // The first parameter set is always the weights themselves. - state_variables->push_back( - MakeStandardStateVariableSpecification("parameters", 0.0)); // The order of the returned parameters needs to match the offsets used by // the algorithm implementations in test_util.cc and // address_handler_program_creator.cc. + // The first parameter set is always the weights themselves. + auto add_state_variable = [&](const std::string& name, float value) { + state_variables->push_back( + MakeStandardStateVariableSpecification(name, value)); + }; switch (params.parameters_case()) { case OptimizationAlgorithm::kAdagrad: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); break; } case OptimizationAlgorithm::kBoundedAdagrad: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); break; } case OptimizationAlgorithm::kStochasticGradientDescent: { - // None. + add_state_variable("parameters", 0.0); break; } case OptimizationAlgorithm::kFtrl: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("linears", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); + add_state_variable("linears", 0.0); break; } case OptimizationAlgorithm::kAdam: { - state_variables->push_back( - MakeStandardStateVariableSpecification("momenta", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("velocities", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("momenta", 0.0); + add_state_variable("velocities", 0.0); break; } case OptimizationAlgorithm::kMomentum: { - state_variables->push_back( - MakeStandardStateVariableSpecification("momenta", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("momenta", 0.0); break; } case OptimizationAlgorithm::kRmsProp: { - state_variables->push_back( - MakeStandardStateVariableSpecification("ms", 1.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("mom", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("ms", 1.0); + add_state_variable("mom", 0.0); break; } case OptimizationAlgorithm::kCenteredRmsProp: { - state_variables->push_back( - MakeStandardStateVariableSpecification("ms", 1.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("mom", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("mg", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("ms", 1.0); + add_state_variable("mom", 0.0); + add_state_variable("mg", 0.0); break; } case OptimizationAlgorithm::kMdlAdagradLight: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("weights", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("benefits", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); + add_state_variable("weights", 0.0); + add_state_variable("benefits", 0.0); break; } case OptimizationAlgorithm::kAdadelta: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.0)); - state_variables->push_back( - MakeStandardStateVariableSpecification("updates", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.0); + add_state_variable("updates", 0.0); break; } case OptimizationAlgorithm::kProximalAdagrad: { - state_variables->push_back( - MakeStandardStateVariableSpecification("accumulators", 0.1)); + add_state_variable("parameters", 0.0); + add_state_variable("accumulators", 0.1); break; } case OptimizationAlgorithm::kOnlineYogi: { - state_variables->push_back( - MakeStandardStateVariableSpecification("vs", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("linears", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("vs", 0.1); + add_state_variable("linears", 0.0); break; } case OptimizationAlgorithm::kProximalYogi: { - state_variables->push_back( - MakeStandardStateVariableSpecification("v", 0.1)); - state_variables->push_back( - MakeStandardStateVariableSpecification("m", 0.0)); + add_state_variable("parameters", 0.0); + add_state_variable("v", 0.1); + add_state_variable("m", 0.0); break; } case OptimizationAlgorithm::kFrequencyEstimator: { - state_variables->push_back( - MakeStandardStateVariableSpecification("last_hit_step", 0)); + add_state_variable("parameters", 0.0); + add_state_variable("last_hit_step", 0); + break; + } + case OptimizationAlgorithm::kUserDefinedProgram: { + add_state_variable("parameters", + params.user_defined_program().padding_values(0)); + int num_slots = -1; + TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots)); + if (num_slots + 1 != + params.user_defined_program().padding_values_size()) { + return errors::InvalidArgument( + "Number of slots does not agree with the number of padding values " + "specified."); + } + for (int i = 0; i < num_slots; ++i) { + add_state_variable(absl::StrCat("Slot_", i), + params.user_defined_program().padding_values(i + 1)); + } break; } case OptimizationAlgorithm::PARAMETERS_NOT_SET: { @@ -313,6 +349,7 @@ std::vector GetOptimizationAlgorithms() { OptimizationAlgorithm::kOnlineYogi, OptimizationAlgorithm::kProximalYogi, OptimizationAlgorithm::kFrequencyEstimator, + OptimizationAlgorithm::kUserDefinedProgram, }; } diff --git a/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc b/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc index 3a027757af7..6ecab5ee981 100644 --- a/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc @@ -20,75 +20,17 @@ limitations under the License. namespace tensorflow { namespace tpu { -void AddDefaultEmbeddingOutputLayoutIfNeeded( - TPUEmbeddingConfiguration* config) { - if (config->has_output_layout()) { - // Model or previous step has already filled this in. - return; - } - - TPUEmbeddingOutputLayout* layout = config->mutable_output_layout(); - // Create output tensors. - for (const auto& table : config->table_descriptor()) { - TPUEmbeddingOutputLayout::EmbeddingOutputTensor* output = - layout->add_output(); - TPUEmbeddingOutputLayout::TwoDOutputTensor* two_d = output->mutable_two_d(); - two_d->set_dim1_size(table.dimension()); - two_d->set_dim0_size_per_sample(table.num_features()); - } - - // Create table output locations. - for (int table_id = 0; table_id < config->table_descriptor_size(); - ++table_id) { - TPUEmbeddingOutputLayout::TableDescriptor* output_table = - layout->add_table(); - const auto& table = config->table_descriptor(table_id); - for (int feature_index = 0; feature_index < table.num_features(); - ++feature_index) { - TPUEmbeddingOutputLayout::FeatureDescriptor* output_feature = - output_table->add_feature(); - TPUEmbeddingOutputLayout::OutputLocation* output_location = - output_feature->add_output_location(); - output_location->set_tensor_index(table_id); - output_location->set_dim0_offset(feature_index); - output_location->set_dim1_offset(0); - } - } -} - Status ComputeOutputTensorShapes(const TPUEmbeddingConfiguration& config, std::vector* shapes) { - if (!config.has_output_layout()) { - return errors::InvalidArgument( - "TPUEmbeddingConfiguration is missing output layout."); - } - const TPUEmbeddingOutputLayout& layout = config.output_layout(); int batch_size = config.batch_size_per_tensor_core(); - for (int i = 0; i < layout.output_size(); ++i) { - const auto& output = layout.output(i); + for (const TPUEmbeddingConfiguration::TableDescriptor& table : + config.table_descriptor()) { TensorShapeProto shape; - switch (output.output_format_case()) { - case TPUEmbeddingOutputLayout::EmbeddingOutputTensor::OutputFormatCase:: - kTwoD: { - auto* dim0 = shape.add_dim(); - dim0->set_size(output.two_d().dim0_size_per_sample() * batch_size); - auto* dim1 = shape.add_dim(); - dim1->set_size(output.two_d().dim1_size()); - break; - } - case TPUEmbeddingOutputLayout::EmbeddingOutputTensor::OutputFormatCase:: - OUTPUT_FORMAT_NOT_SET: { - return errors::InvalidArgument( - "Output layout in TPUEmbeddingConfiguration has unset embedding " - "output tensor format."); - } - default: { - return errors::InvalidArgument( - "Output layout in TPUEmbeddingConfiguration has invalid or " - "unhandled embedding output tensor format."); - } - } + auto* dim0 = shape.add_dim(); + dim0->set_size(table.num_features() * batch_size); + auto* dim1 = shape.add_dim(); + dim1->set_size(table.dimension()); shapes->push_back(shape); } return Status::OK(); diff --git a/tensorflow/core/tpu/tpu_embedding_output_layout_utils.h b/tensorflow/core/tpu/tpu_embedding_output_layout_utils.h index 5bff401b9d2..177ad8e3055 100644 --- a/tensorflow/core/tpu/tpu_embedding_output_layout_utils.h +++ b/tensorflow/core/tpu/tpu_embedding_output_layout_utils.h @@ -23,11 +23,7 @@ limitations under the License. namespace tensorflow { namespace tpu { -// Creates a default output layout for compatibility if none was provided by the -// model. -void AddDefaultEmbeddingOutputLayoutIfNeeded(TPUEmbeddingConfiguration* config); - -// Computes the shape of the output tensors from an output layout. +// Computes the shape of the output tensors from an embedding configuration. Status ComputeOutputTensorShapes( const TPUEmbeddingConfiguration& config, std::vector* shapes); diff --git a/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc new file mode 100644 index 00000000000..4d84781f4e3 --- /dev/null +++ b/tensorflow/core/tpu/tpu_executor_dlsym_initializer.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(skye): this is largely a copy of tpu_api_dlsym_initializer.cc. Figure +// out how to deduplicate these files a little. + +#include + +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/tpu_api_dlsym_set_fn.h" +#if !defined(PLATFORM_GOOGLE) +#include "tensorflow/core/tpu/tpu_executor_api.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_platform.h" +#endif + +namespace tensorflow { +namespace tpu { + +#if defined(PLATFORM_GOOGLE) +Status InitializeTpuLibrary(void* library_handle) { + return errors::Unimplemented("You must statically link in a TPU library."); +} +#else // PLATFORM_GOOGLE +#include "tensorflow/core/tpu/tpu_executor_init_fns.inc" + +Status InitializeTpuLibrary(void* library_handle) { + Status s = SetExecutorStructFn(library_handle); + + // TPU platform registration must only be performed after the library is + // loaded. We do not want to register a TPU platform in XLA without the + // supporting library providing the necessary APIs. + if (s.ok()) { + void (*initialize_fn)(); + initialize_fn = reinterpret_cast( + dlsym(library_handle, "TfTpu_Initialize")); + (*initialize_fn)(); + + RegisterTpuPlatform(); + } + + return s; +} + +bool FindAndLoadTpuLibrary() { + void* library = dlopen("libtpu.so", RTLD_NOW); + if (library) { + InitializeTpuLibrary(library); + } + return true; +} + +static bool tpu_library_finder = FindAndLoadTpuLibrary(); +#endif // PLATFORM_GOOGLE + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_executor_init_fns.inc b/tensorflow/core/tpu/tpu_executor_init_fns.inc index 39f4aedb330..c2df15661f0 100644 --- a/tensorflow/core/tpu/tpu_executor_init_fns.inc +++ b/tensorflow/core/tpu/tpu_executor_init_fns.inc @@ -112,6 +112,8 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) { TFTPU_SET_FN(executor_fn, TpuTopology_LogicalDevicesPerHost); TFTPU_SET_FN(executor_fn, TpuTopology_LogicalDevicesPerChip); + TFTPU_SET_FN(executor_fn, TpuTopology_HostCount); + TFTPU_SET_FN(executor_fn, TpuTopology_ChipsPerHost); TFTPU_SET_FN(executor_fn, TpuTopology_ChipBounds_X); TFTPU_SET_FN(executor_fn, TpuTopology_ChipBounds_Y); TFTPU_SET_FN(executor_fn, TpuTopology_ChipBounds_Z); diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc index 08153f62063..c34a13a45dc 100644 --- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc +++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc @@ -119,7 +119,6 @@ class TpuExecutable : public TpuExecutableInterface { } ApiConverter::ToC(arg.shape(), &se_args[i]->dynamic_shape); - ApiConverter::ToC(arg.host_shape(), &se_args[i]->host_shape); const auto& unowned_indices = arg.unowned_indices(); se_args[i]->unowned_indices_size = unowned_indices.size(); se_args[i]->unowned_indices = new XLA_ShapeIndex[unowned_indices.size()]; @@ -142,7 +141,6 @@ class TpuExecutable : public TpuExecutableInterface { for (int i = 0; i < arguments.size(); ++i) { ApiConverter::Free(&se_args[i]->shape_tree.shape); ApiConverter::Free(&se_args[i]->dynamic_shape); - ApiConverter::Free(&se_args[i]->host_shape); delete[] se_args[i]->unowned_indices; delete[] se_args[i]->shape_tree.buffers; delete se_args[i]; @@ -278,16 +276,6 @@ class TpuCompiler : public Compiler { return HloModule::CreateFromProto(result_proto, module->config()); } - StatusOr< - std::tuple, std::unique_ptr>> - RunHloPassesAndBufferAssignement( - std::unique_ptr module, - stream_executor::StreamExecutor* executor, - stream_executor::DeviceMemoryAllocator* device_allocator) override { - return Unimplemented( - "This compiler does not support RunHloPassesAndBufferAssignment."); - } - StatusOr> RunBackend( std::unique_ptr module, stream_executor::StreamExecutor* executor, diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index ab909078cf0..b9342638592 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -453,6 +453,7 @@ cc_library( deps = [ ":test_log_proto_impl_cc", "//tensorflow/core/platform:env", + "//tensorflow/core/platform:env_impl", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:mutex", @@ -757,7 +758,6 @@ tf_cc_tests( "//tensorflow/core", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:direct_session_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -767,6 +767,7 @@ tf_cc_tests( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:direct_session_internal", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/platform:regexp", "//third_party/eigen3", diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index 03acb98989b..16f982f11b2 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -2285,12 +2285,13 @@ Status ParseContextDenseFeatures(const FeatureProtosMap& context_features, context_features.find(c.feature_name)->second; TensorShape dense_shape, example_shape; DataType dtype = c.dtype; - const size_t expected_max_elements = feature.length; + const size_t data_max_elements = feature.length; if (!c.shape.AsTensorShape(&example_shape) || - expected_max_elements != example_shape.num_elements()) { + data_max_elements != example_shape.num_elements()) { return errors::InvalidArgument( - "Inconsistent number of elements for feature ", c.feature_name, ": ", - expected_max_elements, " vs ", dense_shape.num_elements()); + "Inconsistent max number of elements for feature ", c.feature_name, + ": expected ", example_shape.num_elements(), ", but found ", + data_max_elements); } if (is_batch) { dense_shape.AddDim(num_examples); @@ -2324,7 +2325,7 @@ Status ParseContextDenseFeatures(const FeatureProtosMap& context_features, EnableAliasing(&stream); num_elements += ParseFeature(dtype, &stream, &out, &out_offset); } - if (num_elements != expected_max_elements) { + if (num_elements != data_max_elements) { return errors::InvalidArgument( "Unexpected number of elements in example ", ExampleName(example_names, e)); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index d2133244363..74f9cd100a8 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -893,21 +893,36 @@ inline void SetDummyMklDnnShapeOutput(OpKernelContext* context, AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output); } -inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, +// If the input tensor has ref count as 1, it is forwarded to the desired +// output port and the function returns true. In that case, it also allocates +// the serialized MklDnnShape object. Otherwise, the function returns false. +inline bool ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, int idx_in, int idx_out, - const MklDnnShape& mkl_shape) { + Tensor** output, + const MklDnnShape& mkl_shape, + bool always_forward = true) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); int idx_data_in = GetTensorDataIndex(idx_in, num_inputs); int idx_data_out = GetTensorDataIndex(idx_out, num_outputs); - - AllocateOutputSetMklShape(context, idx_out, mkl_shape); - - if (IsRefType(context->input_dtype(idx_data_in))) { - context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); + bool is_forwarded = false; + const Tensor& input_tensor = context->input(idx_data_in); + const auto output_shape = input_tensor.shape(); + if (always_forward) { + if (IsRefType(context->input_dtype(idx_data_in))) { + context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out); + } else { + context->set_output(idx_data_out, input_tensor); + } } else { - context->set_output(idx_data_out, context->input(idx_data_in)); + is_forwarded = context->forward_input_to_output_with_shape( + idx_data_in, idx_data_out, output_shape, output); } + if (is_forwarded || always_forward) { + AllocateOutputSetMklShape(context, idx_out, mkl_shape); + return true; + } + return false; } // Forward the MKL shape ONLY (used in elementwise and other ops where diff --git a/tensorflow/examples/adding_an_op/BUILD b/tensorflow/examples/adding_an_op/BUILD index 22f3e29b052..f40a310728c 100644 --- a/tensorflow/examples/adding_an_op/BUILD +++ b/tensorflow/examples/adding_an_op/BUILD @@ -14,8 +14,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tf_custom_op_library( name = "zero_out_op_kernel_1.so", srcs = ["zero_out_op_kernel_1.cc"], diff --git a/tensorflow/examples/android/gradle/wrapper/gradle-wrapper.jar b/tensorflow/examples/android/gradle/wrapper/gradle-wrapper.jar deleted file mode 100644 index 13372aef5e2..00000000000 Binary files a/tensorflow/examples/android/gradle/wrapper/gradle-wrapper.jar and /dev/null differ diff --git a/tensorflow/examples/android/gradle/wrapper/gradle-wrapper.properties b/tensorflow/examples/android/gradle/wrapper/gradle-wrapper.properties deleted file mode 100644 index 4a0bf945ec9..00000000000 --- a/tensorflow/examples/android/gradle/wrapper/gradle-wrapper.properties +++ /dev/null @@ -1,6 +0,0 @@ -#Thu Jul 23 05:42:16 CEST 2020 -distributionBase=GRADLE_USER_HOME -distributionPath=wrapper/dists -zipStoreBase=GRADLE_USER_HOME -zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-all.zip diff --git a/tensorflow/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/examples/android/res/drawable-hdpi/ic_action_info.png deleted file mode 100644 index 32bd1aabcab..00000000000 Binary files a/tensorflow/examples/android/res/drawable-hdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/examples/android/res/drawable-hdpi/ic_launcher.png deleted file mode 100644 index b3113cd15c3..00000000000 Binary files a/tensorflow/examples/android/res/drawable-hdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/examples/android/res/drawable-hdpi/tile.9.png deleted file mode 100644 index 135862883e2..00000000000 Binary files a/tensorflow/examples/android/res/drawable-hdpi/tile.9.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/examples/android/res/drawable-mdpi/ic_action_info.png deleted file mode 100644 index 8efbbf8b3c4..00000000000 Binary files a/tensorflow/examples/android/res/drawable-mdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/examples/android/res/drawable-mdpi/ic_launcher.png deleted file mode 100644 index 51f87ee6507..00000000000 Binary files a/tensorflow/examples/android/res/drawable-mdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/examples/android/res/drawable-xhdpi/ic_action_info.png deleted file mode 100644 index ba143ea7a80..00000000000 Binary files a/tensorflow/examples/android/res/drawable-xhdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/examples/android/res/drawable-xhdpi/ic_launcher.png deleted file mode 100644 index 6361d792dac..00000000000 Binary files a/tensorflow/examples/android/res/drawable-xhdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/examples/android/res/drawable-xxhdpi/ic_action_info.png deleted file mode 100644 index 394eb7e5349..00000000000 Binary files a/tensorflow/examples/android/res/drawable-xxhdpi/ic_action_info.png and /dev/null differ diff --git a/tensorflow/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/examples/android/res/drawable-xxhdpi/ic_launcher.png deleted file mode 100644 index 2e27bec9785..00000000000 Binary files a/tensorflow/examples/android/res/drawable-xxhdpi/ic_launcher.png and /dev/null differ diff --git a/tensorflow/examples/ios/.gitignore b/tensorflow/examples/ios/.gitignore deleted file mode 100644 index dbabfb33bf1..00000000000 --- a/tensorflow/examples/ios/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -project.xcworkspace -xcuserdata -imagenet_comp_graph_label_strings.txt -tensorflow_inception_graph.pb -simple/data/LICENSE -camera/data/LICENSE -benchmark/data/LICENSE diff --git a/tensorflow/examples/ios/README.md b/tensorflow/examples/ios/README.md deleted file mode 100644 index 64412d25a00..00000000000 --- a/tensorflow/examples/ios/README.md +++ /dev/null @@ -1,194 +0,0 @@ -# TensorFlow iOS Examples - -This folder contains examples of how to build applications for iOS devices using TensorFlow. - -## Running the Samples using CocoaPod - - You'll need Xcode 7.3 or later. - - - There are currently three examples: simple, benchmark, and camera. For now, - you can download the sample code by cloning the main tensorflow repository - (we are planning to make the samples available as a separate repository - later). - - - From the root of the tensorflow folder, download - [Inception v1](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip), - and extract the label and graph files into the data folders inside both the - simple and camera examples: - -```bash -mkdir -p ~/graphs -curl -o ~/graphs/inception5h.zip \ - https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \ - && unzip ~/graphs/inception5h.zip -d ~/graphs/inception5h -cp ~/graphs/inception5h/* tensorflow/examples/ios/benchmark/data/ -cp ~/graphs/inception5h/* tensorflow/examples/ios/camera/data/ -cp ~/graphs/inception5h/* tensorflow/examples/ios/simple/data/ -``` - - - Change directory to one of the samples, download the TensorFlow-experimental - pod, and open the Xcode workspace. Observe: installing the pod can take a - long time since it is big (~450MB). For example, if you want to run the - simple example, then: -```bash -cd tensorflow/examples/ios/simple -pod install -open tf_simple_example.xcworkspace # obs, not the .xcodeproj directory -``` - - - Run the simple app in the simulator. You should see a single-screen app with - a "Run Model" button. Tap that, and you should see some debug output appear - below indicating that the example Grace Hopper image in directory data has - been analyzed, with a military uniform recognized. - - - Run the other samples using the same process. The camera example requires a - real device connected. Once you build and run that, you should get a live - camera view that you can point at objects to get real-time recognition - results. - -### Troubleshooting - - - Make sure you use the TensorFlow-experimental pod (and not TensorFlow). - - - The TensorFlow-experimental pod is current about ~450MB. The reason it is - so big is because we are bundling multiple platforms, and the pod includes - all TensorFlow functionality (e.g. operations). The final app size after - build is substantially smaller though (~25MB). Working with the complete - pod is convenient during development, but see below section on how you can - build your own custom TensorFlow library to reduce the size. - -### Creating Your own App - - - Create your own app using Xcode then add a file named Podfile at the project - root directory with the following content: -```bash -target 'YourProjectName' - pod 'TensorFlow-experimental' -``` - - - Then you run ```pod install``` to download and install the - TensorFlow-experimental pod, and finally perform - ```open YourProjectName.xcworkspace``` and add your code. - - - In your apps "Build Settings", make sure to add $(inherited) to sections - "Other Linker Flags", and "Header Search Paths". - - - That's it. If you want to create your custom TensorFlow iOS library, for - example to reduce binary footprint, see below section. - -## Building the TensorFlow iOS libraries from source - - - You'll need Xcode 7.3 or later, with the command-line tools installed. - - - Follow the instructions at - [tensorflow/contrib/makefile](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile) - under "iOS" to compile a static library containing the core TensorFlow code. - - - You should see a single-screen app with a "Run Model" button. Tap that, and - you should see some debug output appear below indicating that the example - Grace Hopper image has been analyzed, with a military uniform recognized. - - - Once you have success there, make sure you have a real device connected and - open up the Xcode project in the `camera` subfolder. Once you build and run - that, you should get a live camera view that you can point at objects to get - real-time recognition results. - -### Troubleshooting - -If you're hitting problems, here's a checklist of common things to investigate: - - - Make sure that you've run the `build_all_ios.sh` script. - This will run `download_dependencies.sh`,`compile_ios_protobuf.sh` and `compile_ios_tensorflow.sh`. - (check each one if they have run successful.) - - - Check that you have version 7.3 of Xcode. - - - If there's a complaint about no Sessions registered, that means that the C++ - global constructors that TensorFlow relies on for registration haven't been - linked in properly. You'll have to make sure your project uses force_load, as - described below. - -### Creating your Own App from your source libraries - -You'll need to update various settings in your app to link against -TensorFlow. You can view them in the example projects, but here's a full -rundown: - - - The `compile_ios_tensorflow.sh` script builds a universal static library in - `tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a`. You'll need to add - this to your linking build stage, and in Search Paths add - `tensorflow/contrib/makefile/gen/lib` to the Library Search Paths setting. - - - You'll also need to add `libprotobuf.a` and `libprotobuf-lite.a` from - `tensorflow/contrib/makefile/gen/protobuf_ios/lib` - and `nsync.a` from `tensorflow/contrib/makefile/downloads/nsync/builds/lipo.ios.c++11` - to your _Build Stages_ and _Library Search Paths_. - - - The _Header Search_ paths needs to contain: - - the root folder of tensorflow, - - `tensorflow/contrib/makefile/downloads/nsync/public` - - `tensorflow/contrib/makefile/downloads/protobuf/src` - - `tensorflow/contrib/makefile/downloads`, - - `tensorflow/contrib/makefile/downloads/eigen`, and - - `tensorflow/contrib/makefile/gen/proto`. - - - In the Linking section, you need to add `-force_load` followed by the path to - the TensorFlow static library in the _Other Linker_ Flags section. This ensures - that the global C++ objects that are used to register important classes - inside the library are not stripped out. To the linker, they can appear - unused because no other code references the variables, but in fact their - constructors have the important side effect of registering the class. - - - You'll need to include the Accelerate framework in the "Link Binary with - Libraries" build phase of your project. - - - C++11 support (or later) should be enabled by setting `C++ Language Dialect` to - `GNU++11` (or `GNU++14`), and `C++ Standard Library` to `libc++`. - - - The library doesn't currently support bitcode, so you'll need to disable that - in your project settings. - - - Remove any use of the `-all_load` flag in your project. The protocol buffers - libraries (full and lite versions) contain duplicate symbols, and the - `-all_load` flag will cause these duplicates to become link errors. If you - were using `-all_load` to avoid issues with Objective-C categories in static - libraries, you may be able to replace it with the `-ObjC` flag. - -### Reducing the binary size - -TensorFlow is a comparatively large library for a mobile device, so it will -increase the size of your app. Currently on iOS we see around a 11 MB binary -footprint per CPU architecture, though we're actively working on reducing that. -It can be tricky to set up the right configuration in your own app to keep the -size minimized, so if you do run into this issue we recommend you start by -looking at the simple example to examine its size. Here's how you do that: - - - Open the Xcode project in tensorflow/examples/ios/simple. - - - Make sure you've followed the steps above to get the data files. - - - Choose "Generic iOS Device" as the build configuration. - - - Select Product->Build. - - - Once the build's complete, open the Report Navigator and select the logs. - - - Near the bottom, you'll see a line saying "Touch tf_simple_example.app". - - - Expand that line using the icon on the right, and copy the first argument to - the Touch command. - - - Go to the terminal, type `ls -lah ` and then paste the path you copied. - - - For example it might look like `ls -lah /Users/petewarden/Library/Developer/Xcode/DerivedData/tf_simple_example-etdbksqytcnzeyfgdwiihzkqpxwr/Build/Products/Debug-iphoneos/tf_simple_example.app` - - - Running this command will show the size of the executable as the - `tf_simple_example` line. - -Right now you'll see a size of around 25 MB, since it's including two -architectures (armv7 and arm64). As a first step, you should make sure the size -increase you see in your own app is similar, and if it's larger, look at the -"Other Linker Flags" used in the Simple Xcode project settings to strip the -executable. - -For further optimization, please refer to the ["Optimization" section](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/makefile#optimization) -of the makefile instructions. diff --git a/tensorflow/examples/ios/benchmark/AppDelegate.h b/tensorflow/examples/ios/benchmark/AppDelegate.h deleted file mode 100644 index 94046d97282..00000000000 --- a/tensorflow/examples/ios/benchmark/AppDelegate.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface AppDelegate : UIResponder - -@property(strong, nonatomic) UIWindow *window; - -@end diff --git a/tensorflow/examples/ios/benchmark/AppDelegate.mm b/tensorflow/examples/ios/benchmark/AppDelegate.mm deleted file mode 100644 index 23ffba0f7b8..00000000000 --- a/tensorflow/examples/ios/benchmark/AppDelegate.mm +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "AppDelegate.h" - -#import "BenchmarkViewController.h" - -@implementation AppDelegate - -- (BOOL)application:(UIApplication *)application - didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - - UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[BenchmarkViewController alloc] init]]]; - bar.selectedIndex = 0; - self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; - self.window.rootViewController = bar; - [self.window makeKeyAndVisible]; - return YES; -} - -- (void)applicationWillResignActive:(UIApplication *)application {} - -- (void)applicationDidEnterBackground:(UIApplication *)application {} - -- (void)applicationWillEnterForeground:(UIApplication *)application {} - -- (void)applicationDidBecomeActive:(UIApplication *)application {} - -- (void)applicationWillTerminate:(UIApplication *)application {} - -@end diff --git a/tensorflow/examples/ios/benchmark/Benchmark-Info.plist b/tensorflow/examples/ios/benchmark/Benchmark-Info.plist deleted file mode 100644 index 0cdbf28a31b..00000000000 --- a/tensorflow/examples/ios/benchmark/Benchmark-Info.plist +++ /dev/null @@ -1,47 +0,0 @@ - - - - - CFBundleDevelopmentRegion - en - CFBundleDisplayName - tf_benchmark_example - CFBundleExecutable - tf_benchmark_example - CFBundleIdentifier - com.google.tf_benchmark_example - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - ios-app - CFBundlePackageType - APPL - CFBundleShortVersionString - 1.0 - CFBundleSignature - ???? - CFBundleVersion - 1.0 - LSRequiresIPhoneOS - - UILaunchStoryboardName - BenchmarkViewController - UIRequiredDeviceCapabilities - - armv7 - - UISupportedInterfaceOrientations - - UIInterfaceOrientationPortrait - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - UISupportedInterfaceOrientations~ipad - - UIInterfaceOrientationPortrait - UIInterfaceOrientationPortraitUpsideDown - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - - diff --git a/tensorflow/examples/ios/benchmark/BenchmarkViewController.h b/tensorflow/examples/ios/benchmark/BenchmarkViewController.h deleted file mode 100644 index c9cbc492809..00000000000 --- a/tensorflow/examples/ios/benchmark/BenchmarkViewController.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface BenchmarkViewController : UIViewController - -- (IBAction)getUrl:(id)sender; - -@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property(weak, nonatomic) IBOutlet UITextField *urlTextField; - -@end diff --git a/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm b/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm deleted file mode 100644 index 9fc5f6ded24..00000000000 --- a/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "BenchmarkViewController.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/stat_summarizer.h" - -#include "ios_image_load.h" - -NSString* RunInferenceOnImage(); - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const std::string& file_name) - : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return (int)ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -@interface BenchmarkViewController () -@end - -@implementation BenchmarkViewController { -} - -- (IBAction)getUrl:(id)sender { - NSString* inference_result = RunInferenceOnImage(); - self.urlContentTextView.text = inference_result; -} - -@end - -// Returns the top N confidence values over threshold in the provided vector, -// sorted by confidence in descending order. -static void GetTopN( - const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector>* top_results) { - // Will contain top N results in ascending order. - std::priority_queue, std::vector>, - std::greater>> - top_result_pq; - - long count = prediction.size(); - for (int i = 0; i < count; ++i) { - const float value = prediction(i); - - // Only add it if it beats the threshold and has a chance at being in - // the top N. - if (value < threshold) { - continue; - } - - top_result_pq.push(std::pair(value, i)); - - // If at capacity, kick the smallest value out. - if (top_result_pq.size() > num_results) { - top_result_pq.pop(); - } - } - - // Copy to output vector and reverse into descending order. - while (!top_result_pq.empty()) { - top_results->push_back(top_result_pq.top()); - top_result_pq.pop(); - } - std::reverse(top_results->begin(), top_results->end()); -} - -bool PortableReadFileToProto(const std::string& file_name, - ::google::protobuf::MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(file_name)); - stream.SetOwnsCopyingStream(true); - // TODO(jiayq): the following coded stream is for debugging purposes to allow - // one to parse arbitrarily large messages for MessageLite. One most likely - // doesn't want to put protobufs larger than 64MB on Android, so we should - // eventually remove this and quit loud when a large protobuf is passed in. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -NSString* FilePathForResourceName(NSString* name, NSString* extension) { - NSString* file_path = - [[NSBundle mainBundle] pathForResource:name ofType:extension]; - if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; - } - return file_path; -} - -// A utility function to get the current time in seconds, for simple profiling. -double time() { - timeval t; - gettimeofday(&t, nullptr); - return t.tv_sec + 1e-6 * t.tv_usec; -} - -// Runs the session with profiling enabled, and prints out details of the time -// that each node in the graph takes to the debug log. -tensorflow::Status BenchmarkInference( - tensorflow::Session* session, - const std::vector> inputs, - const std::vector& output_layer_names, - std::vector* output_layers, - tensorflow::StatSummarizer* stat_summarizer, double* average_time) { - tensorflow::Status run_status; - const int iterations_count = 20; - double total_time = 0.0; - tensorflow::RunOptions run_options; - run_options.set_trace_level(tensorflow::RunOptions::FULL_TRACE); - tensorflow::RunMetadata run_metadata; - for (int iteration = 0; iteration < (iterations_count + 1); ++iteration) { - const double start_time = time(); - run_status = session->Run(run_options, inputs, output_layer_names, {}, - output_layers, &run_metadata); - const double end_time = time(); - if (iteration != 0) { - total_time += end_time - start_time; - } - if (!run_status.ok()) { - LOG(ERROR) << "Running model failed: " << run_status; - tensorflow::LogAllRegisteredKernels(); - return run_status; - } - } - assert(run_metadata.has_step_stats()); - const tensorflow::StepStats& step_stats = run_metadata.step_stats(); - stat_summarizer->ProcessStepStats(step_stats); - stat_summarizer->PrintStepStats(); - - *average_time = total_time / iterations_count; - NSLog(@"Took %f seconds", *average_time); - - return tensorflow::Status::OK(); -} - -NSString* RunInferenceOnImage() { - tensorflow::SessionOptions options; - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = - tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - std::string status_string = session_status.ToString(); - return [NSString - stringWithFormat:@"Session create failed - %s", status_string.c_str()]; - } - std::unique_ptr session(session_pointer); - LOG(INFO) << "Session created."; - - tensorflow::GraphDef tensorflow_graph; - LOG(INFO) << "Graph created."; - - NSString* network_path = - FilePathForResourceName(@"tensorflow_inception_graph", @"pb"); - PortableReadFileToProto([network_path UTF8String], &tensorflow_graph); - - LOG(INFO) << "Creating session."; - tensorflow::Status s = session->Create(tensorflow_graph); - if (!s.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << s; - return @""; - } - - // Read the label list - NSString* labels_path = - FilePathForResourceName(@"imagenet_comp_graph_label_strings", @"txt"); - std::vector label_strings; - std::ifstream t; - t.open([labels_path UTF8String]); - std::string line; - while (t) { - std::getline(t, line); - label_strings.push_back(line); - } - t.close(); - - // Read the Grace Hopper image. - NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); - int image_width; - int image_height; - int image_channels; - std::vector image_data = LoadImageFromFile( - [image_path UTF8String], &image_width, &image_height, &image_channels); - const int wanted_width = 224; - const int wanted_height = 224; - const int wanted_channels = 3; - const float input_mean = 117.0f; - const float input_std = 1.0f; - assert(image_channels >= wanted_channels); - tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, - tensorflow::TensorShape( - {1, wanted_height, wanted_width, wanted_channels})); - auto image_tensor_mapped = image_tensor.tensor(); - tensorflow::uint8* in = image_data.data(); - float* out = image_tensor_mapped.data(); - for (int y = 0; y < wanted_height; ++y) { - const int in_y = (y * image_height) / wanted_height; - tensorflow::uint8* in_row = in + (in_y * image_width * image_channels); - float* out_row = out + (y * wanted_width * wanted_channels); - for (int x = 0; x < wanted_width; ++x) { - const int in_x = (x * image_width) / wanted_width; - tensorflow::uint8* in_pixel = in_row + (in_x * image_channels); - float* out_pixel = out_row + (x * wanted_channels); - for (int c = 0; c < wanted_channels; ++c) { - out_pixel[c] = (in_pixel[c] - input_mean) / input_std; - } - } - } - tensorflow::string input_layer = "input"; - tensorflow::string output_layer = "output"; - std::vector outputs; - tensorflow::StatSummarizer stat_summarizer(tensorflow_graph); - double average_time = 0.0; - BenchmarkInference(session.get(), {{input_layer, image_tensor}}, - {output_layer}, &outputs, &stat_summarizer, &average_time); - NSString* result = - [NSString stringWithFormat:@"Average time: %.4f seconds \n\n", average_time]; - - tensorflow::Tensor* output = &outputs[0]; - const int kNumResults = 5; - const float kThreshold = 0.1f; - std::vector> top_results; - GetTopN(output->flat(), kNumResults, kThreshold, &top_results); - - std::stringstream ss; - ss.precision(3); - for (const auto& result : top_results) { - const float confidence = result.first; - const int index = result.second; - - ss << index << " " << confidence << " "; - - // Write out the result as a string - if (index < label_strings.size()) { - // just for safety: theoretically, the output is under 1000 unless there - // is some numerical issues leading to a wrong prediction. - ss << label_strings[index]; - } else { - ss << "Prediction: " << index; - } - - ss << "\n"; - } - - LOG(INFO) << "Predictions: " << ss.str(); - - tensorflow::string predictions = ss.str(); - result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; - - return result; -} diff --git a/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib b/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib deleted file mode 100644 index 56c37080621..00000000000 --- a/tensorflow/examples/ios/benchmark/BenchmarkViewController.xib +++ /dev/null @@ -1,47 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/examples/ios/benchmark/Podfile b/tensorflow/examples/ios/benchmark/Podfile deleted file mode 100644 index e163d56e8d2..00000000000 --- a/tensorflow/examples/ios/benchmark/Podfile +++ /dev/null @@ -1,5 +0,0 @@ -platform :ios, '8.0' -inhibit_all_warnings! - -target 'tf_benchmark_example' - pod 'TensorFlow-experimental' diff --git a/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg b/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg deleted file mode 100644 index d2a427810f6..00000000000 Binary files a/tensorflow/examples/ios/benchmark/data/grace_hopper.jpg and /dev/null differ diff --git a/tensorflow/examples/ios/benchmark/ios_image_load.h b/tensorflow/examples/ios/benchmark/ios_image_load.h deleted file mode 100644 index 3f949846923..00000000000 --- a/tensorflow/examples/ios/benchmark/ios_image_load.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_ -#define TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_ - -#include - -#include "tensorflow/core/framework/types.h" - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); - -#endif // TENSORFLOW_EXAMPLES_IOS_BENCHMARK_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/examples/ios/benchmark/ios_image_load.mm b/tensorflow/examples/ios/benchmark/ios_image_load.mm deleted file mode 100644 index 64d1ea21cf2..00000000000 --- a/tensorflow/examples/ios/benchmark/ios_image_load.mm +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ios_image_load.h" - -#include -#include -#include -#include - -#import -#import - -using tensorflow::uint8; - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { - FILE* file_handle = fopen(file_name, "rb"); - fseek(file_handle, 0, SEEK_END); - const size_t bytes_in_file = ftell(file_handle); - fseek(file_handle, 0, SEEK_SET); - std::vector file_data(bytes_in_file); - fread(file_data.data(), 1, bytes_in_file, file_handle); - fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); - - const char* suffix = strrchr(file_name, '.'); - if (!suffix || suffix == file_name) { - suffix = ""; - } - CGImageRef image; - if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else { - CFRelease(image_provider); - CFRelease(file_data_ref); - fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); - *out_width = 0; - *out_height = 0; - *out_channels = 0; - return std::vector(); - } - - const int width = (int)CGImageGetWidth(image); - const int height = (int)CGImageGetHeight(image); - const int channels = 4; - CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); - const int bytes_per_row = (width * channels); - const int bytes_in_image = (bytes_per_row * height); - std::vector result(bytes_in_image); - const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); - CGColorSpaceRelease(color_space); - CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); - CGContextRelease(context); - CFRelease(image); - CFRelease(image_provider); - CFRelease(file_data_ref); - - *out_width = width; - *out_height = height; - *out_channels = channels; - return result; -} diff --git a/tensorflow/examples/ios/benchmark/main.mm b/tensorflow/examples/ios/benchmark/main.mm deleted file mode 100644 index d70550a7307..00000000000 --- a/tensorflow/examples/ios/benchmark/main.mm +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -int main(int argc, char * argv[]) { - @autoreleasepool { - NSString *delegateClassName = @"AppDelegate"; - return UIApplicationMain(argc, argv, nil, delegateClassName); - } -} diff --git a/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj deleted file mode 100644 index d61b65ba614..00000000000 --- a/tensorflow/examples/ios/benchmark/tf_benchmark_example.xcodeproj/project.pbxproj +++ /dev/null @@ -1,388 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 46; - objects = { - -/* Begin PBXBuildFile section */ - 1C8BA8FD1EC682E700CCCC8C /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; - 1C8BA8FE1EC682E700CCCC8C /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; - 1C8BA8FF1EC682E700CCCC8C /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */; }; - 1C8BA9001EC682E700CCCC8C /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; - 1C8BA9051EC682E700CCCC8C /* BenchmarkViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */; }; - 1C8BA9061EC682E700CCCC8C /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; - 1C8BA9071EC682E700CCCC8C /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; - 1C8BA9081EC682E700CCCC8C /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; - 1CB1883E1ECCC0DC00C93EF7 /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */; }; - 1CB1883F1ECCC10D00C93EF7 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */; }; - 1E0EBA4DF4C722C63814B257 /* libPods-tf_benchmark_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */; }; -/* End PBXBuildFile section */ - -/* Begin PBXFileReference section */ - 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; - 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_benchmark_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 1CB1883B1ECCC09A00C93EF7 /* CoreFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreFoundation.framework; path = System/Library/Frameworks/CoreFoundation.framework; sourceTree = SDKROOT; }; - 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; - 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; - 59A3CFFD1CF4E68100C4259F /* Benchmark-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "Benchmark-Info.plist"; sourceTree = ""; }; - 59A3CFFE1CF4E68100C4259F /* BenchmarkViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = BenchmarkViewController.h; sourceTree = ""; }; - 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = ""; }; - 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = BenchmarkViewController.xib; sourceTree = ""; }; - 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_benchmark_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example.debug.xcconfig"; sourceTree = ""; }; - 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_benchmark_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; - DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_benchmark_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example.release.xcconfig"; sourceTree = ""; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - 1C8BA9011EC682E700CCCC8C /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - 1CB1883F1ECCC10D00C93EF7 /* UIKit.framework in Frameworks */, - 1CB1883E1ECCC0DC00C93EF7 /* CoreGraphics.framework in Frameworks */, - 1E0EBA4DF4C722C63814B257 /* libPods-tf_benchmark_example.a in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - 2BD56010B574F539C2070A57 /* Pods */ = { - isa = PBXGroup; - children = ( - 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */, - DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */, - ); - name = Pods; - sourceTree = ""; - }; - 591157921CF4011C00C31E3A = { - isa = PBXGroup; - children = ( - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, - 59A3CFF31CF4E68100C4259F /* data */, - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, - 59A3CFFC1CF4E68100C4259F /* main.mm */, - 59A3CFFD1CF4E68100C4259F /* Benchmark-Info.plist */, - 59A3CFFE1CF4E68100C4259F /* BenchmarkViewController.h */, - 59A3CFFF1CF4E68100C4259F /* BenchmarkViewController.mm */, - 59A3D0001CF4E68100C4259F /* BenchmarkViewController.xib */, - 5911579C1CF4011C00C31E3A /* Products */, - 2BD56010B574F539C2070A57 /* Pods */, - 76A25A27041EB307BDFF0DD1 /* Frameworks */, - ); - sourceTree = ""; - }; - 5911579C1CF4011C00C31E3A /* Products */ = { - isa = PBXGroup; - children = ( - 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */, - ); - name = Products; - sourceTree = ""; - }; - 59A3CFF31CF4E68100C4259F /* data */ = { - isa = PBXGroup; - children = ( - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, - ); - path = data; - sourceTree = ""; - }; - 76A25A27041EB307BDFF0DD1 /* Frameworks */ = { - isa = PBXGroup; - children = ( - 1CB1883D1ECCC0DC00C93EF7 /* CoreGraphics.framework */, - 1CB1883B1ECCC09A00C93EF7 /* CoreFoundation.framework */, - 1C7AC7FC1ECCBFE400EAE588 /* UIKit.framework */, - 8C4FE48552EFB73D066C66E9 /* libPods-tf_benchmark_example.a */, - ); - name = Frameworks; - sourceTree = ""; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - 1C8BA8FB1EC682E700CCCC8C /* tf_benchmark_example */ = { - isa = PBXNativeTarget; - buildConfigurationList = 1C8BA9091EC682E700CCCC8C /* Build configuration list for PBXNativeTarget "tf_benchmark_example" */; - buildPhases = ( - 0388D751057A257A12848245 /* [CP] Check Pods Manifest.lock */, - 1C8BA8FC1EC682E700CCCC8C /* Sources */, - 1C8BA9011EC682E700CCCC8C /* Frameworks */, - 1C8BA9041EC682E700CCCC8C /* Resources */, - 8999A303091D4E86202C2F64 /* [CP] Embed Pods Frameworks */, - A7B4B278BCC417B76A47ABB0 /* [CP] Copy Pods Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = tf_benchmark_example; - productName = benchmark; - productReference = 1C8BA90C1EC682E700CCCC8C /* tf_benchmark_example.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - 591157931CF4011C00C31E3A /* Project object */ = { - isa = PBXProject; - attributes = { - LastUpgradeCheck = 0830; - ORGANIZATIONNAME = Google; - }; - buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_benchmark_example" */; - compatibilityVersion = "Xcode 3.2"; - developmentRegion = English; - hasScannedForEncodings = 0; - knownRegions = ( - en, - Base, - ); - mainGroup = 591157921CF4011C00C31E3A; - productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - 1C8BA8FB1EC682E700CCCC8C /* tf_benchmark_example */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - 1C8BA9041EC682E700CCCC8C /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 1C8BA9051EC682E700CCCC8C /* BenchmarkViewController.xib in Resources */, - 1C8BA9061EC682E700CCCC8C /* imagenet_comp_graph_label_strings.txt in Resources */, - 1C8BA9071EC682E700CCCC8C /* tensorflow_inception_graph.pb in Resources */, - 1C8BA9081EC682E700CCCC8C /* grace_hopper.jpg in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXShellScriptBuildPhase section */ - 0388D751057A257A12848245 /* [CP] Check Pods Manifest.lock */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Check Pods Manifest.lock"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n"; - showEnvVarsInLog = 0; - }; - 8999A303091D4E86202C2F64 /* [CP] Embed Pods Frameworks */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Embed Pods Frameworks"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example-frameworks.sh\"\n"; - showEnvVarsInLog = 0; - }; - A7B4B278BCC417B76A47ABB0 /* [CP] Copy Pods Resources */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Copy Pods Resources"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_benchmark_example/Pods-tf_benchmark_example-resources.sh\"\n"; - showEnvVarsInLog = 0; - }; -/* End PBXShellScriptBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - 1C8BA8FC1EC682E700CCCC8C /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 1C8BA8FD1EC682E700CCCC8C /* main.mm in Sources */, - 1C8BA8FE1EC682E700CCCC8C /* AppDelegate.mm in Sources */, - 1C8BA8FF1EC682E700CCCC8C /* BenchmarkViewController.mm in Sources */, - 1C8BA9001EC682E700CCCC8C /* ios_image_load.mm in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin XCBuildConfiguration section */ - 1C8BA90A1EC682E700CCCC8C /* Debug */ = { - isa = XCBuildConfiguration; - baseConfigurationReference = 5FD1623E64FC0154A67E8DD5 /* Pods-tf_benchmark_example.debug.xcconfig */; - buildSettings = { - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - HEADER_SEARCH_PATHS = "$(inherited)"; - INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ""; - OTHER_LDFLAGS = "$(inherited)"; - PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-benchmark-example"; - PRODUCT_NAME = "$(TARGET_NAME)"; - }; - name = Debug; - }; - 1C8BA90B1EC682E700CCCC8C /* Release */ = { - isa = XCBuildConfiguration; - baseConfigurationReference = DB6B3E596779C98202E84711 /* Pods-tf_benchmark_example.release.xcconfig */; - buildSettings = { - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - HEADER_SEARCH_PATHS = "$(inherited)"; - INFOPLIST_FILE = "$(SRCROOT)/Benchmark-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ""; - ONLY_ACTIVE_ARCH = YES; - OTHER_LDFLAGS = "$(inherited)"; - PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-benchmark-example"; - PRODUCT_NAME = "$(TARGET_NAME)"; - }; - name = Release; - }; - 591157B01CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - MTL_ENABLE_DEBUG_INFO = YES; - ONLY_ACTIVE_ARCH = YES; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - }; - name = Debug; - }; - 591157B11CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - MTL_ENABLE_DEBUG_INFO = NO; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - VALIDATE_PRODUCT = YES; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - 1C8BA9091EC682E700CCCC8C /* Build configuration list for PBXNativeTarget "tf_benchmark_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 1C8BA90A1EC682E700CCCC8C /* Debug */, - 1C8BA90B1EC682E700CCCC8C /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_benchmark_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B01CF4011D00C31E3A /* Debug */, - 591157B11CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - }; - rootObject = 591157931CF4011C00C31E3A /* Project object */; -} diff --git a/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h deleted file mode 100644 index 0039d5e7cab..00000000000 --- a/tensorflow/examples/ios/camera/CameraExampleAppDelegate.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface CameraExampleAppDelegate : UIResponder - -@property(strong, nonatomic) UIWindow *window; - -@end diff --git a/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m b/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m deleted file mode 100644 index d134c2b591e..00000000000 --- a/tensorflow/examples/ios/camera/CameraExampleAppDelegate.m +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "CameraExampleAppDelegate.h" - -@implementation CameraExampleAppDelegate - -@synthesize window = _window; - -- (BOOL)application:(UIApplication *)application - didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - [self.window makeKeyAndVisible]; - return YES; -} - -- (void)applicationWillResignActive:(UIApplication *)application { - [[UIApplication sharedApplication] setIdleTimerDisabled:NO]; -} - -- (void)applicationDidEnterBackground:(UIApplication *)application { -} - -- (void)applicationWillEnterForeground:(UIApplication *)application { -} - -- (void)applicationDidBecomeActive:(UIApplication *)application { - [[UIApplication sharedApplication] setIdleTimerDisabled:YES]; -} - -- (void)applicationWillTerminate:(UIApplication *)application { -} - -@end diff --git a/tensorflow/examples/ios/camera/CameraExampleViewController.h b/tensorflow/examples/ios/camera/CameraExampleViewController.h deleted file mode 100644 index 0aefbc6eedb..00000000000 --- a/tensorflow/examples/ios/camera/CameraExampleViewController.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import - -#include -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/memmapped_file_system.h" - -@interface CameraExampleViewController - : UIViewController { - IBOutlet UIView *previewView; - IBOutlet UISegmentedControl *camerasControl; - AVCaptureVideoPreviewLayer *previewLayer; - AVCaptureVideoDataOutput *videoDataOutput; - dispatch_queue_t videoDataOutputQueue; - AVCaptureStillImageOutput *stillImageOutput; - UIView *flashView; - UIImage *square; - BOOL isUsingFrontFacingCamera; - AVSpeechSynthesizer *synth; - NSMutableDictionary *oldPredictionValues; - NSMutableArray *labelLayers; - AVCaptureSession *session; - std::unique_ptr tf_session; - std::unique_ptr tf_memmapped_env; - std::vector labels; -} -@property(strong, nonatomic) CATextLayer *predictionTextLayer; - -- (IBAction)takePicture:(id)sender; -- (IBAction)switchCameras:(id)sender; - -@end diff --git a/tensorflow/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/examples/ios/camera/CameraExampleViewController.mm deleted file mode 100644 index d113d50ff8e..00000000000 --- a/tensorflow/examples/ios/camera/CameraExampleViewController.mm +++ /dev/null @@ -1,621 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import -#import -#import -#import "CameraExampleViewController.h" - -#include - -#include "tensorflow_utils.h" - -// If you have your own model, modify this to the file name, and make sure -// you've added the file to your app resources too. -static NSString* model_file_name = @"tensorflow_inception_graph"; -static NSString* model_file_type = @"pb"; -// This controls whether we'll be loading a plain GraphDef proto, or a -// file created by the convert_graphdef_memmapped_format utility that wraps a -// GraphDef and parameter file that can be mapped into memory from file to -// reduce overall memory usage. -const bool model_uses_memory_mapping = false; -// If you have your own model, point this to the labels file. -static NSString* labels_file_name = @"imagenet_comp_graph_label_strings"; -static NSString* labels_file_type = @"txt"; -// These dimensions need to match those the model was trained with. -const int wanted_input_width = 224; -const int wanted_input_height = 224; -const int wanted_input_channels = 3; -const float input_mean = 117.0f; -const float input_std = 1.0f; -const std::string input_layer_name = "input"; -const std::string output_layer_name = "softmax1"; - -static void *AVCaptureStillImageIsCapturingStillImageContext = - &AVCaptureStillImageIsCapturingStillImageContext; - -@interface CameraExampleViewController (InternalMethods) -- (void)setupAVCapture; -- (void)teardownAVCapture; -@end - -@implementation CameraExampleViewController - -- (void)setupAVCapture { - NSError *error = nil; - - session = [AVCaptureSession new]; - if ([[UIDevice currentDevice] userInterfaceIdiom] == - UIUserInterfaceIdiomPhone) - [session setSessionPreset:AVCaptureSessionPreset640x480]; - else - [session setSessionPreset:AVCaptureSessionPresetPhoto]; - - AVCaptureDevice *device = - [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; - AVCaptureDeviceInput *deviceInput = - [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; - assert(error == nil); - - isUsingFrontFacingCamera = NO; - if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; - - stillImageOutput = [AVCaptureStillImageOutput new]; - [stillImageOutput - addObserver:self - forKeyPath:@"capturingStillImage" - options:NSKeyValueObservingOptionNew - context:(void *)(AVCaptureStillImageIsCapturingStillImageContext)]; - if ([session canAddOutput:stillImageOutput]) - [session addOutput:stillImageOutput]; - - videoDataOutput = [AVCaptureVideoDataOutput new]; - - NSDictionary *rgbOutputSettings = [NSDictionary - dictionaryWithObject:[NSNumber numberWithInt:kCMPixelFormat_32BGRA] - forKey:(id)kCVPixelBufferPixelFormatTypeKey]; - [videoDataOutput setVideoSettings:rgbOutputSettings]; - [videoDataOutput setAlwaysDiscardsLateVideoFrames:YES]; - videoDataOutputQueue = - dispatch_queue_create("VideoDataOutputQueue", DISPATCH_QUEUE_SERIAL); - [videoDataOutput setSampleBufferDelegate:self queue:videoDataOutputQueue]; - - if ([session canAddOutput:videoDataOutput]) - [session addOutput:videoDataOutput]; - [[videoDataOutput connectionWithMediaType:AVMediaTypeVideo] setEnabled:YES]; - - previewLayer = [[AVCaptureVideoPreviewLayer alloc] initWithSession:session]; - [previewLayer setBackgroundColor:[[UIColor blackColor] CGColor]]; - [previewLayer setVideoGravity:AVLayerVideoGravityResizeAspect]; - CALayer *rootLayer = [previewView layer]; - [rootLayer setMasksToBounds:YES]; - [previewLayer setFrame:[rootLayer bounds]]; - [rootLayer addSublayer:previewLayer]; - [session startRunning]; - - if (error) { - NSString *title = [NSString stringWithFormat:@"Failed with error %d", (int)[error code]]; - UIAlertController *alertController = - [UIAlertController alertControllerWithTitle:title - message:[error localizedDescription] - preferredStyle:UIAlertControllerStyleAlert]; - UIAlertAction *dismiss = - [UIAlertAction actionWithTitle:@"Dismiss" style:UIAlertActionStyleDefault handler:nil]; - [alertController addAction:dismiss]; - [self presentViewController:alertController animated:YES completion:nil]; - [self teardownAVCapture]; - } -} - -- (void)teardownAVCapture { - [stillImageOutput removeObserver:self forKeyPath:@"isCapturingStillImage"]; - [previewLayer removeFromSuperlayer]; -} - -- (void)observeValueForKeyPath:(NSString *)keyPath - ofObject:(id)object - change:(NSDictionary *)change - context:(void *)context { - if (context == AVCaptureStillImageIsCapturingStillImageContext) { - BOOL isCapturingStillImage = - [[change objectForKey:NSKeyValueChangeNewKey] boolValue]; - - if (isCapturingStillImage) { - // do flash bulb like animation - flashView = [[UIView alloc] initWithFrame:[previewView frame]]; - [flashView setBackgroundColor:[UIColor whiteColor]]; - [flashView setAlpha:0.f]; - [[[self view] window] addSubview:flashView]; - - [UIView animateWithDuration:.4f - animations:^{ - [flashView setAlpha:1.f]; - }]; - } else { - [UIView animateWithDuration:.4f - animations:^{ - [flashView setAlpha:0.f]; - } - completion:^(BOOL finished) { - [flashView removeFromSuperview]; - flashView = nil; - }]; - } - } -} - -- (AVCaptureVideoOrientation)avOrientationForDeviceOrientation: - (UIDeviceOrientation)deviceOrientation { - AVCaptureVideoOrientation result = - (AVCaptureVideoOrientation)(deviceOrientation); - if (deviceOrientation == UIDeviceOrientationLandscapeLeft) - result = AVCaptureVideoOrientationLandscapeRight; - else if (deviceOrientation == UIDeviceOrientationLandscapeRight) - result = AVCaptureVideoOrientationLandscapeLeft; - return result; -} - -- (IBAction)takePicture:(id)sender { - if ([session isRunning]) { - [session stopRunning]; - [sender setTitle:@"Continue" forState:UIControlStateNormal]; - - flashView = [[UIView alloc] initWithFrame:[previewView frame]]; - [flashView setBackgroundColor:[UIColor whiteColor]]; - [flashView setAlpha:0.f]; - [[[self view] window] addSubview:flashView]; - - [UIView animateWithDuration:.2f - animations:^{ - [flashView setAlpha:1.f]; - } - completion:^(BOOL finished) { - [UIView animateWithDuration:.2f - animations:^{ - [flashView setAlpha:0.f]; - } - completion:^(BOOL finished) { - [flashView removeFromSuperview]; - flashView = nil; - }]; - }]; - - } else { - [session startRunning]; - [sender setTitle:@"Freeze Frame" forState:UIControlStateNormal]; - } -} - -+ (CGRect)videoPreviewBoxForGravity:(NSString *)gravity - frameSize:(CGSize)frameSize - apertureSize:(CGSize)apertureSize { - CGFloat apertureRatio = apertureSize.height / apertureSize.width; - CGFloat viewRatio = frameSize.width / frameSize.height; - - CGSize size = CGSizeZero; - if ([gravity isEqualToString:AVLayerVideoGravityResizeAspectFill]) { - if (viewRatio > apertureRatio) { - size.width = frameSize.width; - size.height = - apertureSize.width * (frameSize.width / apertureSize.height); - } else { - size.width = - apertureSize.height * (frameSize.height / apertureSize.width); - size.height = frameSize.height; - } - } else if ([gravity isEqualToString:AVLayerVideoGravityResizeAspect]) { - if (viewRatio > apertureRatio) { - size.width = - apertureSize.height * (frameSize.height / apertureSize.width); - size.height = frameSize.height; - } else { - size.width = frameSize.width; - size.height = - apertureSize.width * (frameSize.width / apertureSize.height); - } - } else if ([gravity isEqualToString:AVLayerVideoGravityResize]) { - size.width = frameSize.width; - size.height = frameSize.height; - } - - CGRect videoBox; - videoBox.size = size; - if (size.width < frameSize.width) - videoBox.origin.x = (frameSize.width - size.width) / 2; - else - videoBox.origin.x = (size.width - frameSize.width) / 2; - - if (size.height < frameSize.height) - videoBox.origin.y = (frameSize.height - size.height) / 2; - else - videoBox.origin.y = (size.height - frameSize.height) / 2; - - return videoBox; -} - -- (void)captureOutput:(AVCaptureOutput *)captureOutput -didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer - fromConnection:(AVCaptureConnection *)connection { - CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); - CFRetain(pixelBuffer); - [self runCNNOnFrame:pixelBuffer]; - CFRelease(pixelBuffer); -} - -- (void)runCNNOnFrame:(CVPixelBufferRef)pixelBuffer { - assert(pixelBuffer != NULL); - - OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); - int doReverseChannels; - if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { - doReverseChannels = 1; - } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { - doReverseChannels = 0; - } else { - assert(false); // Unknown source format - } - - const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); - const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); - const int fullHeight = (int)CVPixelBufferGetHeight(pixelBuffer); - - CVPixelBufferLockFlags unlockFlags = kNilOptions; - CVPixelBufferLockBaseAddress(pixelBuffer, unlockFlags); - - unsigned char *sourceBaseAddr = - (unsigned char *)(CVPixelBufferGetBaseAddress(pixelBuffer)); - int image_height; - unsigned char *sourceStartAddr; - if (fullHeight <= image_width) { - image_height = fullHeight; - sourceStartAddr = sourceBaseAddr; - } else { - image_height = image_width; - const int marginY = ((fullHeight - image_width) / 2); - sourceStartAddr = (sourceBaseAddr + (marginY * sourceRowBytes)); - } - const int image_channels = 4; - - assert(image_channels >= wanted_input_channels); - tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, - tensorflow::TensorShape( - {1, wanted_input_height, wanted_input_width, wanted_input_channels})); - auto image_tensor_mapped = image_tensor.tensor(); - tensorflow::uint8 *in = sourceStartAddr; - float *out = image_tensor_mapped.data(); - for (int y = 0; y < wanted_input_height; ++y) { - float *out_row = out + (y * wanted_input_width * wanted_input_channels); - for (int x = 0; x < wanted_input_width; ++x) { - const int in_x = (y * image_width) / wanted_input_width; - const int in_y = (x * image_height) / wanted_input_height; - tensorflow::uint8 *in_pixel = - in + (in_y * image_width * image_channels) + (in_x * image_channels); - float *out_pixel = out_row + (x * wanted_input_channels); - for (int c = 0; c < wanted_input_channels; ++c) { - out_pixel[c] = (in_pixel[c] - input_mean) / input_std; - } - } - } - - CVPixelBufferUnlockBaseAddress(pixelBuffer, unlockFlags); - - if (tf_session.get()) { - std::vector outputs; - tensorflow::Status run_status = tf_session->Run( - {{input_layer_name, image_tensor}}, {output_layer_name}, {}, &outputs); - if (!run_status.ok()) { - LOG(ERROR) << "Running model failed:" << run_status; - } else { - tensorflow::Tensor *output = &outputs[0]; - auto predictions = output->flat(); - - NSMutableDictionary *newValues = [NSMutableDictionary dictionary]; - for (int index = 0; index < predictions.size(); index += 1) { - const float predictionValue = predictions(index); - if (predictionValue > 0.05f) { - std::string label = labels[index % predictions.size()]; - NSString *labelObject = [NSString stringWithUTF8String:label.c_str()]; - NSNumber *valueObject = [NSNumber numberWithFloat:predictionValue]; - [newValues setObject:valueObject forKey:labelObject]; - } - } - dispatch_async(dispatch_get_main_queue(), ^(void) { - [self setPredictionValues:newValues]; - }); - } - } - CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); -} - -- (void)dealloc { - [self teardownAVCapture]; -} - -// use front/back camera -- (IBAction)switchCameras:(id)sender { - AVCaptureDevicePosition desiredPosition; - if (isUsingFrontFacingCamera) - desiredPosition = AVCaptureDevicePositionBack; - else - desiredPosition = AVCaptureDevicePositionFront; - - for (AVCaptureDevice *d in - [AVCaptureDevice devicesWithMediaType:AVMediaTypeVideo]) { - if ([d position] == desiredPosition) { - [[previewLayer session] beginConfiguration]; - AVCaptureDeviceInput *input = - [AVCaptureDeviceInput deviceInputWithDevice:d error:nil]; - for (AVCaptureInput *oldInput in [[previewLayer session] inputs]) { - [[previewLayer session] removeInput:oldInput]; - } - [[previewLayer session] addInput:input]; - [[previewLayer session] commitConfiguration]; - break; - } - } - isUsingFrontFacingCamera = !isUsingFrontFacingCamera; -} - -- (void)didReceiveMemoryWarning { - [super didReceiveMemoryWarning]; -} - -- (void)viewDidLoad { - [super viewDidLoad]; - square = [UIImage imageNamed:@"squarePNG"]; - synth = [[AVSpeechSynthesizer alloc] init]; - labelLayers = [[NSMutableArray alloc] init]; - oldPredictionValues = [[NSMutableDictionary alloc] init]; - - tensorflow::Status load_status; - if (model_uses_memory_mapping) { - load_status = LoadMemoryMappedModel( - model_file_name, model_file_type, &tf_session, &tf_memmapped_env); - } else { - load_status = LoadModel(model_file_name, model_file_type, &tf_session); - } - if (!load_status.ok()) { - LOG(FATAL) << "Couldn't load model: " << load_status; - } - - tensorflow::Status labels_status = - LoadLabels(labels_file_name, labels_file_type, &labels); - if (!labels_status.ok()) { - LOG(FATAL) << "Couldn't load labels: " << labels_status; - } - [self setupAVCapture]; -} - -- (void)viewDidUnload { - [super viewDidUnload]; -} - -- (void)viewWillAppear:(BOOL)animated { - [super viewWillAppear:animated]; -} - -- (void)viewDidAppear:(BOOL)animated { - [super viewDidAppear:animated]; -} - -- (void)viewWillDisappear:(BOOL)animated { - [super viewWillDisappear:animated]; -} - -- (void)viewDidDisappear:(BOOL)animated { - [super viewDidDisappear:animated]; -} - -- (BOOL)shouldAutorotateToInterfaceOrientation: - (UIInterfaceOrientation)interfaceOrientation { - return (interfaceOrientation == UIInterfaceOrientationPortrait); -} - -- (BOOL)prefersStatusBarHidden { - return YES; -} - -- (void)setPredictionValues:(NSDictionary *)newValues { - const float decayValue = 0.75f; - const float updateValue = 0.25f; - const float minimumThreshold = 0.01f; - - NSMutableDictionary *decayedPredictionValues = - [[NSMutableDictionary alloc] init]; - for (NSString *label in oldPredictionValues) { - NSNumber *oldPredictionValueObject = - [oldPredictionValues objectForKey:label]; - const float oldPredictionValue = [oldPredictionValueObject floatValue]; - const float decayedPredictionValue = (oldPredictionValue * decayValue); - if (decayedPredictionValue > minimumThreshold) { - NSNumber *decayedPredictionValueObject = - [NSNumber numberWithFloat:decayedPredictionValue]; - [decayedPredictionValues setObject:decayedPredictionValueObject - forKey:label]; - } - } - oldPredictionValues = decayedPredictionValues; - - for (NSString *label in newValues) { - NSNumber *newPredictionValueObject = [newValues objectForKey:label]; - NSNumber *oldPredictionValueObject = - [oldPredictionValues objectForKey:label]; - if (!oldPredictionValueObject) { - oldPredictionValueObject = [NSNumber numberWithFloat:0.0f]; - } - const float newPredictionValue = [newPredictionValueObject floatValue]; - const float oldPredictionValue = [oldPredictionValueObject floatValue]; - const float updatedPredictionValue = - (oldPredictionValue + (newPredictionValue * updateValue)); - NSNumber *updatedPredictionValueObject = - [NSNumber numberWithFloat:updatedPredictionValue]; - [oldPredictionValues setObject:updatedPredictionValueObject forKey:label]; - } - NSArray *candidateLabels = [NSMutableArray array]; - for (NSString *label in oldPredictionValues) { - NSNumber *oldPredictionValueObject = - [oldPredictionValues objectForKey:label]; - const float oldPredictionValue = [oldPredictionValueObject floatValue]; - if (oldPredictionValue > 0.05f) { - NSDictionary *entry = @{ - @"label" : label, - @"value" : oldPredictionValueObject - }; - candidateLabels = [candidateLabels arrayByAddingObject:entry]; - } - } - NSSortDescriptor *sort = - [NSSortDescriptor sortDescriptorWithKey:@"value" ascending:NO]; - NSArray *sortedLabels = [candidateLabels - sortedArrayUsingDescriptors:[NSArray arrayWithObject:sort]]; - - const float leftMargin = 10.0f; - const float topMargin = 10.0f; - - const float valueWidth = 48.0f; - const float valueHeight = 26.0f; - - const float labelWidth = 246.0f; - const float labelHeight = 26.0f; - - const float labelMarginX = 5.0f; - const float labelMarginY = 5.0f; - - [self removeAllLabelLayers]; - - int labelCount = 0; - for (NSDictionary *entry in sortedLabels) { - NSString *label = [entry objectForKey:@"label"]; - NSNumber *valueObject = [entry objectForKey:@"value"]; - const float value = [valueObject floatValue]; - - const float originY = - (topMargin + ((labelHeight + labelMarginY) * labelCount)); - - const int valuePercentage = (int)roundf(value * 100.0f); - - const float valueOriginX = leftMargin; - NSString *valueText = [NSString stringWithFormat:@"%d%%", valuePercentage]; - - [self addLabelLayerWithText:valueText - originX:valueOriginX - originY:originY - width:valueWidth - height:valueHeight - alignment:kCAAlignmentRight]; - - const float labelOriginX = (leftMargin + valueWidth + labelMarginX); - - [self addLabelLayerWithText:[label capitalizedString] - originX:labelOriginX - originY:originY - width:labelWidth - height:labelHeight - alignment:kCAAlignmentLeft]; - - if ((labelCount == 0) && (value > 0.5f)) { - [self speak:[label capitalizedString]]; - } - - labelCount += 1; - if (labelCount > 4) { - break; - } - } -} - -- (void)removeAllLabelLayers { - for (CATextLayer *layer in labelLayers) { - [layer removeFromSuperlayer]; - } - [labelLayers removeAllObjects]; -} - -- (void)addLabelLayerWithText:(NSString *)text - originX:(float)originX - originY:(float)originY - width:(float)width - height:(float)height - alignment:(NSString *)alignment { - CFTypeRef font = (CFTypeRef) @"Menlo-Regular"; - const float fontSize = 20.0f; - - const float marginSizeX = 5.0f; - const float marginSizeY = 2.0f; - - const CGRect backgroundBounds = CGRectMake(originX, originY, width, height); - - const CGRect textBounds = - CGRectMake((originX + marginSizeX), (originY + marginSizeY), - (width - (marginSizeX * 2)), (height - (marginSizeY * 2))); - - CATextLayer *background = [CATextLayer layer]; - [background setBackgroundColor:[UIColor blackColor].CGColor]; - [background setOpacity:0.5f]; - [background setFrame:backgroundBounds]; - background.cornerRadius = 5.0f; - - [[self.view layer] addSublayer:background]; - [labelLayers addObject:background]; - - CATextLayer *layer = [CATextLayer layer]; - [layer setForegroundColor:[UIColor whiteColor].CGColor]; - [layer setFrame:textBounds]; - [layer setAlignmentMode:alignment]; - [layer setWrapped:YES]; - [layer setFont:font]; - [layer setFontSize:fontSize]; - layer.contentsScale = [[UIScreen mainScreen] scale]; - [layer setString:text]; - - [[self.view layer] addSublayer:layer]; - [labelLayers addObject:layer]; -} - -- (void)setPredictionText:(NSString *)text withDuration:(float)duration { - if (duration > 0.0) { - CABasicAnimation *colorAnimation = - [CABasicAnimation animationWithKeyPath:@"foregroundColor"]; - colorAnimation.duration = duration; - colorAnimation.fillMode = kCAFillModeForwards; - colorAnimation.removedOnCompletion = NO; - colorAnimation.fromValue = (id)[UIColor darkGrayColor].CGColor; - colorAnimation.toValue = (id)[UIColor whiteColor].CGColor; - colorAnimation.timingFunction = - [CAMediaTimingFunction functionWithName:kCAMediaTimingFunctionLinear]; - [self.predictionTextLayer addAnimation:colorAnimation - forKey:@"colorAnimation"]; - } else { - self.predictionTextLayer.foregroundColor = [UIColor whiteColor].CGColor; - } - - [self.predictionTextLayer removeFromSuperlayer]; - [[self.view layer] addSublayer:self.predictionTextLayer]; - [self.predictionTextLayer setString:text]; -} - -- (void)speak:(NSString *)words { - if ([synth isSpeaking]) { - return; - } - AVSpeechUtterance *utterance = - [AVSpeechUtterance speechUtteranceWithString:words]; - utterance.voice = [AVSpeechSynthesisVoice voiceWithLanguage:@"en-US"]; - utterance.rate = 0.75 * AVSpeechUtteranceDefaultSpeechRate; - [synth speakUtterance:utterance]; -} - -@end diff --git a/tensorflow/examples/ios/camera/Info.plist b/tensorflow/examples/ios/camera/Info.plist deleted file mode 100644 index 772fb38dcc9..00000000000 --- a/tensorflow/examples/ios/camera/Info.plist +++ /dev/null @@ -1,44 +0,0 @@ - - - - - CFBundleDevelopmentRegion - en - CFBundleDisplayName - tf_camera_example - CFBundleExecutable - ${EXECUTABLE_NAME} - CFBundleIdentifier - $(PRODUCT_BUNDLE_IDENTIFIER) - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - ${PRODUCT_NAME} - CFBundlePackageType - APPL - CFBundleShortVersionString - 1.0 - CFBundleSignature - ???? - CFBundleVersion - 1.0 - LSRequiresIPhoneOS - - NSCameraUsageDescription - Capture images to detect object - UIMainStoryboardFile - MainStoryboard_iPhone - UIRequiresFullScreen - - UIStatusBarHidden - - UISupportedInterfaceOrientations - - UIInterfaceOrientationPortrait - - UISupportedInterfaceOrientations~ipad - - UIInterfaceOrientationPortrait - - - diff --git a/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard b/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard deleted file mode 100644 index 0f10a22e415..00000000000 --- a/tensorflow/examples/ios/camera/MainStoryboard_iPhone.storyboard +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/examples/ios/camera/Podfile b/tensorflow/examples/ios/camera/Podfile deleted file mode 100644 index 117828f0714..00000000000 --- a/tensorflow/examples/ios/camera/Podfile +++ /dev/null @@ -1,5 +0,0 @@ -platform :ios, '8.0' -inhibit_all_warnings! - -target 'tf_camera_example' - pod 'TensorFlow-experimental' diff --git a/tensorflow/examples/ios/camera/data/grace_hopper.jpg b/tensorflow/examples/ios/camera/data/grace_hopper.jpg deleted file mode 100644 index d2a427810f6..00000000000 Binary files a/tensorflow/examples/ios/camera/data/grace_hopper.jpg and /dev/null differ diff --git a/tensorflow/examples/ios/camera/ios_image_load.h b/tensorflow/examples/ios/camera/ios_image_load.h deleted file mode 100644 index 3de812c34b3..00000000000 --- a/tensorflow/examples/ios/camera/ios_image_load.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_ -#define TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_ - -#include - -#include "tensorflow/core/framework/types.h" - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); - -#endif // TENSORFLOW_EXAMPLES_IOS_CAMERA_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/examples/ios/camera/ios_image_load.mm b/tensorflow/examples/ios/camera/ios_image_load.mm deleted file mode 100644 index 64d1ea21cf2..00000000000 --- a/tensorflow/examples/ios/camera/ios_image_load.mm +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ios_image_load.h" - -#include -#include -#include -#include - -#import -#import - -using tensorflow::uint8; - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { - FILE* file_handle = fopen(file_name, "rb"); - fseek(file_handle, 0, SEEK_END); - const size_t bytes_in_file = ftell(file_handle); - fseek(file_handle, 0, SEEK_SET); - std::vector file_data(bytes_in_file); - fread(file_data.data(), 1, bytes_in_file, file_handle); - fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); - - const char* suffix = strrchr(file_name, '.'); - if (!suffix || suffix == file_name) { - suffix = ""; - } - CGImageRef image; - if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else { - CFRelease(image_provider); - CFRelease(file_data_ref); - fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); - *out_width = 0; - *out_height = 0; - *out_channels = 0; - return std::vector(); - } - - const int width = (int)CGImageGetWidth(image); - const int height = (int)CGImageGetHeight(image); - const int channels = 4; - CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); - const int bytes_per_row = (width * channels); - const int bytes_in_image = (bytes_per_row * height); - std::vector result(bytes_in_image); - const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); - CGColorSpaceRelease(color_space); - CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); - CGContextRelease(context); - CFRelease(image); - CFRelease(image_provider); - CFRelease(file_data_ref); - - *out_width = width; - *out_height = height; - *out_channels = channels; - return result; -} diff --git a/tensorflow/examples/ios/camera/main.mm b/tensorflow/examples/ios/camera/main.mm deleted file mode 100644 index 42eff697efc..00000000000 --- a/tensorflow/examples/ios/camera/main.mm +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -#import "CameraExampleAppDelegate.h" - -int main(int argc, char *argv[]) { - int retVal = 0; - - @autoreleasepool { - retVal = UIApplicationMain( - argc, argv, nil, NSStringFromClass([CameraExampleAppDelegate class])); - } - return retVal; -} diff --git a/tensorflow/examples/ios/camera/tensorflow_utils.h b/tensorflow/examples/ios/camera/tensorflow_utils.h deleted file mode 100644 index 78bdb82aae6..00000000000 --- a/tensorflow/examples/ios/camera/tensorflow_utils.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ -#define TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ - -#include -#include - -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/util/memmapped_file_system.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -// Reads a serialized GraphDef protobuf file from the bundle, typically -// created with the freeze_graph script. Populates the session argument with a -// Session object that has the model loaded. -tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, - std::unique_ptr* session); - -// Loads a model from a file that has been created using the -// convert_graphdef_memmapped_format tool. This bundles together a GraphDef -// proto together with a file that can be memory-mapped, containing the weight -// parameters for the model. This is useful because it reduces the overall -// memory pressure, since the read-only parameter regions can be easily paged -// out and don't count toward memory limits on iOS. -tensorflow::Status LoadMemoryMappedModel( - NSString* file_name, NSString* file_type, - std::unique_ptr* session, - std::unique_ptr* memmapped_env); - -// Takes a text file with a single label on each line, and returns a list. -tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type, - std::vector* label_strings); - -// Sorts the results from a model execution, and returns the highest scoring. -void GetTopN(const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector >* top_results); - -#endif // TENSORFLOW_CONTRIB_IOS_EXAMPLES_CAMERA_TENSORFLOW_UTILS_H_ diff --git a/tensorflow/examples/ios/camera/tensorflow_utils.mm b/tensorflow/examples/ios/camera/tensorflow_utils.mm deleted file mode 100644 index 56d1e53081d..00000000000 --- a/tensorflow/examples/ios/camera/tensorflow_utils.mm +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -#include "tensorflow_utils.h" - -#include -#include -#include -#include -#include -#include - -namespace { - -// Helper class used to load protobufs efficiently. -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const std::string& file_name) - : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -// Returns the top N confidence values over threshold in the provided vector, -// sorted by confidence in descending order. -void GetTopN(const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector >* top_results) { - // Will contain top N results in ascending order. - std::priority_queue, - std::vector >, - std::greater > > - top_result_pq; - - const int count = prediction.size(); - for (int i = 0; i < count; ++i) { - const float value = prediction(i); - - // Only add it if it beats the threshold and has a chance at being in - // the top N. - if (value < threshold) { - continue; - } - - top_result_pq.push(std::pair(value, i)); - - // If at capacity, kick the smallest value out. - if (top_result_pq.size() > num_results) { - top_result_pq.pop(); - } - } - - // Copy to output vector and reverse into descending order. - while (!top_result_pq.empty()) { - top_results->push_back(top_result_pq.top()); - top_result_pq.pop(); - } - std::reverse(top_results->begin(), top_results->end()); -} - -bool PortableReadFileToProto(const std::string& file_name, - ::google::protobuf::MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(file_name)); - stream.SetOwnsCopyingStream(true); - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -NSString* FilePathForResourceName(NSString* name, NSString* extension) { - NSString* file_path = - [[NSBundle mainBundle] pathForResource:name ofType:extension]; - if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; - return nullptr; - } - return file_path; -} - -tensorflow::Status LoadModel(NSString* file_name, NSString* file_type, - std::unique_ptr* session) { - tensorflow::SessionOptions options; - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = - tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; - return session_status; - } - session->reset(session_pointer); - - tensorflow::GraphDef tensorflow_graph; - - NSString* model_path = FilePathForResourceName(file_name, file_type); - if (!model_path) { - LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] - << [file_type UTF8String]; - return tensorflow::errors::NotFound([file_name UTF8String], - [file_type UTF8String]); - } - const bool read_proto_succeeded = - PortableReadFileToProto([model_path UTF8String], &tensorflow_graph); - if (!read_proto_succeeded) { - LOG(ERROR) << "Failed to load model proto from" << [model_path UTF8String]; - return tensorflow::errors::NotFound([model_path UTF8String]); - } - - tensorflow::Status create_status = (*session)->Create(tensorflow_graph); - if (!create_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; - return create_status; - } - - return tensorflow::Status::OK(); -} - -tensorflow::Status LoadMemoryMappedModel( - NSString* file_name, NSString* file_type, - std::unique_ptr* session, - std::unique_ptr* memmapped_env) { - NSString* network_path = FilePathForResourceName(file_name, file_type); - memmapped_env->reset( - new tensorflow::MemmappedEnv(tensorflow::Env::Default())); - tensorflow::Status mmap_status = - (memmapped_env->get())->InitializeFromFile([network_path UTF8String]); - if (!mmap_status.ok()) { - LOG(ERROR) << "MMap failed with " << mmap_status.error_message(); - return mmap_status; - } - - tensorflow::GraphDef tensorflow_graph; - tensorflow::Status load_graph_status = ReadBinaryProto( - memmapped_env->get(), - tensorflow::MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, - &tensorflow_graph); - if (!load_graph_status.ok()) { - LOG(ERROR) << "MMap load graph failed with " - << load_graph_status.error_message(); - return load_graph_status; - } - - tensorflow::SessionOptions options; - // Disable optimizations on this graph so that constant folding doesn't - // increase the memory footprint by creating new constant copies of the weight - // parameters. - options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_opt_level(::tensorflow::OptimizerOptions::L0); - options.env = memmapped_env->get(); - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = - tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Session: " << session_status; - return session_status; - } - - tensorflow::Status create_status = session_pointer->Create(tensorflow_graph); - if (!create_status.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << create_status; - return create_status; - } - - session->reset(session_pointer); - - return tensorflow::Status::OK(); -} - -tensorflow::Status LoadLabels(NSString* file_name, NSString* file_type, - std::vector* label_strings) { - // Read the label list - NSString* labels_path = FilePathForResourceName(file_name, file_type); - if (!labels_path) { - LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] - << [file_type UTF8String]; - return tensorflow::errors::NotFound([file_name UTF8String], - [file_type UTF8String]); - } - std::ifstream t; - t.open([labels_path UTF8String]); - std::string line; - while (t) { - std::getline(t, line); - label_strings->push_back(line); - } - t.close(); - return tensorflow::Status::OK(); -} diff --git a/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj deleted file mode 100644 index ee9fe57c792..00000000000 --- a/tensorflow/examples/ios/camera/tf_camera_example.xcodeproj/project.pbxproj +++ /dev/null @@ -1,412 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 46; - objects = { - -/* Begin PBXBuildFile section */ - 1C3C9DCB1ED3AB4200B8B5FA /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */; }; - 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */; }; - 1C968D171ED3B8F20054F5C3 /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; - 1C968D181ED3B8F20054F5C3 /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; - 1C968D191ED3B8F20054F5C3 /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; - 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */; }; - 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */; }; - 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */; }; - 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */; }; - 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */; }; - 1CDB2D4C1ED3A9CD007929E9 /* tensorflow_utils.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */; }; - 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; - 54DC6C3C5F734F3A58069F0C /* libPods-tf_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */; }; -/* End PBXBuildFile section */ - -/* Begin PBXFileReference section */ - 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; - 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; - 1C3C9DC81ED3AB4200B8B5FA /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; - 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; - 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; - 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_camera_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; path = MainStoryboard_iPhone.storyboard; sourceTree = ""; }; - 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; - 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = System/Library/Frameworks/CoreMedia.framework; sourceTree = SDKROOT; }; - 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = System/Library/Frameworks/AVFoundation.framework; sourceTree = SDKROOT; }; - 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = ""; }; - 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = ""; }; - 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = ""; }; - 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = ""; }; - 1CDB2D471ED3A9CD007929E9 /* tensorflow_utils.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = tensorflow_utils.h; sourceTree = ""; }; - 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = tensorflow_utils.mm; sourceTree = ""; }; - 1CDB2D4D1ED3AA35007929E9 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; - 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_camera_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; - 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example.debug.xcconfig"; sourceTree = ""; }; - 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example.release.xcconfig"; sourceTree = ""; }; - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - 1C564C0A1ED3A92E00087306 /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, - 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, - 54DC6C3C5F734F3A58069F0C /* libPods-tf_camera_example.a in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - 24D7686C331131624F4454A0 /* Frameworks */ = { - isa = PBXGroup; - children = ( - 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, - 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, - 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, - 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, - 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, - 3BA8BF92C84895BFE59D8236 /* libPods-tf_camera_example.a */, - ); - name = Frameworks; - sourceTree = ""; - }; - 3E9FC355632FB928EA23BEED /* Pods */ = { - isa = PBXGroup; - children = ( - 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */, - 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */, - ); - name = Pods; - sourceTree = ""; - }; - 591157921CF4011C00C31E3A = { - isa = PBXGroup; - children = ( - 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */, - 1C3C9DC81ED3AB4200B8B5FA /* ios_image_load.h */, - 1C3C9DC91ED3AB4200B8B5FA /* ios_image_load.mm */, - 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */, - 1CDB2D4D1ED3AA35007929E9 /* Info.plist */, - 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */, - 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */, - 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */, - 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */, - 1CDB2D471ED3A9CD007929E9 /* tensorflow_utils.h */, - 1CDB2D481ED3A9CD007929E9 /* tensorflow_utils.mm */, - 59A3CFF31CF4E68100C4259F /* data */, - 5911579C1CF4011C00C31E3A /* Products */, - 3E9FC355632FB928EA23BEED /* Pods */, - 24D7686C331131624F4454A0 /* Frameworks */, - ); - sourceTree = ""; - }; - 5911579C1CF4011C00C31E3A /* Products */ = { - isa = PBXGroup; - children = ( - 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */, - ); - name = Products; - sourceTree = ""; - }; - 59A3CFF31CF4E68100C4259F /* data */ = { - isa = PBXGroup; - children = ( - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, - ); - path = data; - sourceTree = ""; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - 1C564C0C1ED3A92E00087306 /* tf_camera_example */ = { - isa = PBXNativeTarget; - buildConfigurationList = 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tf_camera_example" */; - buildPhases = ( - 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */, - 1C564C091ED3A92E00087306 /* Sources */, - 1C564C0A1ED3A92E00087306 /* Frameworks */, - 1C564C0B1ED3A92E00087306 /* Resources */, - 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */, - 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = tf_camera_example; - productName = tf_camera_example; - productReference = 1C564C0D1ED3A92E00087306 /* tf_camera_example.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - 591157931CF4011C00C31E3A /* Project object */ = { - isa = PBXProject; - attributes = { - LastSwiftUpdateCheck = 0830; - LastUpgradeCheck = 0830; - ORGANIZATIONNAME = Google; - TargetAttributes = { - 1C564C0C1ED3A92E00087306 = { - CreatedOnToolsVersion = 8.3.2; - DevelopmentTeam = 5DRPWFQSHP; - ProvisioningStyle = Automatic; - }; - }; - }; - buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_camera_example" */; - compatibilityVersion = "Xcode 3.2"; - developmentRegion = English; - hasScannedForEncodings = 0; - knownRegions = ( - en, - Base, - ); - mainGroup = 591157921CF4011C00C31E3A; - productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - 1C564C0C1ED3A92E00087306 /* tf_camera_example */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - 1C564C0B1ED3A92E00087306 /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 1C968D171ED3B8F20054F5C3 /* grace_hopper.jpg in Resources */, - 1C968D181ED3B8F20054F5C3 /* imagenet_comp_graph_label_strings.txt in Resources */, - 1C968D191ED3B8F20054F5C3 /* tensorflow_inception_graph.pb in Resources */, - 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */, - 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXShellScriptBuildPhase section */ - 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Embed Pods Frameworks"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example-frameworks.sh\"\n"; - showEnvVarsInLog = 0; - }; - 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Copy Pods Resources"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_camera_example/Pods-tf_camera_example-resources.sh\"\n"; - showEnvVarsInLog = 0; - }; - 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Check Pods Manifest.lock"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n"; - showEnvVarsInLog = 0; - }; -/* End PBXShellScriptBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - 1C564C091ED3A92E00087306 /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 1CDB2D4C1ED3A9CD007929E9 /* tensorflow_utils.mm in Sources */, - 1C3C9DCB1ED3AB4200B8B5FA /* ios_image_load.mm in Sources */, - 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */, - 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */, - 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin XCBuildConfiguration section */ - 1C564C361ED3A92E00087306 /* Debug */ = { - isa = XCBuildConfiguration; - baseConfigurationReference = 3BC5BE4BBD09374D3E98F082 /* Pods-tf_camera_example.debug.xcconfig */; - buildSettings = { - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - DEVELOPMENT_TEAM = 5DRPWFQSHP; - INFOPLIST_FILE = Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 10.3; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; - SWIFT_OPTIMIZATION_LEVEL = "-Onone"; - SWIFT_VERSION = 3.0; - }; - name = Debug; - }; - 1C564C371ED3A92E00087306 /* Release */ = { - isa = XCBuildConfiguration; - baseConfigurationReference = 55ED318E8D29C8AFEF03DF1E /* Pods-tf_camera_example.release.xcconfig */; - buildSettings = { - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - DEVELOPMENT_TEAM = 5DRPWFQSHP; - INFOPLIST_FILE = Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 10.3; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SWIFT_OPTIMIZATION_LEVEL = "-Owholemodule"; - SWIFT_VERSION = 3.0; - }; - name = Release; - }; - 591157B01CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - MTL_ENABLE_DEBUG_INFO = YES; - ONLY_ACTIVE_ARCH = YES; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - }; - name = Debug; - }; - 591157B11CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - MTL_ENABLE_DEBUG_INFO = NO; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - VALIDATE_PRODUCT = YES; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tf_camera_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 1C564C361ED3A92E00087306 /* Debug */, - 1C564C371ED3A92E00087306 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_camera_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B01CF4011D00C31E3A /* Debug */, - 591157B11CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - }; - rootObject = 591157931CF4011C00C31E3A /* Project object */; -} diff --git a/tensorflow/examples/ios/simple/AppDelegate.h b/tensorflow/examples/ios/simple/AppDelegate.h deleted file mode 100644 index 75b1f1da384..00000000000 --- a/tensorflow/examples/ios/simple/AppDelegate.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface AppDelegate : UIResponder - -@property (strong, nonatomic) UIWindow *window; - -@end diff --git a/tensorflow/examples/ios/simple/AppDelegate.mm b/tensorflow/examples/ios/simple/AppDelegate.mm deleted file mode 100644 index 1e808eb976f..00000000000 --- a/tensorflow/examples/ios/simple/AppDelegate.mm +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "AppDelegate.h" - -#import "RunModelViewController.h" - -@implementation AppDelegate - -- (BOOL)application:(UIApplication *)application - didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - - UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[RunModelViewController alloc] init]]]; - bar.selectedIndex = 0; - self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; - self.window.rootViewController = bar; - [self.window makeKeyAndVisible]; - return YES; -} - -- (void)applicationWillResignActive:(UIApplication *)application {} - -- (void)applicationDidEnterBackground:(UIApplication *)application {} - -- (void)applicationWillEnterForeground:(UIApplication *)application {} - -- (void)applicationDidBecomeActive:(UIApplication *)application {} - -- (void)applicationWillTerminate:(UIApplication *)application {} - -@end diff --git a/tensorflow/examples/ios/simple/Podfile b/tensorflow/examples/ios/simple/Podfile deleted file mode 100644 index 1740ad64573..00000000000 --- a/tensorflow/examples/ios/simple/Podfile +++ /dev/null @@ -1,5 +0,0 @@ -platform :ios, '8.0' -inhibit_all_warnings! - -target 'tf_simple_example' - pod 'TensorFlow-experimental' diff --git a/tensorflow/examples/ios/simple/RunModel-Info.plist b/tensorflow/examples/ios/simple/RunModel-Info.plist deleted file mode 100644 index d0a8742456f..00000000000 --- a/tensorflow/examples/ios/simple/RunModel-Info.plist +++ /dev/null @@ -1,47 +0,0 @@ - - - - - CFBundleDevelopmentRegion - en - CFBundleDisplayName - tf_simple_example - CFBundleExecutable - tf_simple_example - CFBundleIdentifier - $(PRODUCT_BUNDLE_IDENTIFIER) - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - ios-app - CFBundlePackageType - APPL - CFBundleShortVersionString - 1.0 - CFBundleSignature - ???? - CFBundleVersion - 1.0 - LSRequiresIPhoneOS - - UILaunchStoryboardName - RunModelViewController - UIRequiredDeviceCapabilities - - armv7 - - UISupportedInterfaceOrientations - - UIInterfaceOrientationPortrait - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - UISupportedInterfaceOrientations~ipad - - UIInterfaceOrientationPortrait - UIInterfaceOrientationPortraitUpsideDown - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - - diff --git a/tensorflow/examples/ios/simple/RunModelViewController.h b/tensorflow/examples/ios/simple/RunModelViewController.h deleted file mode 100644 index 4e1a83ccf5a..00000000000 --- a/tensorflow/examples/ios/simple/RunModelViewController.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -@interface RunModelViewController : UIViewController - -- (IBAction)getUrl:(id)sender; - -@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property (weak, nonatomic) IBOutlet UITextField *urlTextField; - -@end diff --git a/tensorflow/examples/ios/simple/RunModelViewController.mm b/tensorflow/examples/ios/simple/RunModelViewController.mm deleted file mode 100644 index c8ccb5c77b2..00000000000 --- a/tensorflow/examples/ios/simple/RunModelViewController.mm +++ /dev/null @@ -1,253 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "RunModelViewController.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/public/session.h" - -#include "ios_image_load.h" - -NSString* RunInferenceOnImage(); - -namespace { -class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream { - public: - explicit IfstreamInputStream(const std::string& file_name) - : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {} - ~IfstreamInputStream() { ifs_.close(); } - - int Read(void* buffer, int size) { - if (!ifs_) { - return -1; - } - ifs_.read(static_cast(buffer), size); - return (int)ifs_.gcount(); - } - - private: - std::ifstream ifs_; -}; -} // namespace - -@interface RunModelViewController () -@end - -@implementation RunModelViewController { -} - -- (IBAction)getUrl:(id)sender { - NSString* inference_result = RunInferenceOnImage(); - self.urlContentTextView.text = inference_result; -} - -@end - -// Returns the top N confidence values over threshold in the provided vector, -// sorted by confidence in descending order. -static void GetTopN( - const Eigen::TensorMap, - Eigen::Aligned>& prediction, - const int num_results, const float threshold, - std::vector >* top_results) { - // Will contain top N results in ascending order. - std::priority_queue, - std::vector >, - std::greater > > top_result_pq; - - const long count = prediction.size(); - for (int i = 0; i < count; ++i) { - const float value = prediction(i); - - // Only add it if it beats the threshold and has a chance at being in - // the top N. - if (value < threshold) { - continue; - } - - top_result_pq.push(std::pair(value, i)); - - // If at capacity, kick the smallest value out. - if (top_result_pq.size() > num_results) { - top_result_pq.pop(); - } - } - - // Copy to output vector and reverse into descending order. - while (!top_result_pq.empty()) { - top_results->push_back(top_result_pq.top()); - top_result_pq.pop(); - } - std::reverse(top_results->begin(), top_results->end()); -} - - -bool PortableReadFileToProto(const std::string& file_name, - ::google::protobuf::MessageLite* proto) { - ::google::protobuf::io::CopyingInputStreamAdaptor stream( - new IfstreamInputStream(file_name)); - stream.SetOwnsCopyingStream(true); - // TODO(jiayq): the following coded stream is for debugging purposes to allow - // one to parse arbitrarily large messages for MessageLite. One most likely - // doesn't want to put protobufs larger than 64MB on Android, so we should - // eventually remove this and quit loud when a large protobuf is passed in. - ::google::protobuf::io::CodedInputStream coded_stream(&stream); - // Total bytes hard limit / warning limit are set to 1GB and 512MB - // respectively. - coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); - return proto->ParseFromCodedStream(&coded_stream); -} - -NSString* FilePathForResourceName(NSString* name, NSString* extension) { - NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; - if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; - } - return file_path; -} - -NSString* RunInferenceOnImage() { - tensorflow::SessionOptions options; - - tensorflow::Session* session_pointer = nullptr; - tensorflow::Status session_status = tensorflow::NewSession(options, &session_pointer); - if (!session_status.ok()) { - std::string status_string = session_status.ToString(); - return [NSString stringWithFormat: @"Session create failed - %s", - status_string.c_str()]; - } - std::unique_ptr session(session_pointer); - LOG(INFO) << "Session created."; - - tensorflow::GraphDef tensorflow_graph; - LOG(INFO) << "Graph created."; - - NSString* network_path = FilePathForResourceName(@"tensorflow_inception_graph", @"pb"); - PortableReadFileToProto([network_path UTF8String], &tensorflow_graph); - - LOG(INFO) << "Creating session."; - tensorflow::Status s = session->Create(tensorflow_graph); - if (!s.ok()) { - LOG(ERROR) << "Could not create TensorFlow Graph: " << s; - return @""; - } - - // Read the label list - NSString* labels_path = FilePathForResourceName(@"imagenet_comp_graph_label_strings", @"txt"); - std::vector label_strings; - std::ifstream t; - t.open([labels_path UTF8String]); - std::string line; - while(t){ - std::getline(t, line); - label_strings.push_back(line); - } - t.close(); - - // Read the Grace Hopper image. - NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); - int image_width; - int image_height; - int image_channels; - std::vector image_data = LoadImageFromFile( - [image_path UTF8String], &image_width, &image_height, &image_channels); - const int wanted_width = 224; - const int wanted_height = 224; - const int wanted_channels = 3; - const float input_mean = 117.0f; - const float input_std = 1.0f; - assert(image_channels >= wanted_channels); - tensorflow::Tensor image_tensor( - tensorflow::DT_FLOAT, - tensorflow::TensorShape({ - 1, wanted_height, wanted_width, wanted_channels})); - auto image_tensor_mapped = image_tensor.tensor(); - tensorflow::uint8* in = image_data.data(); - // tensorflow::uint8* in_end = (in + (image_height * image_width * image_channels)); - float* out = image_tensor_mapped.data(); - for (int y = 0; y < wanted_height; ++y) { - const int in_y = (y * image_height) / wanted_height; - tensorflow::uint8* in_row = in + (in_y * image_width * image_channels); - float* out_row = out + (y * wanted_width * wanted_channels); - for (int x = 0; x < wanted_width; ++x) { - const int in_x = (x * image_width) / wanted_width; - tensorflow::uint8* in_pixel = in_row + (in_x * image_channels); - float* out_pixel = out_row + (x * wanted_channels); - for (int c = 0; c < wanted_channels; ++c) { - out_pixel[c] = (in_pixel[c] - input_mean) / input_std; - } - } - } - - NSString* result = [network_path stringByAppendingString: @" - loaded!"]; - result = [NSString stringWithFormat: @"%@ - %lu, %s - %dx%d", result, - label_strings.size(), label_strings[0].c_str(), image_width, image_height]; - - std::string input_layer = "input"; - std::string output_layer = "output"; - std::vector outputs; - tensorflow::Status run_status = session->Run({{input_layer, image_tensor}}, - {output_layer}, {}, &outputs); - if (!run_status.ok()) { - LOG(ERROR) << "Running model failed: " << run_status; - tensorflow::LogAllRegisteredKernels(); - result = @"Error running model"; - return result; - } - tensorflow::string status_string = run_status.ToString(); - result = [NSString stringWithFormat: @"%@ - %s", result, - status_string.c_str()]; - - tensorflow::Tensor* output = &outputs[0]; - const int kNumResults = 5; - const float kThreshold = 0.1f; - std::vector > top_results; - GetTopN(output->flat(), kNumResults, kThreshold, &top_results); - - std::stringstream ss; - ss.precision(3); - for (const auto& result : top_results) { - const float confidence = result.first; - const int index = result.second; - - ss << index << " " << confidence << " "; - - // Write out the result as a string - if (index < label_strings.size()) { - // just for safety: theoretically, the output is under 1000 unless there - // is some numerical issues leading to a wrong prediction. - ss << label_strings[index]; - } else { - ss << "Prediction: " << index; - } - - ss << "\n"; - } - - LOG(INFO) << "Predictions: " << ss.str(); - - tensorflow::string predictions = ss.str(); - result = [NSString stringWithFormat: @"%@ - %s", result, - predictions.c_str()]; - - return result; -} diff --git a/tensorflow/examples/ios/simple/RunModelViewController.xib b/tensorflow/examples/ios/simple/RunModelViewController.xib deleted file mode 100644 index 93f334b9850..00000000000 --- a/tensorflow/examples/ios/simple/RunModelViewController.xib +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/examples/ios/simple/data/grace_hopper.jpg b/tensorflow/examples/ios/simple/data/grace_hopper.jpg deleted file mode 100644 index d2a427810f6..00000000000 Binary files a/tensorflow/examples/ios/simple/data/grace_hopper.jpg and /dev/null differ diff --git a/tensorflow/examples/ios/simple/ios_image_load.h b/tensorflow/examples/ios/simple/ios_image_load.h deleted file mode 100644 index 0e0b771118b..00000000000 --- a/tensorflow/examples/ios/simple/ios_image_load.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ -#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ - -#include - -#include "tensorflow/core/framework/types.h" - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); - -#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/examples/ios/simple/ios_image_load.mm b/tensorflow/examples/ios/simple/ios_image_load.mm deleted file mode 100644 index 64d1ea21cf2..00000000000 --- a/tensorflow/examples/ios/simple/ios_image_load.mm +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ios_image_load.h" - -#include -#include -#include -#include - -#import -#import - -using tensorflow::uint8; - -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { - FILE* file_handle = fopen(file_name, "rb"); - fseek(file_handle, 0, SEEK_END); - const size_t bytes_in_file = ftell(file_handle); - fseek(file_handle, 0, SEEK_SET); - std::vector file_data(bytes_in_file); - fread(file_data.data(), 1, bytes_in_file, file_handle); - fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); - - const char* suffix = strrchr(file_name, '.'); - if (!suffix || suffix == file_name) { - suffix = ""; - } - CGImageRef image; - if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else { - CFRelease(image_provider); - CFRelease(file_data_ref); - fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); - *out_width = 0; - *out_height = 0; - *out_channels = 0; - return std::vector(); - } - - const int width = (int)CGImageGetWidth(image); - const int height = (int)CGImageGetHeight(image); - const int channels = 4; - CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); - const int bytes_per_row = (width * channels); - const int bytes_in_image = (bytes_per_row * height); - std::vector result(bytes_in_image); - const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); - CGColorSpaceRelease(color_space); - CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); - CGContextRelease(context); - CFRelease(image); - CFRelease(image_provider); - CFRelease(file_data_ref); - - *out_width = width; - *out_height = height; - *out_channels = channels; - return result; -} diff --git a/tensorflow/examples/ios/simple/main.mm b/tensorflow/examples/ios/simple/main.mm deleted file mode 100644 index d70550a7307..00000000000 --- a/tensorflow/examples/ios/simple/main.mm +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2015 Google Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import - -int main(int argc, char * argv[]) { - @autoreleasepool { - NSString *delegateClassName = @"AppDelegate"; - return UIApplicationMain(argc, argv, nil, delegateClassName); - } -} diff --git a/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj b/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj deleted file mode 100644 index 55c06e28fb3..00000000000 --- a/tensorflow/examples/ios/simple/tf_simple_example.xcodeproj/project.pbxproj +++ /dev/null @@ -1,404 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 46; - objects = { - -/* Begin PBXBuildFile section */ - 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; }; - 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; }; - 2530463E3C9A9D5FB9299C0E /* libPods-tf_simple_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */; }; - 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; - 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; - 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */; }; - 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */; }; - 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; - 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; - 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; }; - 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; }; -/* End PBXBuildFile section */ - -/* Begin PBXFileReference section */ - 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; - 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; - 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; - 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = imagenet_comp_graph_label_strings.txt; sourceTree = ""; }; - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */ = {isa = PBXFileReference; lastKnownFileType = file; path = tensorflow_inception_graph.pb; sourceTree = ""; }; - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; - 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; - 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "RunModel-Info.plist"; sourceTree = ""; }; - 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; - 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; - 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; - 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; - 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_simple_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example.release.xcconfig"; sourceTree = ""; }; - 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tf_simple_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example.debug.xcconfig"; sourceTree = ""; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - 591157981CF4011C00C31E3A /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */, - 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */, - 2530463E3C9A9D5FB9299C0E /* libPods-tf_simple_example.a in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - 24D7686C331131624F4454A0 /* Frameworks */ = { - isa = PBXGroup; - children = ( - 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, - 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, - 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, - 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */, - ); - name = Frameworks; - sourceTree = ""; - }; - 3E9FC355632FB928EA23BEED /* Pods */ = { - isa = PBXGroup; - children = ( - 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */, - 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */, - ); - name = Pods; - sourceTree = ""; - }; - 591157921CF4011C00C31E3A = { - isa = PBXGroup; - children = ( - 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, - 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, - 59A3CFF31CF4E68100C4259F /* data */, - 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, - 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, - 59A3CFFC1CF4E68100C4259F /* main.mm */, - 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */, - 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */, - 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */, - 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, - 5911579C1CF4011C00C31E3A /* Products */, - 3E9FC355632FB928EA23BEED /* Pods */, - 24D7686C331131624F4454A0 /* Frameworks */, - ); - sourceTree = ""; - }; - 5911579C1CF4011C00C31E3A /* Products */ = { - isa = PBXGroup; - children = ( - 5911579B1CF4011C00C31E3A /* tf_simple_example.app */, - ); - name = Products; - sourceTree = ""; - }; - 59A3CFF31CF4E68100C4259F /* data */ = { - isa = PBXGroup; - children = ( - 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, - 59A3CFF71CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt */, - 59A3CFF91CF4E68100C4259F /* tensorflow_inception_graph.pb */, - ); - path = data; - sourceTree = ""; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - 5911579A1CF4011C00C31E3A /* tf_simple_example */ = { - isa = PBXNativeTarget; - buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */; - buildPhases = ( - 1CD07C1CEB04E50C5975C7BB /* [CP] Check Pods Manifest.lock */, - 591157971CF4011C00C31E3A /* Sources */, - 591157981CF4011C00C31E3A /* Frameworks */, - 591157991CF4011C00C31E3A /* Resources */, - 0EABEF9F31578BDA8CA9D2A7 /* [CP] Embed Pods Frameworks */, - 96DDF9E6E35958387A215092 /* [CP] Copy Pods Resources */, - ); - buildRules = ( - ); - dependencies = ( - ); - name = tf_simple_example; - productName = tf_ios_makefile_example; - productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - 591157931CF4011C00C31E3A /* Project object */ = { - isa = PBXProject; - attributes = { - LastUpgradeCheck = 0830; - ORGANIZATIONNAME = Google; - TargetAttributes = { - 5911579A1CF4011C00C31E3A = { - CreatedOnToolsVersion = 7.2; - DevelopmentTeam = 85Z3VXS37U; - }; - }; - }; - buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_simple_example" */; - compatibilityVersion = "Xcode 3.2"; - developmentRegion = English; - hasScannedForEncodings = 0; - knownRegions = ( - en, - Base, - ); - mainGroup = 591157921CF4011C00C31E3A; - productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - 5911579A1CF4011C00C31E3A /* tf_simple_example */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - 591157991CF4011C00C31E3A /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */, - 59A3D0051CF4E68100C4259F /* imagenet_comp_graph_label_strings.txt in Resources */, - 59A3D0071CF4E68100C4259F /* tensorflow_inception_graph.pb in Resources */, - 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXShellScriptBuildPhase section */ - 0EABEF9F31578BDA8CA9D2A7 /* [CP] Embed Pods Frameworks */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Embed Pods Frameworks"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example-frameworks.sh\"\n"; - showEnvVarsInLog = 0; - }; - 1CD07C1CEB04E50C5975C7BB /* [CP] Check Pods Manifest.lock */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Check Pods Manifest.lock"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n"; - showEnvVarsInLog = 0; - }; - 96DDF9E6E35958387A215092 /* [CP] Copy Pods Resources */ = { - isa = PBXShellScriptBuildPhase; - buildActionMask = 2147483647; - files = ( - ); - inputPaths = ( - ); - name = "[CP] Copy Pods Resources"; - outputPaths = ( - ); - runOnlyForDeploymentPostprocessing = 0; - shellPath = /bin/sh; - shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tf_simple_example/Pods-tf_simple_example-resources.sh\"\n"; - showEnvVarsInLog = 0; - }; -/* End PBXShellScriptBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - 591157971CF4011C00C31E3A /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 59A3D0091CF4E68100C4259F /* main.mm in Sources */, - 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */, - 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */, - 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin XCBuildConfiguration section */ - 591157B01CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_DYNAMIC_NO_PIC = NO; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - MTL_ENABLE_DEBUG_INFO = YES; - ONLY_ACTIVE_ARCH = YES; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - }; - name = Debug; - }; - 591157B11CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_C_LANGUAGE_STANDARD = gnu99; - GCC_NO_COMMON_BLOCKS = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 8.0; - MTL_ENABLE_DEBUG_INFO = NO; - SDKROOT = iphoneos; - TARGETED_DEVICE_FAMILY = "1,2"; - VALIDATE_PRODUCT = YES; - }; - name = Release; - }; - 591157B31CF4011D00C31E3A /* Debug */ = { - isa = XCBuildConfiguration; - baseConfigurationReference = 8C94FEE43FD467468C5B75AA /* Pods-tf_simple_example.debug.xcconfig */; - buildSettings = { - CLANG_DEBUG_INFORMATION_LEVEL = default; - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - GCC_ENABLE_CPP_EXCEPTIONS = YES; - GCC_ENABLE_CPP_RTTI = YES; - HEADER_SEARCH_PATHS = "$(inherited)"; - INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ""; - OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - OTHER_LDFLAGS = "$(inherited)"; - PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-simple-example"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SEPARATE_STRIP = NO; - }; - name = Debug; - }; - 591157B41CF4011D00C31E3A /* Release */ = { - isa = XCBuildConfiguration; - baseConfigurationReference = 87ABECA6543FF90E81111A6D /* Pods-tf_simple_example.release.xcconfig */; - buildSettings = { - CLANG_DEBUG_INFORMATION_LEVEL = default; - CODE_SIGN_IDENTITY = "iPhone Developer"; - ENABLE_BITCODE = NO; - GCC_ENABLE_CPP_EXCEPTIONS = YES; - GCC_ENABLE_CPP_RTTI = YES; - HEADER_SEARCH_PATHS = "$(inherited)"; - INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.2; - LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ""; - ONLY_ACTIVE_ARCH = YES; - OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - OTHER_LDFLAGS = "$(inherited)"; - PRODUCT_BUNDLE_IDENTIFIER = "com.google.tf-simple-example"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SEPARATE_STRIP = NO; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tf_simple_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B01CF4011D00C31E3A /* Debug */, - 591157B11CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 591157B31CF4011D00C31E3A /* Debug */, - 591157B41CF4011D00C31E3A /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - }; - rootObject = 591157931CF4011C00C31E3A /* Project object */; -} diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD index 7c3a6dca1b2..3e1772dc21c 100644 --- a/tensorflow/examples/label_image/BUILD +++ b/tensorflow/examples/label_image/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - exports_files(["data/grace_hopper.jpg"]) tf_cc_binary( diff --git a/tensorflow/examples/multibox_detector/BUILD b/tensorflow/examples/multibox_detector/BUILD index 4b1e18954fa..5a7a4cbeb71 100644 --- a/tensorflow/examples/multibox_detector/BUILD +++ b/tensorflow/examples/multibox_detector/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tf_cc_binary( name = "detect_objects", srcs = [ diff --git a/tensorflow/examples/saved_model/integration_tests/BUILD b/tensorflow/examples/saved_model/integration_tests/BUILD deleted file mode 100644 index 0bfbee1ee2a..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/BUILD +++ /dev/null @@ -1,86 +0,0 @@ -load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test") - -package( - licenses = ["notice"], # Apache 2.0 -) - -exports_files(["LICENSE"]) - -py_library( - name = "integration_scripts", - srcs = [ - "deploy_mnist_cnn.py", - "export_mnist_cnn.py", - "export_rnn_cell.py", - "export_simple_text_embedding.py", - "export_text_rnn_model.py", - "integration_scripts.py", - "use_mnist_cnn.py", - "use_model_in_sequential_keras.py", - "use_rnn_cell.py", - "use_text_embedding_in_dataset.py", - "use_text_rnn_model.py", - ], - visibility = ["//tensorflow:internal"], - deps = [ - ":distribution_strategy_utils", - ":mnist_util", - "//tensorflow:tensorflow_py", - "@absl_py//absl/testing:parameterized", - ], -) - -py_library( - name = "mnist_util", - srcs = ["mnist_util.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "distribution_strategy_utils", - srcs = ["distribution_strategy_utils.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/python/distribute:strategy_combinations", - ], -) - -distribute_py_test( - name = "saved_model_test", - srcs = [ - "saved_model_test.py", - ], - shard_count = 4, - tags = [ - "no_pip", # b/131697937 and b/132196869 - "noasan", # forge input size exceeded - "nomsan", # forge input size exceeded - "notsan", # forge input size exceeded - ], - tpu_tags = [ - "no_oss", # Test infra collision (b/157754990) - ], - deps = [ - ":distribution_strategy_utils", - ":integration_scripts", - "//tensorflow:tensorflow_py", - "//tensorflow/python:framework_combinations", - "//tensorflow/python/distribute:combinations", - "@absl_py//absl/testing:parameterized", - ], -) - -# b/132234211: Target added to support internal test target that runs the test -# in an environment that has the extra dependencies required to test integration -# with non core tensorflow packages. -py_library( - name = "saved_model_test_lib", - srcs = [ - "saved_model_test.py", - ], - visibility = ["//tensorflow:internal"], - deps = [":integration_scripts"], -) diff --git a/tensorflow/examples/saved_model/integration_tests/deploy_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/deploy_mnist_cnn.py deleted file mode 100644 index cc5cb6b2a6c..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/deploy_mnist_cnn.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Deploys a SavedModel with an MNIST classifier to TFLite.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import app -from absl import flags -import numpy as np -import tensorflow.compat.v2 as tf - -from tensorflow.examples.saved_model.integration_tests import mnist_util - -FLAGS = flags.FLAGS - -flags.DEFINE_string( - 'saved_model_dir', None, - 'Directory of the SavedModel to deploy.') -flags.DEFINE_bool( - 'use_fashion_mnist', False, - 'Use Fashion MNIST (products) instead of the real MNIST (digits).') -flags.DEFINE_bool( - 'fast_test_mode', False, - 'Limit amount of test data for running in unit tests.') -flags.DEFINE_string( - 'tflite_output_file', None, - 'The filename of the .tflite model file to write (optional).') -flags.DEFINE_bool( - 'reload_as_keras_model', True, - 'Also test tf.keras.models.load_model() on --saved_model_dir.') - - -def main(argv): - del argv - - # First convert the SavedModel in a pristine environment. - converter = tf.lite.TFLiteConverter.from_saved_model(FLAGS.saved_model_dir) - lite_model_content = converter.convert() - # Here is how you can save it for actual deployment. - if FLAGS.tflite_output_file: - with open(FLAGS.tflite_output_file, 'wb') as outfile: - outfile.write(lite_model_content) - # For testing, the TFLite model can be executed like this. - interpreter = tf.lite.Interpreter(model_content=lite_model_content) - def lite_model(images): - interpreter.allocate_tensors() - interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images) - interpreter.invoke() - return interpreter.get_tensor(interpreter.get_output_details()[0]['index']) - - # Load the SavedModel again for use as a test baseline. - imported = tf.saved_model.load(FLAGS.saved_model_dir) - def tf_model(images): - output_dict = imported.signatures['serving_default'](tf.constant(images)) - logits, = output_dict.values() # Unpack single value. - return logits - - # Compare model outputs on the test inputs. - (_, _), (x_test, _) = mnist_util.load_reshaped_data( - use_fashion_mnist=FLAGS.use_fashion_mnist, - fake_tiny_data=FLAGS.fast_test_mode) - for i, x in enumerate(x_test): - x = x[None, ...] # Make batch of size 1. - y_lite = lite_model(x) - y_tf = tf_model(x) - # This numpy primitive uses plain `raise` and works outside tf.TestCase. - # Model outputs are probabilities that sum to 1, so atol makes sense here. - np.testing.assert_allclose( - y_lite, y_tf, rtol=0, atol=1e-5, - err_msg='Mismatch with TF Lite at test example %d' % i) - - # Test that the SavedModel loads correctly with v1 load APIs as well. - with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as session: - tf.compat.v1.saved_model.load( - session, - [tf.compat.v1.saved_model.SERVING], - FLAGS.saved_model_dir) - - # The SavedModel actually was a Keras Model; test that it also loads as that. - if FLAGS.reload_as_keras_model: - keras_model = tf.keras.models.load_model(FLAGS.saved_model_dir) - for i, x in enumerate(x_test): - x = x[None, ...] # Make batch of size 1. - y_tf = tf_model(x) - y_keras = keras_model(x) - # This numpy primitive uses plain `raise` and works outside tf.TestCase. - # Model outputs are probabilities that sum to 1, so atol makes sense here. - np.testing.assert_allclose( - y_tf, y_keras, rtol=0, atol=1e-5, - err_msg='Mismatch with Keras at test example %d' % i) - -if __name__ == '__main__': - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/distribution_strategy_utils.py b/tensorflow/examples/saved_model/integration_tests/distribution_strategy_utils.py deleted file mode 100644 index eabe4cf3802..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/distribution_strategy_utils.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utils related to tf.distribute.strategy.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import sys - -from tensorflow.python.distribute import strategy_combinations - -_strategies = [ - strategy_combinations.one_device_strategy, - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_one_gpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.mirrored_strategy_with_two_gpus, - strategy_combinations.tpu_strategy, -] - -# The presence of GPU strategies upsets TPU initialization, -# despite their test instances being skipped early. This is a workaround -# for b/145386854. -if "test_tpu" in sys.argv[0]: - _strategies = [s for s in _strategies if "GPU" not in str(s)] - - -named_strategies = collections.OrderedDict( - [(None, None)] + - [(str(s), s) for s in _strategies] -) - - -class MaybeDistributionScope(object): - """Provides a context allowing no distribution strategy.""" - - @staticmethod - def from_name(name): - return MaybeDistributionScope(named_strategies[name].strategy if name - else None) - - def __init__(self, distribution): - self._distribution = distribution - self._scope = None - - def __enter__(self): - if self._distribution: - self._scope = self._distribution.scope() - self._scope.__enter__() - - def __exit__(self, exc_type, value, traceback): - if self._distribution: - self._scope.__exit__(exc_type, value, traceback) - self._scope = None diff --git a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py deleted file mode 100644 index ea1c138d164..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/export_mnist_cnn.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Exports a convolutional feature extractor for MNIST in SavedModel format. - -The feature extractor is a convolutional neural network plus a hidden layer -that gets trained as part of an MNIST classifier and then written to a -SavedModel (without the classification layer). From there, use_mnist_cnn.py -picks it up for transfer learning. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import app -from absl import flags -import tensorflow.compat.v2 as tf - -from tensorflow.examples.saved_model.integration_tests import mnist_util -from tensorflow.python.util import tf_decorator -from tensorflow.python.util import tf_inspect - -FLAGS = flags.FLAGS - -flags.DEFINE_string( - 'export_dir', None, - 'Directory of exported SavedModel.') -flags.DEFINE_integer( - 'epochs', 10, - 'Number of epochs to train.') -flags.DEFINE_bool( - 'use_keras_save_api', False, - 'Uses tf.keras.models.save_model() on the feature extractor ' - 'instead of tf.saved_model.save() on a manually wrapped version. ' - 'With this, the exported model has no hparams.') -flags.DEFINE_bool( - 'fast_test_mode', False, - 'Shortcut training for running in unit tests.') -flags.DEFINE_bool( - 'export_print_hparams', False, - 'If true, the exported function will print its effective hparams.') - - -def make_feature_extractor(l2_strength, dropout_rate): - """Returns a Keras Model to compute a feature vector from MNIST images.""" - regularizer = lambda: tf.keras.regularizers.l2(l2_strength) - net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE) - net = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', name='conv1', - kernel_regularizer=regularizer())(net) - net = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', name='conv2', - kernel_regularizer=regularizer())(net) - net = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name='pool1')(net) - net = tf.keras.layers.Dropout(dropout_rate, name='dropout1')(net) - net = tf.keras.layers.Flatten(name='flatten')(net) - net = tf.keras.layers.Dense(10, activation='relu', name='dense1', - kernel_regularizer=regularizer())(net) - return tf.keras.Model(inputs=inp, outputs=net) - - -def set_feature_extractor_hparams(model, dropout_rate): - model.get_layer('dropout1').rate = dropout_rate - - -def make_classifier(feature_extractor, l2_strength, dropout_rate=0.5): - """Returns a Keras Model to classify MNIST using feature_extractor.""" - regularizer = lambda: tf.keras.regularizers.l2(l2_strength) - net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE) - net = feature_extractor(net) - net = tf.keras.layers.Dropout(dropout_rate)(net) - net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax', - kernel_regularizer=regularizer())(net) - return tf.keras.Model(inputs=inp, outputs=net) - - -def wrap_keras_model_for_export(model, batch_input_shape, - set_hparams, default_hparams): - """Wraps `model` for saving and loading as SavedModel.""" - # The primary input to the module is a Tensor with a batch of images. - # Here we determine its spec. - inputs_spec = tf.TensorSpec(shape=batch_input_shape, dtype=tf.float32) - - # The module also accepts certain hparams as optional Tensor inputs. - # Here, we cut all the relevant slices from `default_hparams` - # (and don't worry if anyone accidentally modifies it later). - if default_hparams is None: default_hparams = {} - hparam_keys = list(default_hparams.keys()) - hparam_defaults = tuple(default_hparams.values()) - hparams_spec = {name: tf.TensorSpec.from_tensor(tf.constant(value)) - for name, value in default_hparams.items()} - - # The goal is to save a function with this argspec... - argspec = tf_inspect.FullArgSpec( - args=(['inputs', 'training'] + hparam_keys), - defaults=((False,) + hparam_defaults), - varargs=None, varkw=None, - kwonlyargs=[], kwonlydefaults=None, - annotations={}) - # ...and this behavior: - def call_fn(inputs, training, *args): - if FLAGS.export_print_hparams: - args = [tf.keras.backend.print_tensor(args[i], 'training=%s and %s=' - % (training, hparam_keys[i])) - for i in range(len(args))] - kwargs = dict(zip(hparam_keys, args)) - if kwargs: set_hparams(model, **kwargs) - return model(inputs, training=training) - - # We cannot spell out `args` in def statement for call_fn, but since - # tf.function uses tf_inspect, we can use tf_decorator to wrap it with - # the desired argspec. - def wrapped(*args, **kwargs): # TODO(arnoegw): Can we use call_fn itself? - return call_fn(*args, **kwargs) - traced_call_fn = tf.function( - tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec)) - - # Now we need to trigger traces for all supported combinations of the - # non-Tensor-value inputs. - for training in (True, False): - traced_call_fn.get_concrete_function(inputs_spec, training, **hparams_spec) - - # Finally, we assemble the object for tf.saved_model.save(). - obj = tf.train.Checkpoint() - obj.__call__ = traced_call_fn - obj.trainable_variables = model.trainable_variables - obj.variables = model.trainable_variables + model.non_trainable_variables - # Make tf.functions for the regularization terms of the loss. - obj.regularization_losses = [_get_traced_loss(model, i) - for i in range(len(model.losses))] - return obj - - -def _get_traced_loss(model, i): - """Returns tf.function for model.losses[i] with a trace for zero args. - - The intended usage is - [_get_traced_loss(model, i) for i in range(len(model.losses))] - This is better than - [tf.function(lambda: model.losses[i], input_signature=[]) for i ...] - because it avoids capturing a loop index in a lambda, and removes any - chance of deferring the trace. - - Args: - model: a Keras Model. - i: an integer between from 0 up to but to len(model.losses). - """ - f = tf.function(lambda: model.losses[i]) - _ = f.get_concrete_function() - return f - - -def main(argv): - del argv - - # Build a complete classifier model using a feature extractor. - default_hparams = dict(dropout_rate=0.25) - l2_strength = 0.01 # Not a hparam for inputs -> outputs. - feature_extractor = make_feature_extractor(l2_strength=l2_strength, - **default_hparams) - classifier = make_classifier(feature_extractor, l2_strength=l2_strength) - - # Train the complete model. - (x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data( - fake_tiny_data=FLAGS.fast_test_mode) - classifier.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.keras.optimizers.SGD(), - metrics=['accuracy']) - classifier.fit(x_train, y_train, - batch_size=128, - epochs=FLAGS.epochs, - verbose=1, - validation_data=(x_test, y_test)) - - # Save the feature extractor to a framework-agnostic SavedModel for reuse. - # Note that the feature_extractor object has not been compiled or fitted, - # so it does not contain an optimizer and related state. - if FLAGS.use_keras_save_api: - # Use Keras' built-in way of creating reusable SavedModels. - # This has no support for adjustable hparams at this time (July 2019). - # (We could also call tf.saved_model.save(feature_extractor, ...), - # point is we're passing a Keras model, not a plain Checkpoint.) - tf.keras.models.save_model(feature_extractor, FLAGS.export_dir) - else: - # Assemble a reusable SavedModel manually, with adjustable hparams. - exportable = wrap_keras_model_for_export(feature_extractor, - (None,) + mnist_util.INPUT_SHAPE, - set_feature_extractor_hparams, - default_hparams) - tf.saved_model.save(exportable, FLAGS.export_dir) - - -if __name__ == '__main__': - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/export_rnn_cell.py b/tensorflow/examples/saved_model/integration_tests/export_rnn_cell.py deleted file mode 100644 index 6a2853f0617..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/export_rnn_cell.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Export an RNN cell in SavedModel format.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import app -from absl import flags -import numpy as np - -import tensorflow.compat.v2 as tf - -FLAGS = flags.FLAGS - -flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.") - - -def main(argv): - del argv - - root = tf.train.Checkpoint() - # Create a cell and attach to our trackable. - root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None) - - # Wrap the rnn_cell.__call__ function and assign to next_state. - root.next_state = tf.function(root.rnn_cell.__call__) - - # Wrap the rnn_cell.get_initial_function using a decorator and assign to an - # attribute with the same name. - @tf.function(input_signature=[tf.TensorSpec([None, None], tf.float32)]) - def get_initial_state(tensor): - return root.rnn_cell.get_initial_state(tensor, None, None) - - root.get_initial_state = get_initial_state - - # Construct an initial_state, then call next_state explicitly to trigger a - # trace for serialization (we need an explicit call, because next_state has - # not been annotated with an input_signature). - initial_state = root.get_initial_state( - tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32))) - root.next_state( - tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)), - initial_state) - - tf.saved_model.save(root, FLAGS.export_dir) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py b/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py deleted file mode 100644 index 891a8f1c7e2..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/export_simple_text_embedding.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Text embedding model stored as a SavedModel.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import tempfile - -from absl import app -from absl import flags - -import tensorflow.compat.v2 as tf - -FLAGS = flags.FLAGS - -flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.") - - -def write_vocabulary_file(vocabulary): - """Write temporary vocab file for module construction.""" - tmpdir = tempfile.mkdtemp() - vocabulary_file = os.path.join(tmpdir, "tokens.txt") - with tf.io.gfile.GFile(vocabulary_file, "w") as f: - for entry in vocabulary: - f.write(entry + "\n") - return vocabulary_file - - -class TextEmbeddingModel(tf.train.Checkpoint): - """Text embedding model. - - A text embeddings model that takes a sentences on input and outputs the - sentence embedding. - """ - - def __init__(self, vocabulary, emb_dim, oov_buckets): - super(TextEmbeddingModel, self).__init__() - self._oov_buckets = oov_buckets - self._total_size = len(vocabulary) + oov_buckets - # Assign the table initializer to this instance to ensure the asset - # it depends on is saved with the SavedModel. - self._table_initializer = tf.lookup.TextFileInitializer( - write_vocabulary_file(vocabulary), tf.string, - tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64, - tf.lookup.TextFileIndex.LINE_NUMBER) - self._table = tf.lookup.StaticVocabularyTable( - self._table_initializer, num_oov_buckets=self._oov_buckets) - self.embeddings = tf.Variable( - tf.random.uniform(shape=[self._total_size, emb_dim])) - self.variables = [self.embeddings] - self.trainable_variables = self.variables - - def _tokenize(self, sentences): - # Perform a minimalistic text preprocessing by removing punctuation and - # splitting on spaces. - normalized_sentences = tf.strings.regex_replace( - input=sentences, pattern=r"\pP", rewrite="") - normalized_sentences = tf.reshape(normalized_sentences, [-1]) - sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse() - - # Deal with a corner case: there is one empty sentence. - sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant("")) - # Deal with a corner case: all sentences are empty. - sparse_tokens = tf.sparse.reset_shape(sparse_tokens) - sparse_token_ids = self._table.lookup(sparse_tokens.values) - - return (sparse_tokens.indices, sparse_token_ids, sparse_tokens.dense_shape) - - @tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)]) - def __call__(self, sentences): - token_ids, token_values, token_dense_shape = self._tokenize(sentences) - - return tf.nn.safe_embedding_lookup_sparse( - embedding_weights=self.embeddings, - sparse_ids=tf.sparse.SparseTensor(token_ids, token_values, - token_dense_shape), - sparse_weights=None, - combiner="sqrtn") - - -def main(argv): - del argv - - vocabulary = ["cat", "is", "on", "the", "mat"] - module = TextEmbeddingModel(vocabulary=vocabulary, emb_dim=10, oov_buckets=10) - tf.saved_model.save(module, FLAGS.export_dir) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/export_text_rnn_model.py b/tensorflow/examples/saved_model/integration_tests/export_text_rnn_model.py deleted file mode 100644 index 9b9f5925588..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/export_text_rnn_model.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Text RNN model stored as a SavedModel.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import app -from absl import flags - -import tensorflow.compat.v2 as tf - -FLAGS = flags.FLAGS - -flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.") - - -class TextRnnModel(tf.train.Checkpoint): - """Text RNN model. - - A full generative text RNN model that can train and decode sentences from a - starting word. - """ - - def __init__(self, vocab, emb_dim, buckets, state_size): - super(TextRnnModel, self).__init__() - self._buckets = buckets - self._lstm_cell = tf.keras.layers.LSTMCell(units=state_size) - self._rnn_layer = tf.keras.layers.RNN( - self._lstm_cell, return_sequences=True) - self._embeddings = tf.Variable(tf.random.uniform(shape=[buckets, emb_dim])) - self._logit_layer = tf.keras.layers.Dense(buckets) - self._set_up_vocab(vocab) - - def _tokenize(self, sentences): - # Perform a minimalistic text preprocessing by removing punctuation and - # splitting on spaces. - normalized_sentences = tf.strings.regex_replace( - input=sentences, pattern=r"\pP", rewrite="") - sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse() - - # Deal with a corner case: there is one empty sentence. - sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant("")) - # Deal with a corner case: all sentences are empty. - sparse_tokens = tf.sparse.reset_shape(sparse_tokens) - - return (sparse_tokens.indices, sparse_tokens.values, - sparse_tokens.dense_shape) - - def _set_up_vocab(self, vocab_tokens): - # TODO(vbardiovsky): Currently there is no real vocabulary, because - # saved_model serialization does not support trackable resources. Add a real - # vocabulary when it does. - vocab_list = ["UNK"] * self._buckets - for vocab_token in vocab_tokens: - index = self._words_to_indices(vocab_token).numpy() - vocab_list[index] = vocab_token - # This is a variable representing an inverse index. - self._vocab_tensor = tf.Variable(vocab_list) - - def _indices_to_words(self, indices): - return tf.gather(self._vocab_tensor, indices) - - def _words_to_indices(self, words): - return tf.strings.to_hash_bucket(words, self._buckets) - - @tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)]) - def train(self, sentences): - token_ids, token_values, token_dense_shape = self._tokenize(sentences) - tokens_sparse = tf.sparse.SparseTensor( - indices=token_ids, values=token_values, dense_shape=token_dense_shape) - tokens = tf.sparse.to_dense(tokens_sparse, default_value="") - - sparse_lookup_ids = tf.sparse.SparseTensor( - indices=tokens_sparse.indices, - values=self._words_to_indices(tokens_sparse.values), - dense_shape=tokens_sparse.dense_shape) - lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0) - - # Targets are the next word for each word of the sentence. - tokens_ids_seq = lookup_ids[:, 0:-1] - tokens_ids_target = lookup_ids[:, 1:] - - tokens_prefix = tokens[:, 0:-1] - - # Mask determining which positions we care about for a loss: all positions - # that have a valid non-terminal token. - mask = tf.logical_and( - tf.logical_not(tf.equal(tokens_prefix, "")), - tf.logical_not(tf.equal(tokens_prefix, ""))) - - input_mask = tf.cast(mask, tf.int32) - - with tf.GradientTape() as t: - sentence_embeddings = tf.nn.embedding_lookup(self._embeddings, - tokens_ids_seq) - - lstm_initial_state = self._lstm_cell.get_initial_state( - sentence_embeddings) - - lstm_output = self._rnn_layer( - inputs=sentence_embeddings, initial_state=lstm_initial_state) - - # Stack LSTM outputs into a batch instead of a 2D array. - lstm_output = tf.reshape(lstm_output, [-1, self._lstm_cell.output_size]) - - logits = self._logit_layer(lstm_output) - - targets = tf.reshape(tokens_ids_target, [-1]) - weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32) - - losses = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=targets, logits=logits) - - # Final loss is the mean loss for all token losses. - final_loss = tf.math.divide( - tf.reduce_sum(tf.multiply(losses, weights)), - tf.reduce_sum(weights), - name="final_loss") - - watched = t.watched_variables() - gradients = t.gradient(final_loss, watched) - - for w, g in zip(watched, gradients): - w.assign_sub(g) - - return final_loss - - @tf.function - def decode_greedy(self, sequence_length, first_word): - initial_state = self._lstm_cell.get_initial_state( - dtype=tf.float32, batch_size=1) - - sequence = [first_word] - current_word = first_word - current_id = tf.expand_dims(self._words_to_indices(current_word), 0) - current_state = initial_state - - for _ in range(sequence_length): - token_embeddings = tf.nn.embedding_lookup(self._embeddings, current_id) - lstm_outputs, current_state = self._lstm_cell(token_embeddings, - current_state) - lstm_outputs = tf.reshape(lstm_outputs, [-1, self._lstm_cell.output_size]) - logits = self._logit_layer(lstm_outputs) - softmax = tf.nn.softmax(logits) - - next_ids = tf.math.argmax(softmax, axis=1) - next_words = self._indices_to_words(next_ids)[0] - - current_id = next_ids - current_word = next_words - sequence.append(current_word) - - return sequence - - -def main(argv): - del argv - - sentences = [" hello there ", " how are you doing today "] - vocab = [ - "", "", "hello", "there", "how", "are", "you", "doing", "today" - ] - - module = TextRnnModel(vocab=vocab, emb_dim=10, buckets=100, state_size=128) - - for _ in range(100): - _ = module.train(tf.constant(sentences)) - - # We have to call this function explicitly if we want it exported, because it - # has no input_signature in the @tf.function decorator. - decoded = module.decode_greedy( - sequence_length=10, first_word=tf.constant("")) - _ = [d.numpy() for d in decoded] - - tf.saved_model.save(module, FLAGS.export_dir) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/integration_scripts.py b/tensorflow/examples/saved_model/integration_tests/integration_scripts.py deleted file mode 100644 index 6f1ccfa2f05..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/integration_scripts.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility to write SavedModel integration tests. - -SavedModel testing requires isolation between the process that creates and -consumes it. This file helps doing that by relaunching the same binary that -calls `assertCommandSucceeded` with an environment flag indicating what source -file to execute. That binary must start by calling `MaybeRunScriptInstead`. - -This allows to wire this into existing building systems without having to depend -on data dependencies. And as so allow to keep a fixed binary size and allows -interop with GPU tests. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import importlib -import os -import subprocess -import sys - -from absl import app -import tensorflow.compat.v2 as tf - -from tensorflow.python.platform import tf_logging as logging - - -class TestCase(tf.test.TestCase): - """Base class to write SavedModel integration tests.""" - - def assertCommandSucceeded(self, script_name, **flags): - """Runs an integration test script with given flags.""" - run_script = sys.argv[0] - if run_script.endswith(".py"): - command_parts = [sys.executable, run_script] - else: - command_parts = [run_script] - command_parts.append("--alsologtostderr") # For visibility in sponge. - for flag_key, flag_value in flags.items(): - command_parts.append("--%s=%s" % (flag_key, flag_value)) - - env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name) - logging.info("Running %s with added environment variables %s" % - (command_parts, env)) - subprocess.check_call(command_parts, env=dict(os.environ, **env)) - - -def MaybeRunScriptInstead(): - if "SCRIPT_NAME" in os.environ: - # Append current path to import path and execute `SCRIPT_NAME` main. - sys.path.extend([os.path.dirname(__file__)]) - module_name = os.environ["SCRIPT_NAME"] - retval = app.run(importlib.import_module(module_name).main) # pylint: disable=assignment-from-no-return - sys.exit(retval) diff --git a/tensorflow/examples/saved_model/integration_tests/mnist_util.py b/tensorflow/examples/saved_model/integration_tests/mnist_util.py deleted file mode 100644 index 9770c849603..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/mnist_util.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Convenience wrapper around Keras' MNIST and Fashion MNIST data.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow.compat.v2 as tf - -INPUT_SHAPE = (28, 28, 1) -NUM_CLASSES = 10 - - -def _load_random_data(num_train_and_test): - return ((np.random.randint(0, 256, (num, 28, 28), dtype=np.uint8), - np.random.randint(0, 10, (num,), dtype=np.int64)) - for num in num_train_and_test) - - -def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False): - """Returns MNIST or Fashion MNIST or fake train and test data.""" - load = ((lambda: _load_random_data([128, 128])) if fake_tiny_data else - tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else - tf.keras.datasets.mnist.load_data) - (x_train, y_train), (x_test, y_test) = load() - return ((_prepare_image(x_train), _prepare_label(y_train)), - (_prepare_image(x_test), _prepare_label(y_test))) - - -def _prepare_image(x): - """Converts images to [n,h,w,c] format in range [0,1].""" - return x[..., None].astype('float32') / 255. - - -def _prepare_label(y): - """Conerts labels to one-hot encoding.""" - return tf.keras.utils.to_categorical(y, NUM_CLASSES) diff --git a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py b/tensorflow/examples/saved_model/integration_tests/saved_model_test.py deleted file mode 100644 index 434d5ed4ad5..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/saved_model_test.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SavedModel integration tests.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from absl.testing import parameterized -import tensorflow.compat.v2 as tf - -from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils -from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts -from tensorflow.python.distribute import combinations as distribute_combinations -from tensorflow.python.framework import combinations - - -class SavedModelTest(scripts.TestCase, parameterized.TestCase): - - def __init__(self, method_name="runTest", has_extra_deps=False): - super(SavedModelTest, self).__init__(method_name) - self.has_extra_deps = has_extra_deps - - def skipIfMissingExtraDeps(self): - """Skip test if it requires extra dependencies. - - b/132234211: The extra dependencies are not available in all environments - that run the tests, e.g. "tensorflow_hub" is not available from tests - within "tensorflow" alone. Those tests are instead run by another - internal test target. - """ - if not self.has_extra_deps: - self.skipTest("Missing extra dependencies") - - def test_text_rnn(self): - export_dir = self.get_temp_dir() - self.assertCommandSucceeded("export_text_rnn_model", export_dir=export_dir) - self.assertCommandSucceeded("use_text_rnn_model", model_dir=export_dir) - - def test_rnn_cell(self): - export_dir = self.get_temp_dir() - self.assertCommandSucceeded("export_rnn_cell", export_dir=export_dir) - self.assertCommandSucceeded("use_rnn_cell", model_dir=export_dir) - - def test_text_embedding_in_sequential_keras(self): - self.skipIfMissingExtraDeps() - export_dir = self.get_temp_dir() - self.assertCommandSucceeded( - "export_simple_text_embedding", export_dir=export_dir) - self.assertCommandSucceeded( - "use_model_in_sequential_keras", model_dir=export_dir) - - def test_text_embedding_in_dataset(self): - export_dir = self.get_temp_dir() - self.assertCommandSucceeded( - "export_simple_text_embedding", export_dir=export_dir) - self.assertCommandSucceeded( - "use_text_embedding_in_dataset", model_dir=export_dir) - - TEST_MNIST_CNN_GENERATE_KWARGS = dict( - combinations=( - combinations.combine( - # Test all combinations with tf.saved_model.save(). - # Test all combinations using tf.keras.models.save_model() - # for both the reusable and the final full model. - use_keras_save_api=True, - named_strategy=list(ds_utils.named_strategies.values()), - retrain_flag_value=["true", "false"], - regularization_loss_multiplier=[None, 2], # Test for b/134528831. - ) + combinations.combine( - # Test few critcial combinations with raw tf.saved_model.save(), - # including export of a reusable SavedModel that gets assembled - # manually, including support for adjustable hparams. - use_keras_save_api=False, - named_strategy=None, - retrain_flag_value=["true", "false"], - regularization_loss_multiplier=[None, 2], # Test for b/134528831. - )), - test_combinations=(distribute_combinations.GPUCombination(), - distribute_combinations.TPUCombination())) - - @combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS) - def test_mnist_cnn(self, use_keras_save_api, named_strategy, - retrain_flag_value, regularization_loss_multiplier): - - self.skipIfMissingExtraDeps() - - fast_test_mode = True - temp_dir = self.get_temp_dir() - feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor") - full_model_dir = os.path.join(temp_dir, "full_model") - - self.assertCommandSucceeded( - "export_mnist_cnn", - fast_test_mode=fast_test_mode, - export_dir=feature_extrator_dir, - use_keras_save_api=use_keras_save_api) - - use_kwargs = dict(fast_test_mode=fast_test_mode, - input_saved_model_dir=feature_extrator_dir, - retrain=retrain_flag_value, - output_saved_model_dir=full_model_dir, - use_keras_save_api=use_keras_save_api) - if named_strategy: - use_kwargs["strategy"] = str(named_strategy) - if regularization_loss_multiplier is not None: - use_kwargs[ - "regularization_loss_multiplier"] = regularization_loss_multiplier - self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs) - - self.assertCommandSucceeded( - "deploy_mnist_cnn", - fast_test_mode=fast_test_mode, - saved_model_dir=full_model_dir) - - -if __name__ == "__main__": - scripts.MaybeRunScriptInstead() - tf.test.main() diff --git a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py b/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py deleted file mode 100644 index ae45a02a59b..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/use_mnist_cnn.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Imports a convolutional feature extractor for MNIST in SavedModel format. - -This program picks up the SavedModel written by export_mnist_cnn.py and -uses the feature extractor contained in it to do classification on either -classic MNIST (digits) or Fashion MNIST (thumbnails of apparel). Optionally, -it trains the feature extractor further as part of the new classifier. -As expected, that makes training slower but does not help much for the -original training dataset but helps a lot for transfer to the other dataset. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl import app -from absl import flags -import tensorflow.compat.v2 as tf -import tensorflow_hub as hub - -from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils -from tensorflow.examples.saved_model.integration_tests import mnist_util - -FLAGS = flags.FLAGS - -flags.DEFINE_string( - 'input_saved_model_dir', None, - 'Directory of the reusable SavedModel that is imported into this program.') -flags.DEFINE_integer( - 'epochs', 5, - 'Number of epochs to train.') -flags.DEFINE_bool( - 'retrain', False, - 'If set, the imported SavedModel is trained further.') -flags.DEFINE_float( - 'dropout_rate', None, - 'If set, dropout rate passed to the SavedModel. ' - 'Requires a SavedModel with support for adjustable hyperparameters.') -flags.DEFINE_float( - 'regularization_loss_multiplier', None, - 'If set, multiplier for the regularization losses in the SavedModel.') -flags.DEFINE_bool( - 'use_fashion_mnist', False, - 'Use Fashion MNIST (products) instead of the real MNIST (digits). ' - 'With this, --retrain gains a lot.') -flags.DEFINE_bool( - 'fast_test_mode', False, - 'Shortcut training for running in unit tests.') -flags.DEFINE_string( - 'output_saved_model_dir', None, - 'Directory of the SavedModel that was exported for reuse.') -flags.DEFINE_bool( - 'use_keras_save_api', False, - 'Uses tf.keras.models.save_model() instead of tf.saved_model.save().') -flags.DEFINE_string('strategy', None, - 'Name of the distribution strategy to use.') - - -def make_feature_extractor(saved_model_path, trainable, - regularization_loss_multiplier): - """Load a pre-trained feature extractor and wrap it for use in Keras.""" - if regularization_loss_multiplier is not None: - # TODO(b/63257857): Scaling regularization losses requires manual loading - # and modification of the SavedModel - obj = tf.saved_model.load(saved_model_path) - def _scale_one_loss(l): # Separate def avoids lambda capture of loop var. - f = tf.function(lambda: tf.multiply(regularization_loss_multiplier, l())) - _ = f.get_concrete_function() - return f - obj.regularization_losses = [_scale_one_loss(l) - for l in obj.regularization_losses] - # The modified object is then passed to hub.KerasLayer instead of the - # string handle. That prevents it from saving a Keras config (b/134528831). - handle = obj - else: - # If possible, we exercise the more common case of passing a string handle - # such that hub.KerasLayer can save a Keras config (b/134528831). - handle = saved_model_path - - arguments = {} - if FLAGS.dropout_rate is not None: - arguments['dropout_rate'] = FLAGS.dropout_rate - - return hub.KerasLayer(handle, trainable=trainable, arguments=arguments) - - -def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5): - """Returns a Keras Model to classify MNIST using feature_extractor.""" - regularizer = lambda: tf.keras.regularizers.l2(l2_strength) - net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE) - net = feature_extractor(net) - if dropout_rate: - net = tf.keras.layers.Dropout(dropout_rate)(net) - net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax', - kernel_regularizer=regularizer())(net) - return tf.keras.Model(inputs=inp, outputs=net) - - -def main(argv): - del argv - - with ds_utils.MaybeDistributionScope.from_name(FLAGS.strategy): - feature_extractor = make_feature_extractor( - FLAGS.input_saved_model_dir, - FLAGS.retrain, - FLAGS.regularization_loss_multiplier) - model = make_classifier(feature_extractor) - - model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.keras.optimizers.SGD(), - metrics=['accuracy']) - - # Train the classifier (possibly on a different dataset). - (x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data( - use_fashion_mnist=FLAGS.use_fashion_mnist, - fake_tiny_data=FLAGS.fast_test_mode) - print('Training on %s with %d trainable and %d untrainable variables.' % - ('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST', - len(model.trainable_variables), len(model.non_trainable_variables))) - model.fit(x_train, y_train, - batch_size=128, - epochs=FLAGS.epochs, - verbose=1, - validation_data=(x_test, y_test)) - - if FLAGS.output_saved_model_dir: - if FLAGS.use_keras_save_api: - tf.keras.models.save_model(model, FLAGS.output_saved_model_dir) - else: - tf.saved_model.save(model, FLAGS.output_saved_model_dir) - - -if __name__ == '__main__': - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py b/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py deleted file mode 100644 index 47a10fbb608..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Load and use text embedding module in sequential Keras.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tempfile - -from absl import app -from absl import flags - -import numpy as np -import tensorflow.compat.v2 as tf -import tensorflow_hub as hub - -FLAGS = flags.FLAGS - -flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.") - - -def train(fine_tuning): - """Build a Keras model and train with mock data.""" - features = np.array(["my first sentence", "my second sentence"]) - labels = np.array([1, 0]) - dataset = tf.data.Dataset.from_tensor_slices((features, labels)) - - module = tf.saved_model.load(FLAGS.model_dir) - - # Create the sequential keras model. - l = tf.keras.layers - model = tf.keras.Sequential() - model.add(l.Reshape((), batch_input_shape=[None, 1], dtype=tf.string)) - # TODO(b/124219898): output_shape should be optional. - model.add(hub.KerasLayer(module, output_shape=[10], trainable=fine_tuning)) - model.add(l.Dense(100, activation="relu")) - model.add(l.Dense(50, activation="relu")) - model.add(l.Dense(1, activation="sigmoid")) - - model.compile( - optimizer="adam", - loss="binary_crossentropy", - metrics=["accuracy"], - # TODO(b/124446120): Remove after fixed. - run_eagerly=True) - - model.fit_generator(generator=dataset.batch(1), epochs=5) - - # This is testing that a model using a SavedModel can be re-exported again, - # e.g. to catch issues such as b/142231881. - tf.saved_model.save(model, tempfile.mkdtemp()) - - -def main(argv): - del argv - - train(fine_tuning=False) - train(fine_tuning=True) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py b/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py deleted file mode 100644 index 2caca306cac..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Load and use an RNN cell stored as a SavedModel.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tempfile - -from absl import app -from absl import flags -import numpy as np -import tensorflow.compat.v2 as tf - -FLAGS = flags.FLAGS - -flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.") - - -def main(argv): - del argv - cell = tf.saved_model.load(FLAGS.model_dir) - - initial_state = cell.get_initial_state( - tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32))) - - cell.next_state( - tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)), - initial_state) - - # This is testing that a model using a SavedModel can be re-exported again, - # e.g. to catch issues such as b/142231881. - tf.saved_model.save(cell, tempfile.mkdtemp()) - - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py b/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py deleted file mode 100644 index be147a86d4c..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Load and use text embedding module in a Dataset map function.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tempfile - -from absl import app -from absl import flags - -import numpy as np -import tensorflow.compat.v2 as tf - -FLAGS = flags.FLAGS - -flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.") - - -def train(): - """Build a Keras model and train with mock data.""" - module = tf.saved_model.load(FLAGS.model_dir) - def _map_fn(features, labels): - features = tf.expand_dims(features, 0) - features = module(features) - features = tf.squeeze(features, 0) - return features, labels - - features = np.array(["my first sentence", "my second sentence"]) - labels = np.array([1, 0]) - dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn) - - # Create the sequential keras model. - l = tf.keras.layers - model = tf.keras.Sequential() - model.add(l.Dense(10, activation="relu")) - model.add(l.Dense(1, activation="sigmoid")) - - model.compile( - optimizer="adam", - loss="binary_crossentropy", - metrics=["accuracy"]) - - model.fit_generator(generator=dataset.batch(10), epochs=5) - - # This is testing that a model using a SavedModel can be re-exported again, - # e.g. to catch issues such as b/142231881. - tf.saved_model.save(model, tempfile.mkdtemp()) - - -def main(argv): - del argv - - train() - - -if __name__ == "__main__": - tf.enable_v2_behavior() - app.run(main) diff --git a/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py b/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py deleted file mode 100644 index a3c0f230976..00000000000 --- a/tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Load and use RNN model stored as a SavedModel.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tempfile - -from absl import app -from absl import flags -import tensorflow.compat.v2 as tf - -FLAGS = flags.FLAGS - -flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.") - - -def main(argv): - del argv - - sentences = [ - " sentence ", " second sentence ", " third sentence" - ] - - model = tf.saved_model.load(FLAGS.model_dir) - model.train(tf.constant(sentences)) - decoded = model.decode_greedy( - sequence_length=10, first_word=tf.constant("")) - _ = [d.numpy() for d in decoded] - - # This is testing that a model using a SavedModel can be re-exported again, - # e.g. to catch issues such as b/142231881. - tf.saved_model.save(model, tempfile.mkdtemp()) - -if __name__ == "__main__": - app.run(main) diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD index b4b07c368b4..8e41e36296f 100644 --- a/tensorflow/examples/speech_commands/BUILD +++ b/tensorflow/examples/speech_commands/BUILD @@ -7,10 +7,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files([ - "LICENSE", -]) - py_binary( name = "accuracy_utils_py", srcs = ["accuracy_utils.py"], diff --git a/tensorflow/examples/wav_to_spectrogram/BUILD b/tensorflow/examples/wav_to_spectrogram/BUILD index 9aeaf9414ae..1fb74dab01f 100644 --- a/tensorflow/examples/wav_to_spectrogram/BUILD +++ b/tensorflow/examples/wav_to_spectrogram/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "wav_to_spectrogram_lib", srcs = ["wav_to_spectrogram.cc"], diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 7a00daf4e26..f60f4daf071 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -38,6 +38,179 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in return list, start + size, nil } +// Operator that connects the output of an XLA computation to other consumer graph nodes. +func XlaClusterOutput(scope *Scope, input tf.Output) (outputs tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaClusterOutput", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// An op used by XLA SPMD partitioner to switch from manual partitioning to +// +// automatic partitioning. It converts the shard-shaped, manually partitioned input +// into full-shaped tensor to be partitioned automatically with the same sharding +// used by manual partitioning. +func XlaSpmdShardToFullShape(scope *Scope, input tf.Output, manual_sharding string, full_shape tf.Shape) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"manual_sharding": manual_sharding, "full_shape": full_shape} + opspec := tf.OpSpec{ + Type: "XlaSpmdShardToFullShape", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Wraps the XLA Sort operator, documented at +// +// https://www.tensorflow.org/performance/xla/operation_semantics#sort +// . +// +// Sorts a tensor. Currently only sorts in ascending order are supported. +// +// Arguments: +// input: A `Tensor` of type T. +// +// Returns A `Tensor` of type T. +func XlaSort(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaSort", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Receives the named tensor from another XLA computation. Wraps the XLA Recv +// +// operator documented at +// https://www.tensorflow.org/performance/xla/operation_semantics#recv . +// +// Arguments: +// dtype: The type of the tensor. +// tensor_name: A string key that identifies the channel. +// shape: The shape of the tensor. +// +// Returns The tensor to receive. +func XlaRecv(scope *Scope, dtype tf.DataType, tensor_name string, shape tf.Shape) (tensor tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype, "tensor_name": tensor_name, "shape": shape} + opspec := tf.OpSpec{ + Type: "XlaRecv", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Wraps the XLA DynamicSlice operator, documented at +// +// https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice +// . +// +// DynamicSlice extracts a sub-array from the input array at dynamic +// start_indices. The size of the slice in each dimension is passed in +// size_indices, which specify the end point of exclusive slice intervals in each +// dimension -- [start, start + size). The shape of start_indices must have rank 1, +// with dimension size equal to the rank of operand. +// +// Arguments: +// input: A `Tensor` of type T. +// start_indices: List of N integers containing the slice size for each +// dimension. Each value must be strictly greater than zero, and start + size +// must be less than or equal to the size of the dimension to avoid +// implementation defined behavior. +// +func XlaDynamicSlice(scope *Scope, input tf.Output, start_indices tf.Output, size_indices tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaDynamicSlice", + Input: []tf.Input{ + input, start_indices, size_indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Set a bound for the given input value as a hint to Xla compiler, +// +// returns the same value. +func XlaSetBound(scope *Scope, input tf.Output, bound tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaSetBound", + Input: []tf.Input{ + input, bound, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Wraps the XLA DotGeneral operator, documented at +// +// https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +// . +// +// Arguments: +// lhs: the LHS tensor +// rhs: the RHS tensor +// dimension_numbers: a serialized xla::DotDimensionNumbers proto. +// precision_config: a serialized xla::PrecisionConfig proto. +func XlaDot(scope *Scope, lhs tf.Output, rhs tf.Output, dimension_numbers string, precision_config string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dimension_numbers": dimension_numbers, "precision_config": precision_config} + opspec := tf.OpSpec{ + Type: "XlaDot", + Input: []tf.Input{ + lhs, rhs, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Output a fact about factorials. +func Fact(scope *Scope) (fact tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Fact", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) @@ -96,6 +269,43 @@ func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs t return op.Output(0), op.Output(1), op.Output(2) } +// Computes the eigen decomposition of a batch of self-adjoint matrices +// +// (Note: Only real inputs are supported). +// +// Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in +// tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]). +// +// Arguments: +// a: the input tensor. +// max_iter: maximum number of sweep update, i.e., the whole lower triangular +// part or upper triangular part based on parameter lower. Heuristically, it has +// been argued that approximately log(min (M, N)) sweeps are needed in practice +// (Ref: Golub & van Loan "Matrix Computation"). +// epsilon: the tolerance ratio. +// precision_config: a serialized xla::PrecisionConfig proto. +// +// Returns: +// s: Singular values. The values are sorted in reverse order of magnitude, so +// s[..., 0] is the largest value, s[..., 1] is the second largest, etc. +// u: Left singular vectors. +// v: Right singular vectors. +func XlaSvd(scope *Scope, a tf.Output, max_iter int64, epsilon float32, precision_config string) (s tf.Output, u tf.Output, v tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"max_iter": max_iter, "epsilon": epsilon, "precision_config": precision_config} + opspec := tf.OpSpec{ + Type: "XlaSvd", + Input: []tf.Input{ + a, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) @@ -158,6 +368,34 @@ func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs t return op.Output(0) } +// Helper operator for performing XLA-style broadcasts +// +// Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to +// whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules +// for binary operators. +// +// Arguments: +// lhs: the LHS input tensor +// rhs: the RHS input tensor +// broadcast_dims: an XLA-style broadcast dimension specification +// +// Returns: +// lhs_output: the broadcasted LHS tensor +// rhs_output: the broadcasted RHS tensor +func XlaBroadcastHelper(scope *Scope, lhs tf.Output, rhs tf.Output, broadcast_dims tf.Output) (lhs_output tf.Output, rhs_output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaBroadcastHelper", + Input: []tf.Input{ + lhs, rhs, broadcast_dims, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // Subtracts sparse `updates` from an existing tensor according to `indices`. // // This operation creates a new tensor by subtracting sparse `updates` from the @@ -357,6 +595,40 @@ func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min return op.Output(0), op.Output(1), op.Output(2) } +// QuantizeAndDequantizeV4GradAttr is an optional argument to QuantizeAndDequantizeV4Grad. +type QuantizeAndDequantizeV4GradAttr func(optionalAttr) + +// QuantizeAndDequantizeV4GradAxis sets the optional axis attribute to value. +// If not specified, defaults to -1 +func QuantizeAndDequantizeV4GradAxis(value int64) QuantizeAndDequantizeV4GradAttr { + return func(m optionalAttr) { + m["axis"] = value + } +} + +// Returns the gradient of `QuantizeAndDequantizeV4`. +// +// Returns a gradient of 1 for inputs that are within the quantization range, +// or 0 otherwise. +func QuantizeAndDequantizeV4Grad(scope *Scope, gradients tf.Output, input tf.Output, input_min tf.Output, input_max tf.Output, optional ...QuantizeAndDequantizeV4GradAttr) (input_backprop tf.Output, input_min_backprop tf.Output, input_max_backprop tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeAndDequantizeV4Grad", + Input: []tf.Input{ + gradients, input, input_min, input_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // QuantizeAndDequantizeV2Attr is an optional argument to QuantizeAndDequantizeV2. type QuantizeAndDequantizeV2Attr func(optionalAttr) @@ -1438,6 +1710,45 @@ func TensorStridedSliceUpdate(scope *Scope, input tf.Output, begin tf.Output, en return op.Output(0) } +// Computes the eigen decomposition of a batch of self-adjoint matrices +// +// (Note: Only real inputs are supported). +// +// Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in +// tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for +// i=0...N-1. +// +// Arguments: +// a: the input tensor. +// lower: a boolean specifies whether the calculation is done with the lower +// triangular part or the upper triangular part. +// max_iter: maximum number of sweep update, i.e., the whole lower triangular +// part or upper triangular part based on parameter lower. Heuristically, it has +// been argued that approximately logN sweeps are needed in practice (Ref: Golub & +// van Loan "Matrix Computation"). +// epsilon: the tolerance ratio. +// +// Returns: +// w: The eigenvalues in ascending order, each repeated according to its +// multiplicity. +// v: The column v[..., :, i] is the normalized eigenvector corresponding to the +// eigenvalue w[..., i]. +func XlaSelfAdjointEig(scope *Scope, a tf.Output, lower bool, max_iter int64, epsilon float32) (w tf.Output, v tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"lower": lower, "max_iter": max_iter, "epsilon": epsilon} + opspec := tf.OpSpec{ + Type: "XlaSelfAdjointEig", + Input: []tf.Input{ + a, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // Ensures that the tensor's shape matches the expected shape. // // Raises an error if the input tensor's shape does not match the specified shape. @@ -4343,6 +4654,31 @@ func NearestNeighbors(scope *Scope, points tf.Output, centers tf.Output, k tf.Ou return op.Output(0), op.Output(1) } +// Sends the named tensor to another XLA computation. Wraps the XLA Send operator +// +// documented at +// https://www.tensorflow.org/performance/xla/operation_semantics#send . +// +// Arguments: +// tensor: The tensor to send. +// tensor_name: A string key that identifies the channel. +// +// Returns the created operation. +func XlaSend(scope *Scope, tensor tf.Output, tensor_name string) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"tensor_name": tensor_name} + opspec := tf.OpSpec{ + Type: "XlaSend", + Input: []tf.Input{ + tensor, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Returns the index of a data point that should be added to the seed set. // // Entries in distances are assumed to be squared distances of candidate points to @@ -7528,6 +7864,33 @@ func ResourceAccumulatorApplyGradient(scope *Scope, handle tf.Output, local_step return scope.AddOperation(opspec) } +// Wraps the XLA Pad operator, documented at +// +// https://www.tensorflow.org/performance/xla/operation_semantics#pad +// . +// +// Arguments: +// input: A `Tensor` of type T. +// padding_value: A scalar `Tensor` of type T. +// padding_low: the padding to apply at the start of each input dimensions +// padding_high: the padding to apply at the end of each input dimension. +// padding_interior: the padding to apply between each input element. +// +// Returns A `Tensor` of type T. +func XlaPad(scope *Scope, input tf.Output, padding_value tf.Output, padding_low tf.Output, padding_high tf.Output, padding_interior tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaPad", + Input: []tf.Input{ + input, padding_value, padding_low, padding_high, padding_interior, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Updates the accumulator with a new value for global_step. // // Logs warning if the accumulator's value is already higher than @@ -8015,6 +8378,80 @@ func MultiDeviceIteratorToStringHandle(scope *Scope, multi_device_iterator tf.Ou return op.Output(0) } +// QuantizeAndDequantizeV4Attr is an optional argument to QuantizeAndDequantizeV4. +type QuantizeAndDequantizeV4Attr func(optionalAttr) + +// QuantizeAndDequantizeV4SignedInput sets the optional signed_input attribute to value. +// If not specified, defaults to true +func QuantizeAndDequantizeV4SignedInput(value bool) QuantizeAndDequantizeV4Attr { + return func(m optionalAttr) { + m["signed_input"] = value + } +} + +// QuantizeAndDequantizeV4NumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func QuantizeAndDequantizeV4NumBits(value int64) QuantizeAndDequantizeV4Attr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// QuantizeAndDequantizeV4RangeGiven sets the optional range_given attribute to value. +// If not specified, defaults to false +func QuantizeAndDequantizeV4RangeGiven(value bool) QuantizeAndDequantizeV4Attr { + return func(m optionalAttr) { + m["range_given"] = value + } +} + +// QuantizeAndDequantizeV4RoundMode sets the optional round_mode attribute to value. +// If not specified, defaults to "HALF_TO_EVEN" +func QuantizeAndDequantizeV4RoundMode(value string) QuantizeAndDequantizeV4Attr { + return func(m optionalAttr) { + m["round_mode"] = value + } +} + +// QuantizeAndDequantizeV4NarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func QuantizeAndDequantizeV4NarrowRange(value bool) QuantizeAndDequantizeV4Attr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// QuantizeAndDequantizeV4Axis sets the optional axis attribute to value. +// If not specified, defaults to -1 +func QuantizeAndDequantizeV4Axis(value int64) QuantizeAndDequantizeV4Attr { + return func(m optionalAttr) { + m["axis"] = value + } +} + +// Returns the gradient of `QuantizeAndDequantizeV4`. +// +// This is almost identical to QuantizeAndDequantizeV2, except that it returns a +// gradient of 1 for inputs that are within the quantization range, or 0 otherwise. +func QuantizeAndDequantizeV4(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, optional ...QuantizeAndDequantizeV4Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizeAndDequantizeV4", + Input: []tf.Input{ + input, input_min, input_max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Gets next element for the provided shard number. // // Arguments: @@ -8137,6 +8574,39 @@ func BoostedTreesCalculateBestFeatureSplit(scope *Scope, node_id_range tf.Output return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) } +// Wraps the XLA DynamicUpdateSlice operator, documented at +// +// https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +// . +// +// XlaDynamicUpdateSlice generates a result which is the value of the `input` +// operand, with a slice update overwritten at `indices`. The shape of `update` +// determines the shape of the sub-array of the result which is updated. The shape +// of indices must be rank == 1, with dimension size equal to the rank of `input`. +// +// Handling of out-of-bounds slice indices is implementation-defined. +// +// Arguments: +// input: A `Tensor` of type T. +// update: A `Tensor` of type T. Same rank as `input`. +// indices: A vector of indices into `input`. Must have length equal to the rank of +// `input`. +// +// Returns A `Tensor` of type T. +func XlaDynamicUpdateSlice(scope *Scope, input tf.Output, update tf.Output, indices tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaDynamicUpdateSlice", + Input: []tf.Input{ + input, update, indices, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ModelDatasetAttr is an optional argument to ModelDataset. type ModelDatasetAttr func(optionalAttr) @@ -13934,61 +14404,6 @@ func ExperimentalRebatchDataset(scope *Scope, input_dataset tf.Output, num_repli return op.Output(0) } -// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. -type TextLineReaderV2Attr func(optionalAttr) - -// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. -// -// value: Number of lines to skip from the beginning of every file. -// If not specified, defaults to 0 -func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["skip_header_lines"] = value - } -} - -// TextLineReaderV2Container sets the optional container attribute to value. -// -// value: If non-empty, this reader is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func TextLineReaderV2Container(value string) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// TextLineReaderV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this reader is named in the given bucket -// with this shared_name. Otherwise, the node name is used instead. -// If not specified, defaults to "" -func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A Reader that outputs the lines of a file delimited by '\n'. -// -// Returns The handle to reference the Reader. -func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TextLineReaderV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Saves the input tensors to disk. // // The size of `tensor_names` must match the number of tensors in `data`. `data[i]` @@ -14757,6 +15172,10 @@ func QrFullMatrices(value bool) QrAttr { // Computes the QR decomposition of each inner matrix in `tensor` such that // `tensor[..., :, :] = q[..., :, :] * r[..., :,:])` // +// Currently, the gradient for the QR decomposition is well-defined only when +// the first `P` columns of the inner matrix are linearly independent, where +// `P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`. +// // ```python // # a is a tensor. // # q is a tensor of orthonormal matrices. @@ -15831,6 +16250,62 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr return op.Output(0) } +// Outputs a `Summary` protocol buffer with a histogram. +// +// The generated +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// has one summary value containing a histogram for `values`. +// +// This op reports an `InvalidArgument` error if any value is not finite. +// +// Arguments: +// tag: Scalar. Tag to use for the `Summary.Value`. +// values: Any shape. Values to use to build the histogram. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "HistogramSummary", + Input: []tf.Input{ + tag, values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Merges summaries. +// +// This op creates a +// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +// protocol buffer that contains the union of all the values in the input +// summaries. +// +// When the Op is run, it reports an `InvalidArgument` error if multiple values +// in the summaries to merge use the same tag. +// +// Arguments: +// inputs: Can be of any shape. Each must contain serialized `Summary` protocol +// buffers. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MergeSummary", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the gradient for the sqrt of `x` wrt its input. // // Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` @@ -18807,6 +19282,14 @@ func CollectiveReduceV2CommunicationHint(value string) CollectiveReduceV2Attr { } } +// CollectiveReduceV2TimeoutSeconds sets the optional timeout_seconds attribute to value. +// If not specified, defaults to 0 +func CollectiveReduceV2TimeoutSeconds(value float32) CollectiveReduceV2Attr { + return func(m optionalAttr) { + m["timeout_seconds"] = value + } +} + // Mutually reduces multiple tensors of identical type and shape. func CollectiveReduceV2(scope *Scope, input tf.Output, group_size tf.Output, group_key tf.Output, instance_key tf.Output, merge_op string, final_op string, optional ...CollectiveReduceV2Attr) (data tf.Output) { if scope.Err() != nil { @@ -20963,73 +21446,36 @@ func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { // scattered onto an existing tensor (as opposed to a zero-tensor). If the memory // for the existing tensor cannot be re-used, a copy is made and updated. // -// If `indices` contains duplicates, then their updates are accumulated (summed). +// If `indices` contains duplicates, then we pick the last update for the index. // -// **WARNING**: The order in which updates are applied is nondeterministic, so the -// output will be nondeterministic if `indices` contains duplicates -- because -// of some numerical approximation issues, numbers summed in different order -// may yield different results. +// If an out of bound index is found on CPU, an error is returned. +// +// **WARNING**: There are some GPU specific semantics for this operation. +// - If an out of bound index is found, the index is ignored. +// - The order in which updates are applied is nondeterministic, so the output +// will be nondeterministic if `indices` contains duplicates. // // `indices` is an integer tensor containing indices into a new tensor of shape -// `shape`. The last dimension of `indices` can be at most the rank of `shape`: +// `shape`. // -// indices.shape[-1] <= shape.rank +// * `indices` must have at least 2 axes: `(num_updates, index_depth)`. +// * The last axis of `indices` is how deep to index into `tensor` so this index +// depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim` // -// The last dimension of `indices` corresponds to indices into elements -// (if `indices.shape[-1] = shape.rank`) or slices -// (if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of -// `shape`. `updates` is a tensor with shape +// if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements. +// if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input +// `tensor`. // -// indices.shape[:-1] + shape[indices.shape[-1]:] +// Each `update` has a rank of `tensor.rank - indices.shape[-1]`. +// The overall shape of `updates` is: // -// The simplest form of scatter is to insert individual elements in a tensor by -// index. For example, say we want to insert 4 scattered elements in a rank-1 -// tensor with 8 elements. +// ``` +// indices.shape[:-1] + tensor.shape[indices.shape[-1]:] +// ``` // -//
-// -//
+// For usage examples see the python [tf.tensor_scatter_nd_update]( +// https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function // -// In Python, this scatter operation would look like this: -// -// >>> indices = tf.constant([[4], [3], [1], [7]]) -// >>> updates = tf.constant([9, 10, 11, 12]) -// >>> tensor = tf.ones([8], dtype=tf.int32) -// >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates)) -// tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32) -// -// We can also, insert entire slices of a higher rank tensor all at once. For -// example, if we wanted to insert two slices in the first dimension of a -// rank-3 tensor with two matrices of new values. -// -// In Python, this scatter operation would look like this: -// -// >>> indices = tf.constant([[0], [2]]) -// >>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6], -// ... [7, 7, 7, 7], [8, 8, 8, 8]], -// ... [[5, 5, 5, 5], [6, 6, 6, 6], -// ... [7, 7, 7, 7], [8, 8, 8, 8]]]) -// >>> tensor = tf.ones([4, 4, 4], dtype=tf.int32) -// >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy()) -// [[[5 5 5 5] -// [6 6 6 6] -// [7 7 7 7] -// [8 8 8 8]] -// [[1 1 1 1] -// [1 1 1 1] -// [1 1 1 1] -// [1 1 1 1]] -// [[5 5 5 5] -// [6 6 6 6] -// [7 7 7 7] -// [8 8 8 8]] -// [[1 1 1 1] -// [1 1 1 1] -// [1 1 1 1] -// [1 1 1 1]]] -// -// Note that on CPU, if an out of bound index is found, an error is returned. -// On GPU, if an out of bound index is found, the index is ignored. // // Arguments: // tensor: Tensor to copy/update. @@ -21825,6 +22271,187 @@ func Sin(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. +type ResourceSparseApplyAdadeltaAttr func(optionalAttr) + +// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. +// +// value: If True, updating of the var and accum tensors will be protected by +// a lock; otherwise the behavior is undefined, but may exhibit less contention. +// If not specified, defaults to false +func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// var: Should be from a Variable(). +// +// Arguments: +// +// accum: Should be from a Variable(). +// accum_update: : Should be from a Variable(). +// lr: Learning rate. Must be a scalar. +// rho: Decay factor. Must be a scalar. +// epsilon: Constant factor. Must be a scalar. +// grad: The gradient. +// indices: A vector of indices into the first dimension of var and accum. +// +// Returns the created operation. +func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceSparseApplyAdadelta", + Input: []tf.Input{ + var_, accum, accum_update, lr, rho, epsilon, grad, indices, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. +type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve ADAM embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the ADAM optimization algorithm. +// momenta: Parameter momenta updated by the ADAM optimization algorithm. +// velocities: Parameter velocities updated by the ADAM optimization algorithm. +// gradient_accumulators: Parameter gradient_accumulators updated by the ADAM optimization algorithm. +func RetrieveTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2), op.Output(3) +} + +// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. +type ResourceApplyAdamAttr func(optionalAttr) + +// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, uses the nesterov update. +// If not specified, defaults to false +func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the Adam algorithm. +// +// $$\text{lr}_t := \mathrm{learning_rate} * \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ +// $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ +// $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$ +// $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilon)$$ +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdam", + Input: []tf.Input{ + var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// Computes sigmoid of `x` element-wise. +// +// Specifically, `y = 1 / (1 + exp(-x))`. +func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sigmoid", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // PrintAttr is an optional argument to Print. type PrintAttr func(optionalAttr) @@ -24704,6 +25331,37 @@ func FractionalMaxPoolGrad(scope *Scope, orig_input tf.Output, orig_output tf.Ou return op.Output(0) } +// Wraps the XLA ConvGeneralDilated operator, documented at +// +// https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +// . +// +// Arguments: +// lhs: the input tensor +// rhs: the kernel tensor +// window_strides: the inter-window strides +// padding: the padding to apply at the start and end of each input dimensions +// lhs_dilation: dilation to apply between input elements +// rhs_dilation: dilation to apply between kernel elements +// feature_group_count: number of feature groups for grouped convolution. +// dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. +// precision_config: a serialized xla::PrecisionConfig proto. +func XlaConv(scope *Scope, lhs tf.Output, rhs tf.Output, window_strides tf.Output, padding tf.Output, lhs_dilation tf.Output, rhs_dilation tf.Output, feature_group_count tf.Output, dimension_numbers string, precision_config string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dimension_numbers": dimension_numbers, "precision_config": precision_config} + opspec := tf.OpSpec{ + Type: "XlaConv", + Input: []tf.Input{ + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // NthElementAttr is an optional argument to NthElement. type NthElementAttr func(optionalAttr) @@ -25552,6 +26210,37 @@ func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, return op.Output(0) } +// Takes the packed uint32 input and unpacks the input to uint8 to do +// +// Dequantization on device. +// +// Arguments: +// input: Input tensors whose types is uint32, shape is [d0, ..., dn]. +// min_range: The minimum scalar value possibly produced for the input. +// max_range: The maximum scalar value possibly produced for the input. +// mode: String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}. +// transpose_output: Boolean to determine if output is transposed. transpose_output +// is faster when input is large and rank of input is higher than 1. +// +// Returns Output tensors whose types is bloat16. If transpose_output is true, +// output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output +// is false, output shape is [d0,..., dn * 4]. +func XlaDequantize(scope *Scope, input tf.Output, min_range float32, max_range float32, mode string, transpose_output bool) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"min_range": min_range, "max_range": max_range, "mode": mode, "transpose_output": transpose_output} + opspec := tf.OpSpec{ + Type: "XlaDequantize", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. type MaxPoolGradV2Attr func(optionalAttr) @@ -27393,6 +28082,28 @@ func FusedBatchNormGrad(scope *Scope, y_backprop tf.Output, x tf.Output, scale t return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4) } +// An op used by XLA SPMD partitioner to switch from automatic partitioning to +// +// manual partitioning. It annotates the input (full-shape, to be automatically +// partitioned) with the same sharding used by manual partitioning, and outputs a +// shard-shaped tensor to be consumed by later manually-partitioned ops. If the +// shape is not evenly partitionable, the padding region will be masked with 0s. +func XlaSpmdFullToShardShape(scope *Scope, input tf.Output, manual_sharding string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"manual_sharding": manual_sharding} + opspec := tf.OpSpec{ + Type: "XlaSpmdFullToShardShape", + Input: []tf.Input{ + input, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // DecodeCSVAttr is an optional argument to DecodeCSV. type DecodeCSVAttr func(optionalAttr) @@ -27970,6 +28681,92 @@ func DecodePaddedRaw(scope *Scope, input_bytes tf.Output, fixed_length tf.Output return op.Output(0) } +// Elementwise computes the bitwise left-shift of `x` and `y`. +// +// If `y` is negative, or greater than or equal to the width of `x` in bits the +// result is implementation defined. +// +// Example: +// +// ```python +// import tensorflow as tf +// from tensorflow.python.ops import bitwise_ops +// import numpy as np +// dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64] +// +// for dtype in dtype_list: +// lhs = tf.constant([-1, -5, -3, -14], dtype=dtype) +// rhs = tf.constant([5, 0, 7, 11], dtype=dtype) +// +// left_shift_result = bitwise_ops.left_shift(lhs, rhs) +// +// print(left_shift_result) +// +// # This will print: +// # tf.Tensor([ -32 -5 -128 0], shape=(4,), dtype=int8) +// # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int16) +// # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int32) +// # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int64) +// +// lhs = np.array([-2, 64, 101, 32], dtype=np.int8) +// rhs = np.array([-1, -5, -3, -14], dtype=np.int8) +// bitwise_ops.left_shift(lhs, rhs) +// # +// ``` +// +func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "LeftShift", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Generates a feature cross from a list of tensors, and returns it as a +// RaggedTensor. See `tf.ragged.cross` for more details. +// +// Arguments: +// ragged_values: The values tensor for each RaggedTensor input. +// ragged_row_splits: The row_splits tensor for each RaggedTensor input. +// sparse_indices: The indices tensor for each SparseTensor input. +// sparse_values: The values tensor for each SparseTensor input. +// sparse_shape: The dense_shape tensor for each SparseTensor input. +// dense_inputs: The tf.Tensor inputs. +// input_order: String specifying the tensor type for each input. The `i`th character in +// this string specifies the type of the `i`th input, and is one of: 'R' (ragged), +// 'D' (dense), or 'S' (sparse). This attr is used to ensure that the crossed +// values are combined in the order of the inputs from the call to tf.ragged.cross. +// +// +// +// +// +// +// Returns: +// output_values: The `values` for the returned `RaggedTensor`. +// output_row_splits: The `row_splits` for the returned `RaggedTensor`. +func RaggedCross(scope *Scope, ragged_values []tf.Output, ragged_row_splits []tf.Output, sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shape []tf.Output, dense_inputs []tf.Output, input_order string, hashed_output bool, num_buckets int64, hash_key int64, out_values_type tf.DataType, out_row_splits_type tf.DataType) (output_values tf.Output, output_row_splits tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"input_order": input_order, "hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_values_type": out_values_type, "out_row_splits_type": out_row_splits_type} + opspec := tf.OpSpec{ + Type: "RaggedCross", + Input: []tf.Input{ + tf.OutputList(ragged_values), tf.OutputList(ragged_row_splits), tf.OutputList(sparse_indices), tf.OutputList(sparse_values), tf.OutputList(sparse_shape), tf.OutputList(dense_inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // BatchMatMulAttr is an optional argument to BatchMatMul. type BatchMatMulAttr func(optionalAttr) @@ -28851,82 +29648,6 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra return op.Output(0) } -// Computes the derivative of a Gamma random sample w.r.t. `alpha`. -func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RandomGammaGrad", - Input: []tf.Input{ - alpha, sample, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomShuffleAttr is an optional argument to RandomShuffle. -type RandomShuffleAttr func(optionalAttr) - -// RandomShuffleSeed sets the optional seed attribute to value. -// -// value: If either `seed` or `seed2` are set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, it is seeded by a -// random seed. -// If not specified, defaults to 0 -func RandomShuffleSeed(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleSeed2(value int64) RandomShuffleAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// Randomly shuffles a tensor along its first dimension. -// -// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped -// to one and only one `output[i]`. For example, a mapping that might occur for a -// 3x2 tensor is: -// -// ``` -// [[1, 2], [[5, 6], -// [3, 4], ==> [1, 2], -// [5, 6]] [3, 4]] -// ``` -// -// Arguments: -// value: The tensor to be shuffled. -// -// Returns A tensor of same shape and type as `value`, shuffled along its first -// dimension. -func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomShuffle", - Input: []tf.Input{ - value, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Creates a dataset that takes a Bernoulli sample of the contents of another dataset. // // There is no transformation in the `tf.data` Python API for creating this dataset. @@ -29391,6 +30112,45 @@ func BlockLSTMV2(scope *Scope, seq_len_max tf.Output, x tf.Output, cs_prev tf.Ou return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6) } +// Return a tensor with the same shape and contents as the input tensor or value. +func Identity(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Identity", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs a `Summary` protocol buffer with scalar values. +// +// The input `tags` and `values` must have the same shape. The generated summary +// has a summary value for each tag-value pair in `tags` and `values`. +// +// Arguments: +// tags: Tags for the summary. +// values: Same shape as `tags. Values for the summary. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ScalarSummary", + Input: []tf.Input{ + tags, values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceSparseApplyProximalAdagradAttr is an optional argument to ResourceSparseApplyProximalAdagrad. type ResourceSparseApplyProximalAdagradAttr func(optionalAttr) @@ -29458,45 +30218,6 @@ func Neg(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } -// Return a tensor with the same shape and contents as the input tensor or value. -func Identity(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Identity", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs a `Summary` protocol buffer with scalar values. -// -// The input `tags` and `values` must have the same shape. The generated summary -// has a summary value for each tag-value pair in `tags` and `values`. -// -// Arguments: -// tags: Tags for the summary. -// values: Same shape as `tags. Values for the summary. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ScalarSummary", - Input: []tf.Input{ - tags, values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Concatenates tensors along one dimension. // // Arguments: @@ -29663,33 +30384,6 @@ func WriteAudioSummary(scope *Scope, writer tf.Output, step tf.Output, tag tf.Ou return scope.AddOperation(opspec) } -// Outputs a `Summary` protocol buffer with a histogram. -// -// The generated -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// has one summary value containing a histogram for `values`. -// -// This op reports an `InvalidArgument` error if any value is not finite. -// -// Arguments: -// tag: Scalar. Tag to use for the `Summary.Value`. -// values: Any shape. Values to use to build the histogram. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func HistogramSummary(scope *Scope, tag tf.Output, values tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "HistogramSummary", - Input: []tf.Input{ - tag, values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Performs gradient updates of embedding tables. // // Arguments: @@ -29922,6 +30616,61 @@ func GRUBlockCellGrad(scope *Scope, x tf.Output, h_prev tf.Output, w_ru tf.Outpu return op.Output(0), op.Output(1), op.Output(2), op.Output(3) } +// TextLineReaderV2Attr is an optional argument to TextLineReaderV2. +type TextLineReaderV2Attr func(optionalAttr) + +// TextLineReaderV2SkipHeaderLines sets the optional skip_header_lines attribute to value. +// +// value: Number of lines to skip from the beginning of every file. +// If not specified, defaults to 0 +func TextLineReaderV2SkipHeaderLines(value int64) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["skip_header_lines"] = value + } +} + +// TextLineReaderV2Container sets the optional container attribute to value. +// +// value: If non-empty, this reader is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func TextLineReaderV2Container(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// TextLineReaderV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this reader is named in the given bucket +// with this shared_name. Otherwise, the node name is used instead. +// If not specified, defaults to "" +func TextLineReaderV2SharedName(value string) TextLineReaderV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A Reader that outputs the lines of a file delimited by '\n'. +// +// Returns The handle to reference the Reader. +func TextLineReaderV2(scope *Scope, optional ...TextLineReaderV2Attr) (reader_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TextLineReaderV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Encode audio data using the WAV file format. // // This operation will generate a string suitable to be saved out to create a .wav @@ -30261,6 +31010,32 @@ func StatefulUniform(scope *Scope, resource tf.Output, algorithm tf.Output, shap return op.Output(0) } +// Wraps the XLA Gather operator documented at +// +// https://www.tensorflow.org/xla/operation_semantics#gather +// +// Arguments: +// operand: The array we're gathering from. +// start_indices: Array containing the starting indices of the slices we gather. +// slice_sizes: slice_sizes[i] is the bounds for the slice on dimension i. +// dimension_numbers: A serialized xla::GatherDimensionNumbers proto. +// indices_are_sorted: Boolean indicating if the indices are sorted. +func XlaGather(scope *Scope, operand tf.Output, start_indices tf.Output, slice_sizes tf.Output, dimension_numbers string, indices_are_sorted bool) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dimension_numbers": dimension_numbers, "indices_are_sorted": indices_are_sorted} + opspec := tf.OpSpec{ + Type: "XlaGather", + Input: []tf.Input{ + operand, start_indices, slice_sizes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // QuantizedConv2DAttr is an optional argument to QuantizedConv2D. type QuantizedConv2DAttr func(optionalAttr) @@ -31552,6 +32327,34 @@ func AssignVariableOp(scope *Scope, resource tf.Output, value tf.Output) (o *tf. return scope.AddOperation(opspec) } +// Wraps the XLA Sort operator, documented at +// +// https://www.tensorflow.org/performance/xla/operation_semantics#sort +// . +// +// Sorts a tensor. Currently only sorts in ascending order are supported. +// +// Arguments: +// keys: A `Tensor` of type K. +// values: A `Tensor` of type V. +// +// Returns: +// sorted_keys: A `Tensor` of type K. +// sorted_values: A `Tensor` of type V. +func XlaKeyValueSort(scope *Scope, keys tf.Output, values tf.Output) (sorted_keys tf.Output, sorted_values tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaKeyValueSort", + Input: []tf.Input{ + keys, values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + // Asserts that compilation succeeded. This op produces no output and closes the // // device during failure to ensure all pending device interactions fail. @@ -31691,6 +32494,18 @@ func VarHandleOp(scope *Scope, dtype tf.DataType, shape tf.Shape, optional ...Va return op.Output(0) } +// Replica ID. +func XlaReplicaId(scope *Scope) (id tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaReplicaId", + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns conj(x - y)(x - y) element-wise. // // *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting @@ -31888,6 +32703,21 @@ func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...Booste return op.Output(0) } +// An op which shards the input based on the given sharding attribute. +func XlaSharding(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "XlaSharding", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // EagerPyFuncAttr is an optional argument to EagerPyFunc. type EagerPyFuncAttr func(optionalAttr) @@ -34789,6 +35619,307 @@ func SparseCrossV2(scope *Scope, indices []tf.Output, values []tf.Output, shapes return op.Output(0), op.Output(1), op.Output(2) } +// Pads a tensor with mirrored values. +// +// This operation pads a `input` with mirrored values according to the `paddings` +// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is +// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many values to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many values to add after the contents of `input` +// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater +// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true +// (if false, respectively). +// +// The padded size of each dimension D of the output is: +// +// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` +// +// For example: +// +// ``` +// # 't' is [[1, 2, 3], [4, 5, 6]]. +// # 'paddings' is [[1, 1]], [2, 2]]. +// # 'mode' is SYMMETRIC. +// # rank of 't' is 2. +// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] +// [2, 1, 1, 2, 3, 3, 2] +// [5, 4, 4, 5, 6, 6, 5] +// [5, 4, 4, 5, 6, 6, 5]] +// ``` +// +// Arguments: +// input: The input tensor to be padded. +// paddings: A two-column matrix specifying the padding sizes. The number of +// rows must be the same as the rank of `input`. +// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions +// do not include the borders, while in symmetric mode the padded regions +// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` +// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and +// it is `[1, 2, 3, 3, 2]` in symmetric mode. +// +// Returns The padded tensor. +func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"mode": mode} + opspec := tf.OpSpec{ + Type: "MirrorPad", + Input: []tf.Input{ + input, paddings, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// TensorArrayV3Attr is an optional argument to TensorArrayV3. +type TensorArrayV3Attr func(optionalAttr) + +// TensorArrayV3ElementShape sets the optional element_shape attribute to value. +// +// value: The expected shape of an element, if known. Used to +// validate the shapes of TensorArray elements. If this shape is not +// fully specified, gathering zero-size TensorArrays is an error. +// If not specified, defaults to +func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { + return func(m optionalAttr) { + m["element_shape"] = value + } +} + +// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. +// +// value: A boolean that determines whether writes to the TensorArray +// are allowed to grow the size. By default, this is not allowed. +// If not specified, defaults to false +func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["dynamic_size"] = value + } +} + +// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. +// +// value: If true (default), Tensors in the TensorArray are cleared +// after being read. This disables multiple read semantics but allows early +// release of memory. +// If not specified, defaults to true +func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["clear_after_read"] = value + } +} + +// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. +// +// value: If true (default is false), then all +// elements in the TensorArray will be expected to have have identical shapes. +// This allows certain behaviors, like dynamically checking for +// consistent shapes on write, and being able to fill in properly +// shaped zero tensors on stack -- even if the element_shape attribute +// is not fully defined. +// If not specified, defaults to false +func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["identical_element_shapes"] = value + } +} + +// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. +// +// value: Overrides the name used for the temporary tensor_array +// resource. Default value is the name of the 'TensorArray' op (which +// is guaranteed unique). +// If not specified, defaults to "" +func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { + return func(m optionalAttr) { + m["tensor_array_name"] = value + } +} + +// An array of Tensors of given size. +// +// Write data via Write and read via Read or Pack. +// +// Arguments: +// size: The size of the array. +// dtype: The type of the elements on the tensor_array. +// +// Returns: +// handle: The handle to the TensorArray. +// flow: A scalar used to control gradient flow. +func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "TensorArrayV3", + Input: []tf.Input{ + size, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) +} + +// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. +type MatrixSolveLsAttr func(optionalAttr) + +// MatrixSolveLsFast sets the optional fast attribute to value. +// If not specified, defaults to true +func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { + return func(m optionalAttr) { + m["fast"] = value + } +} + +// Solves one or more linear least-squares problems. +// +// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions +// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same +// type as `matrix` and shape `[..., M, K]`. +// The output is a tensor shape `[..., N, K]` where each output matrix solves +// each of the equations +// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` +// in the least squares sense. +// +// We use the following notation for (complex) matrix and right-hand sides +// in the batch: +// +// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), +// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), +// `output`=\\(X \in \mathbb{C}^{n \times k}\\), +// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). +// +// If `fast` is `True`, then the solution is computed by solving the normal +// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then +// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares +// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). +// If \\(m \lt n\\) then `output` is computed as +// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the +// minimum-norm solution to the under-determined linear system, i.e. +// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), +// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable +// when \\(A\\) is numerically full rank and has a condition number +// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is +// sufficiently large. +// +// If `fast` is `False` an algorithm based on the numerically robust complete +// orthogonal decomposition is used. This computes the minimum-norm +// least-squares solution, even when \\(A\\) is rank deficient. This path is +// typically 6-7 times slower than the fast path. If `fast` is `False` then +// `l2_regularizer` is ignored. +// +// Arguments: +// matrix: Shape is `[..., M, N]`. +// rhs: Shape is `[..., M, K]`. +// l2_regularizer: Scalar tensor. +// +// @compatibility(numpy) +// Equivalent to np.linalg.lstsq +// @end_compatibility +// +// Returns Shape is `[..., N, K]`. +func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MatrixSolveLs", + Input: []tf.Input{ + matrix, rhs, l2_regularizer, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Generates sparse cross from a list of sparse and dense tensors. +// +// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each +// representing features of one feature column. It outputs a 2D `SparseTensor` with +// the batchwise crosses of these features. +// +// For example, if the inputs are +// +// inputs[0]: SparseTensor with shape = [2, 2] +// [0, 0]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// inputs[1]: SparseTensor with shape = [2, 1] +// [0, 0]: "d" +// [1, 0]: "e" +// +// inputs[2]: Tensor [["f"], ["g"]] +// +// then the output will be +// +// shape = [2, 2] +// [0, 0]: "a_X_d_X_f" +// [1, 0]: "b_X_e_X_g" +// [1, 1]: "c_X_e_X_g" +// +// if hashed_output=true then the output will be +// +// shape = [2, 2] +// [0, 0]: FingerprintCat64( +// Fingerprint64("f"), FingerprintCat64( +// Fingerprint64("d"), Fingerprint64("a"))) +// [1, 0]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("b"))) +// [1, 1]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("c"))) +// +// Arguments: +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// dense_inputs: 2-D. Columns represented by dense `Tensor`. +// hashed_output: If true, returns the hash of the cross instead of the string. +// This will allow us avoiding string manipulations. +// num_buckets: It is used if hashed_output is true. +// output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. +// hash_key: Specify the hash_key that will be used by the `FingerprintCat64` +// function to combine the crosses fingerprints. +// +// +// +// Returns: +// output_indices: 2-D. Indices of the concatenated `SparseTensor`. +// output_values: 1-D. Non-empty values of the concatenated or hashed +// `SparseTensor`. +// output_shape: 1-D. Shape of the concatenated `SparseTensor`. +func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, hashed_output bool, num_buckets int64, hash_key int64, out_type tf.DataType, internal_type tf.DataType) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_type": out_type, "internal_type": internal_type} + opspec := tf.OpSpec{ + Type: "SparseCross", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // Generate a glob pattern matching all sharded file names. func ShardedFilespec(scope *Scope, basename tf.Output, num_shards tf.Output) (filename tf.Output) { if scope.Err() != nil { @@ -35148,29 +36279,6 @@ func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output, return op.Output(0) } -// Creates a dataset that uses a custom thread pool to compute `input_dataset`. -// -// Arguments: -// -// thread_pool: A resource produced by the ThreadPoolHandle op. -// -// -func ThreadPoolDataset(scope *Scope, input_dataset tf.Output, thread_pool tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "ThreadPoolDataset", - Input: []tf.Input{ - input_dataset, thread_pool, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Bitcasts a tensor from one type to another without copying data. // // Given a tensor `input`, this operation returns a tensor that has the same buffer @@ -35239,6 +36347,29 @@ func Bitcast(scope *Scope, input tf.Output, type_ tf.DataType) (output tf.Output return op.Output(0) } +// Creates a dataset that uses a custom thread pool to compute `input_dataset`. +// +// Arguments: +// +// thread_pool: A resource produced by the ThreadPoolHandle op. +// +// +func ThreadPoolDataset(scope *Scope, input_dataset tf.Output, thread_pool tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "ThreadPoolDataset", + Input: []tf.Input{ + input_dataset, thread_pool, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceApplyAdagradDAAttr is an optional argument to ResourceApplyAdagradDA. type ResourceApplyAdagradDAAttr func(optionalAttr) @@ -35860,6 +36991,67 @@ func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf return op.Output(0), op.Output(1), op.Output(2) } +// RandomShuffleAttr is an optional argument to RandomShuffle. +type RandomShuffleAttr func(optionalAttr) + +// RandomShuffleSeed sets the optional seed attribute to value. +// +// value: If either `seed` or `seed2` are set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, it is seeded by a +// random seed. +// If not specified, defaults to 0 +func RandomShuffleSeed(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleSeed2(value int64) RandomShuffleAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// Randomly shuffles a tensor along its first dimension. +// +// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped +// to one and only one `output[i]`. For example, a mapping that might occur for a +// 3x2 tensor is: +// +// ``` +// [[1, 2], [[5, 6], +// [3, 4], ==> [1, 2], +// [5, 6]] [3, 4]] +// ``` +// +// Arguments: +// value: The tensor to be shuffled. +// +// Returns A tensor of same shape and type as `value`, shuffled along its first +// dimension. +func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomShuffle", + Input: []tf.Input{ + value, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Selects elements from `x` or `y`, depending on `condition`. // // The `x`, and `y` tensors must all have the same shape, and the @@ -37533,6 +38725,21 @@ func RaggedTensorToTensor(scope *Scope, shape tf.Output, values tf.Output, defau return op.Output(0) } +// Computes the derivative of a Gamma random sample w.r.t. `alpha`. +func RandomGammaGrad(scope *Scope, alpha tf.Output, sample tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RandomGammaGrad", + Input: []tf.Input{ + alpha, sample, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // LRNGradAttr is an optional argument to LRNGrad. type LRNGradAttr func(optionalAttr) @@ -38356,435 +39563,6 @@ func StatelessMultinomial(scope *Scope, logits tf.Output, num_samples tf.Output, return op.Output(0) } -// QuantizedMatMulWithBiasAndReluAttr is an optional argument to QuantizedMatMulWithBiasAndRelu. -type QuantizedMatMulWithBiasAndReluAttr func(optionalAttr) - -// QuantizedMatMulWithBiasAndReluToutput sets the optional Toutput attribute to value. -// If not specified, defaults to DT_QINT32 -func QuantizedMatMulWithBiasAndReluToutput(value tf.DataType) QuantizedMatMulWithBiasAndReluAttr { - return func(m optionalAttr) { - m["Toutput"] = value - } -} - -// QuantizedMatMulWithBiasAndReluTransposeA sets the optional transpose_a attribute to value. -// -// value: If true, `a` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulWithBiasAndReluTransposeA(value bool) QuantizedMatMulWithBiasAndReluAttr { - return func(m optionalAttr) { - m["transpose_a"] = value - } -} - -// QuantizedMatMulWithBiasAndReluTransposeB sets the optional transpose_b attribute to value. -// -// value: If true, `b` is transposed before multiplication. -// If not specified, defaults to false -func QuantizedMatMulWithBiasAndReluTransposeB(value bool) QuantizedMatMulWithBiasAndReluAttr { - return func(m optionalAttr) { - m["transpose_b"] = value - } -} - -// QuantizedMatMulWithBiasAndReluInputQuantMode sets the optional input_quant_mode attribute to value. -// -// value: Input data quantization mode. Either MIN_FIRST(default) or SCALED. -// If not specified, defaults to "MIN_FIRST" -func QuantizedMatMulWithBiasAndReluInputQuantMode(value string) QuantizedMatMulWithBiasAndReluAttr { - return func(m optionalAttr) { - m["input_quant_mode"] = value - } -} - -// Perform a quantized matrix multiplication of `a` by the matrix `b` with bias -// add and relu fusion. -// -// The inputs must be two-dimensional matrices and 1D bias vector. And the inner -// dimension of `a` (after being transposed if `transpose_a` is non-zero) must -// match the outer dimension of `b` (after being transposed if `transposed_b` is -// non-zero). Then do broadcast add operation with bias values on the matrix -// multiplication result. The bias size must match inner dimension of `b`. Then do -// relu activation to get non-negative result. -// -// Arguments: -// a: A matrix to be multiplied. Must be a two-dimensional tensor of type `quint8`. -// b: A matrix to be multiplied and must be a two-dimensional tensor of type `qint8`. -// bias: A 1D bias tensor with size matching with inner dimension of `b` (after being -// transposed if `transposed_b` is non-zero). -// min_a: The float value that the lowest quantized `a` value represents. -// max_a: The float value that the highest quantized `a` value represents. -// min_b: The float value that the lowest quantized `b` value represents. -// max_b: The float value that the highest quantized `b` value represents. -// -// Returns: -// out -// min_out: The float value that the lowest quantized output value represents. -// max_out: The float value that the highest quantized output value represents. -func QuantizedMatMulWithBiasAndRelu(scope *Scope, a tf.Output, b tf.Output, bias tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulWithBiasAndReluAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "QuantizedMatMulWithBiasAndRelu", - Input: []tf.Input{ - a, b, bias, min_a, max_a, min_b, max_b, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingMomentumParametersGradAccumDebug. -type RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve Momentum embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns: -// parameters: Parameter parameters updated by the Momentum optimization algorithm. -// momenta: Parameter momenta updated by the Momentum optimization algorithm. -// gradient_accumulators: Parameter gradient_accumulators updated by the Momentum optimization algorithm. -func RetrieveTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - -// StatelessRandomUniformFullIntAttr is an optional argument to StatelessRandomUniformFullInt. -type StatelessRandomUniformFullIntAttr func(optionalAttr) - -// StatelessRandomUniformFullIntDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_UINT64 -func StatelessRandomUniformFullIntDtype(value tf.DataType) StatelessRandomUniformFullIntAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom random integers from a uniform distribution. -// -// The generated values are uniform integers covering the whole range of `dtype`. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// -// Returns Random values with specified shape. -func StatelessRandomUniformFullInt(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformFullIntAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatelessRandomUniformFullInt", - Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. -type ResourceApplyRMSPropAttr func(optionalAttr) - -// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, ms, and mom tensors is protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// Update '*var' according to the RMSProp algorithm. -// -// Note that in dense implementation of this algorithm, ms and mom will -// update even if the grad is zero, but in this sparse implementation, ms -// and mom will not update in iterations during which the grad is zero. -// -// mean_square = decay * mean_square + (1-decay) * gradient ** 2 -// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) -// -// ms <- rho * ms_{t-1} + (1-rho) * grad * grad -// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) -// var <- var - mom -// -// Arguments: -// var_: Should be from a Variable(). -// ms: Should be from a Variable(). -// mom: Should be from a Variable(). -// lr: Scaling factor. Must be a scalar. -// rho: Decay rate. Must be a scalar. -// -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyRMSProp", - Input: []tf.Input{ - var_, ms, mom, lr, rho, momentum, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. -type MaxPool3DGradAttr func(optionalAttr) - -// MaxPool3DGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of 3D max pooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool3DGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that executes a SQL query and emits rows of the result set. -// -// Arguments: -// driver_name: The database type. Currently, the only supported type is 'sqlite'. -// data_source_name: A connection string to connect to the database. -// query: A SQL query to execute. -// -// -func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "SqlDataset", - Input: []tf.Input{ - driver_name, data_source_name, query, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs deterministic pseudorandom random integers from a uniform distribution. -// -// The generated values follow a uniform distribution in the range `[minval, maxval)`. -// -// The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// minval: Minimum value (inclusive, scalar). -// maxval: Maximum value (exclusive, scalar). -// -// Returns Random values with specified shape. -func StatelessRandomUniformInt(scope *Scope, shape tf.Output, seed tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "StatelessRandomUniformInt", - Input: []tf.Input{ - shape, seed, minval, maxval, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Returns a batched diagonal tensor with a given batched diagonal values. -// -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: -// -// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a -// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: -// -// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. -// -// For example: -// -// ``` -// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] -// -// and diagonal.shape = (2, 4) -// -// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// which has shape (2, 4, 4) -// ``` -// -// Arguments: -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. -func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDiag", - Input: []tf.Input{ - diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. -type StatelessTruncatedNormalAttr func(optionalAttr) - -// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_FLOAT -func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Outputs deterministic pseudorandom values from a truncated normal distribution. -// -// The generated values follow a normal distribution with mean 0 and standard -// deviation 1, except that values whose magnitude is more than 2 standard -// deviations from the mean are dropped and re-picked. -// -// The outputs are a deterministic function of `shape` and `seed`. -// -// Arguments: -// shape: The shape of the output tensor. -// seed: 2 seeds (shape [2]). -// -// Returns Random values with specified shape. -func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "StatelessTruncatedNormal", - Input: []tf.Input{ - shape, seed, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns a copy of the input tensor. func Snapshot(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { @@ -40969,187 +41747,6 @@ func DatasetToGraph(scope *Scope, input_dataset tf.Output, optional ...DatasetTo return op.Output(0) } -// ResourceSparseApplyAdadeltaAttr is an optional argument to ResourceSparseApplyAdadelta. -type ResourceSparseApplyAdadeltaAttr func(optionalAttr) - -// ResourceSparseApplyAdadeltaUseLocking sets the optional use_locking attribute to value. -// -// value: If True, updating of the var and accum tensors will be protected by -// a lock; otherwise the behavior is undefined, but may exhibit less contention. -// If not specified, defaults to false -func ResourceSparseApplyAdadeltaUseLocking(value bool) ResourceSparseApplyAdadeltaAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// var: Should be from a Variable(). -// -// Arguments: -// -// accum: Should be from a Variable(). -// accum_update: : Should be from a Variable(). -// lr: Learning rate. Must be a scalar. -// rho: Decay factor. Must be a scalar. -// epsilon: Constant factor. Must be a scalar. -// grad: The gradient. -// indices: A vector of indices into the first dimension of var and accum. -// -// Returns the created operation. -func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output, accum_update tf.Output, lr tf.Output, rho tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyAdadeltaAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceSparseApplyAdadelta", - Input: []tf.Input{ - var_, accum, accum_update, lr, rho, epsilon, grad, indices, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// Computes sigmoid of `x` element-wise. -// -// Specifically, `y = 1 / (1 + exp(-x))`. -func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Sigmoid", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingADAMParametersGradAccumDebug. -type RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr func(optionalAttr) - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId sets the optional table_id attribute to value. -// If not specified, defaults to -1 -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_id"] = value - } -} - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName sets the optional table_name attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["table_name"] = value - } -} - -// RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig sets the optional config attribute to value. -// If not specified, defaults to "" -func RetrieveTPUEmbeddingADAMParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr { - return func(m optionalAttr) { - m["config"] = value - } -} - -// Retrieve ADAM embedding parameters with debug support. -// -// An op that retrieves optimization parameters from embedding to host -// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up -// the correct embedding table configuration. For example, this op is -// used to retrieve updated parameters before saving a checkpoint. -// -// Returns: -// parameters: Parameter parameters updated by the ADAM optimization algorithm. -// momenta: Parameter momenta updated by the ADAM optimization algorithm. -// velocities: Parameter velocities updated by the ADAM optimization algorithm. -// gradient_accumulators: Parameter gradient_accumulators updated by the ADAM optimization algorithm. -func RetrieveTPUEmbeddingADAMParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingADAMParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, velocities tf.Output, gradient_accumulators tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RetrieveTPUEmbeddingADAMParametersGradAccumDebug", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2), op.Output(3) -} - -// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. -type ResourceApplyAdamAttr func(optionalAttr) - -// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, uses the nesterov update. -// If not specified, defaults to false -func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the Adam algorithm. -// -// $$\text{lr}_t := \mathrm{learning_rate} * \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$ -// $$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$ -// $$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$ -// $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilon)$$ -// -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdam", - Input: []tf.Input{ - var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - // ResizeNearestNeighborAttr is an optional argument to ResizeNearestNeighbor. type ResizeNearestNeighborAttr func(optionalAttr) @@ -43806,35 +44403,110 @@ func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_val return op.Output(0), op.Output(1) } -// Merges summaries. +// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad. +type MaxPool3DGradAttr func(optionalAttr) + +// MaxPool3DGradDataFormat sets the optional data_format attribute to value. // -// This op creates a -// [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) -// protocol buffer that contains the union of all the values in the input -// summaries. -// -// When the Op is run, it reports an `InvalidArgument` error if multiple values -// in the summaries to merge use the same tag. +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradDataFormat(value string) MaxPool3DGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of 3D max pooling function. // // Arguments: -// inputs: Can be of any shape. Each must contain serialized `Summary` protocol -// buffers. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) { +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func MaxPool3DGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "MergeSummary", + Type: "MaxPool3DGrad", Input: []tf.Input{ - tf.OutputList(inputs), + orig_input, orig_output, grad, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } +// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp. +type ResourceApplyRMSPropAttr func(optionalAttr) + +// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, ms, and mom tensors is protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// Update '*var' according to the RMSProp algorithm. +// +// Note that in dense implementation of this algorithm, ms and mom will +// update even if the grad is zero, but in this sparse implementation, ms +// and mom will not update in iterations during which the grad is zero. +// +// mean_square = decay * mean_square + (1-decay) * gradient ** 2 +// Delta = learning_rate * gradient / sqrt(mean_square + epsilon) +// +// ms <- rho * ms_{t-1} + (1-rho) * grad * grad +// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +// var <- var - mom +// +// Arguments: +// var_: Should be from a Variable(). +// ms: Should be from a Variable(). +// mom: Should be from a Variable(). +// lr: Scaling factor. Must be a scalar. +// rho: Decay rate. Must be a scalar. +// +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyRMSProp", + Input: []tf.Input{ + var_, ms, mom, lr, rho, momentum, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + // Reshapes a SparseTensor to represent values in a new dense shape. // // This operation has the same semantics as reshape on the represented dense @@ -44118,52 +44790,6 @@ func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear t return scope.AddOperation(opspec) } -// Constructs a tensor by tiling a given tensor. -// -// This operation creates a new tensor by replicating `input` `multiples` times. -// The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements, -// and the values of `input` are replicated `multiples[i]` times along the 'i'th -// dimension. For example, tiling `[a b c d]` by `[2]` produces -// `[a b c d a b c d]`. -// -// >>> a = tf.constant([[1,2,3],[4,5,6]], tf.int32) -// >>> b = tf.constant([1,2], tf.int32) -// >>> tf.tile(a, b) -// -// >>> c = tf.constant([2,1], tf.int32) -// >>> tf.tile(a, c) -// -// >>> d = tf.constant([2,2], tf.int32) -// >>> tf.tile(a, d) -// -// -// Arguments: -// input: 1-D or higher. -// multiples: 1-D. Length must be the same as the number of dimensions in `input` -func Tile(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Tile", - Input: []tf.Input{ - input, multiples, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes softmax cross entropy cost and gradients to backpropagate. // // Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept @@ -44263,405 +44889,52 @@ func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf return scope.AddOperation(opspec) } -// Elementwise computes the bitwise left-shift of `x` and `y`. +// Constructs a tensor by tiling a given tensor. // -// If `y` is negative, or greater than or equal to the width of `x` in bits the -// result is implementation defined. +// This operation creates a new tensor by replicating `input` `multiples` times. +// The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements, +// and the values of `input` are replicated `multiples[i]` times along the 'i'th +// dimension. For example, tiling `[a b c d]` by `[2]` produces +// `[a b c d a b c d]`. // -// Example: +// >>> a = tf.constant([[1,2,3],[4,5,6]], tf.int32) +// >>> b = tf.constant([1,2], tf.int32) +// >>> tf.tile(a, b) +// +// >>> c = tf.constant([2,1], tf.int32) +// >>> tf.tile(a, c) +// +// >>> d = tf.constant([2,2], tf.int32) +// >>> tf.tile(a, d) +// // -// ```python -// import tensorflow as tf -// from tensorflow.python.ops import bitwise_ops -// import numpy as np -// dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64] -// -// for dtype in dtype_list: -// lhs = tf.constant([-1, -5, -3, -14], dtype=dtype) -// rhs = tf.constant([5, 0, 7, 11], dtype=dtype) -// -// left_shift_result = bitwise_ops.left_shift(lhs, rhs) -// -// print(left_shift_result) -// -// # This will print: -// # tf.Tensor([ -32 -5 -128 0], shape=(4,), dtype=int8) -// # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int16) -// # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int32) -// # tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int64) -// -// lhs = np.array([-2, 64, 101, 32], dtype=np.int8) -// rhs = np.array([-1, -5, -3, -14], dtype=np.int8) -// bitwise_ops.left_shift(lhs, rhs) -// # -// ``` -// -func LeftShift(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// input: 1-D or higher. +// multiples: 1-D. Length must be the same as the number of dimensions in `input` +func Tile(scope *Scope, input tf.Output, multiples tf.Output) (output tf.Output) { if scope.Err() != nil { return } opspec := tf.OpSpec{ - Type: "LeftShift", + Type: "Tile", Input: []tf.Input{ - x, y, + input, multiples, }, } op := scope.AddOperation(opspec) return op.Output(0) } -// Generates a feature cross from a list of tensors, and returns it as a -// RaggedTensor. See `tf.ragged.cross` for more details. -// -// Arguments: -// ragged_values: The values tensor for each RaggedTensor input. -// ragged_row_splits: The row_splits tensor for each RaggedTensor input. -// sparse_indices: The indices tensor for each SparseTensor input. -// sparse_values: The values tensor for each SparseTensor input. -// sparse_shape: The dense_shape tensor for each SparseTensor input. -// dense_inputs: The tf.Tensor inputs. -// input_order: String specifying the tensor type for each input. The `i`th character in -// this string specifies the type of the `i`th input, and is one of: 'R' (ragged), -// 'D' (dense), or 'S' (sparse). This attr is used to ensure that the crossed -// values are combined in the order of the inputs from the call to tf.ragged.cross. -// -// -// -// -// -// -// Returns: -// output_values: The `values` for the returned `RaggedTensor`. -// output_row_splits: The `row_splits` for the returned `RaggedTensor`. -func RaggedCross(scope *Scope, ragged_values []tf.Output, ragged_row_splits []tf.Output, sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shape []tf.Output, dense_inputs []tf.Output, input_order string, hashed_output bool, num_buckets int64, hash_key int64, out_values_type tf.DataType, out_row_splits_type tf.DataType) (output_values tf.Output, output_row_splits tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"input_order": input_order, "hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_values_type": out_values_type, "out_row_splits_type": out_row_splits_type} - opspec := tf.OpSpec{ - Type: "RaggedCross", - Input: []tf.Input{ - tf.OutputList(ragged_values), tf.OutputList(ragged_row_splits), tf.OutputList(sparse_indices), tf.OutputList(sparse_values), tf.OutputList(sparse_shape), tf.OutputList(dense_inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// Output a fact about factorials. -func Fact(scope *Scope) (fact tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Fact", - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Pads a tensor with mirrored values. -// -// This operation pads a `input` with mirrored values according to the `paddings` -// you specify. `paddings` is an integer tensor with shape `[n, 2]`, where n is -// the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -// how many values to add before the contents of `input` in that dimension, and -// `paddings[D, 1]` indicates how many values to add after the contents of `input` -// in that dimension. Both `paddings[D, 0]` and `paddings[D, 1]` must be no greater -// than `input.dim_size(D)` (or `input.dim_size(D) - 1`) if `copy_border` is true -// (if false, respectively). -// -// The padded size of each dimension D of the output is: -// -// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` -// -// For example: -// -// ``` -// # 't' is [[1, 2, 3], [4, 5, 6]]. -// # 'paddings' is [[1, 1]], [2, 2]]. -// # 'mode' is SYMMETRIC. -// # rank of 't' is 2. -// pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2] -// [2, 1, 1, 2, 3, 3, 2] -// [5, 4, 4, 5, 6, 6, 5] -// [5, 4, 4, 5, 6, 6, 5]] -// ``` -// -// Arguments: -// input: The input tensor to be padded. -// paddings: A two-column matrix specifying the padding sizes. The number of -// rows must be the same as the rank of `input`. -// mode: Either `REFLECT` or `SYMMETRIC`. In reflect mode the padded regions -// do not include the borders, while in symmetric mode the padded regions -// do include the borders. For example, if `input` is `[1, 2, 3]` and `paddings` -// is `[0, 2]`, then the output is `[1, 2, 3, 2, 1]` in reflect mode, and -// it is `[1, 2, 3, 3, 2]` in symmetric mode. -// -// Returns The padded tensor. -func MirrorPad(scope *Scope, input tf.Output, paddings tf.Output, mode string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"mode": mode} - opspec := tf.OpSpec{ - Type: "MirrorPad", - Input: []tf.Input{ - input, paddings, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// TensorArrayV3Attr is an optional argument to TensorArrayV3. -type TensorArrayV3Attr func(optionalAttr) - -// TensorArrayV3ElementShape sets the optional element_shape attribute to value. -// -// value: The expected shape of an element, if known. Used to -// validate the shapes of TensorArray elements. If this shape is not -// fully specified, gathering zero-size TensorArrays is an error. -// If not specified, defaults to -func TensorArrayV3ElementShape(value tf.Shape) TensorArrayV3Attr { - return func(m optionalAttr) { - m["element_shape"] = value - } -} - -// TensorArrayV3DynamicSize sets the optional dynamic_size attribute to value. -// -// value: A boolean that determines whether writes to the TensorArray -// are allowed to grow the size. By default, this is not allowed. -// If not specified, defaults to false -func TensorArrayV3DynamicSize(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["dynamic_size"] = value - } -} - -// TensorArrayV3ClearAfterRead sets the optional clear_after_read attribute to value. -// -// value: If true (default), Tensors in the TensorArray are cleared -// after being read. This disables multiple read semantics but allows early -// release of memory. -// If not specified, defaults to true -func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["clear_after_read"] = value - } -} - -// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. -// -// value: If true (default is false), then all -// elements in the TensorArray will be expected to have have identical shapes. -// This allows certain behaviors, like dynamically checking for -// consistent shapes on write, and being able to fill in properly -// shaped zero tensors on stack -- even if the element_shape attribute -// is not fully defined. -// If not specified, defaults to false -func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { - return func(m optionalAttr) { - m["identical_element_shapes"] = value - } -} - -// TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. -// -// value: Overrides the name used for the temporary tensor_array -// resource. Default value is the name of the 'TensorArray' op (which -// is guaranteed unique). -// If not specified, defaults to "" -func TensorArrayV3TensorArrayName(value string) TensorArrayV3Attr { - return func(m optionalAttr) { - m["tensor_array_name"] = value - } -} - -// An array of Tensors of given size. -// -// Write data via Write and read via Read or Pack. -// -// Arguments: -// size: The size of the array. -// dtype: The type of the elements on the tensor_array. -// -// Returns: -// handle: The handle to the TensorArray. -// flow: A scalar used to control gradient flow. -func TensorArrayV3(scope *Scope, size tf.Output, dtype tf.DataType, optional ...TensorArrayV3Attr) (handle tf.Output, flow tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "TensorArrayV3", - Input: []tf.Input{ - size, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1) -} - -// MatrixSolveLsAttr is an optional argument to MatrixSolveLs. -type MatrixSolveLsAttr func(optionalAttr) - -// MatrixSolveLsFast sets the optional fast attribute to value. -// If not specified, defaults to true -func MatrixSolveLsFast(value bool) MatrixSolveLsAttr { - return func(m optionalAttr) { - m["fast"] = value - } -} - -// Solves one or more linear least-squares problems. -// -// `matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions -// form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same -// type as `matrix` and shape `[..., M, K]`. -// The output is a tensor shape `[..., N, K]` where each output matrix solves -// each of the equations -// `matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]` -// in the least squares sense. -// -// We use the following notation for (complex) matrix and right-hand sides -// in the batch: -// -// `matrix`=\\(A \in \mathbb{C}^{m \times n}\\), -// `rhs`=\\(B \in \mathbb{C}^{m \times k}\\), -// `output`=\\(X \in \mathbb{C}^{n \times k}\\), -// `l2_regularizer`=\\(\lambda \in \mathbb{R}\\). -// -// If `fast` is `True`, then the solution is computed by solving the normal -// equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then -// \\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares -// problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 + \lambda ||Z||_F^2\\). -// If \\(m \lt n\\) then `output` is computed as -// \\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the -// minimum-norm solution to the under-determined linear system, i.e. -// \\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\), -// subject to \\(A Z = B\\). Notice that the fast path is only numerically stable -// when \\(A\\) is numerically full rank and has a condition number -// \\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or \\(\lambda\\) is -// sufficiently large. -// -// If `fast` is `False` an algorithm based on the numerically robust complete -// orthogonal decomposition is used. This computes the minimum-norm -// least-squares solution, even when \\(A\\) is rank deficient. This path is -// typically 6-7 times slower than the fast path. If `fast` is `False` then -// `l2_regularizer` is ignored. -// -// Arguments: -// matrix: Shape is `[..., M, N]`. -// rhs: Shape is `[..., M, K]`. -// l2_regularizer: Scalar tensor. -// -// @compatibility(numpy) -// Equivalent to np.linalg.lstsq -// @end_compatibility -// -// Returns Shape is `[..., N, K]`. -func MatrixSolveLs(scope *Scope, matrix tf.Output, rhs tf.Output, l2_regularizer tf.Output, optional ...MatrixSolveLsAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MatrixSolveLs", - Input: []tf.Input{ - matrix, rhs, l2_regularizer, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Generates sparse cross from a list of sparse and dense tensors. -// -// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each -// representing features of one feature column. It outputs a 2D `SparseTensor` with -// the batchwise crosses of these features. -// -// For example, if the inputs are -// -// inputs[0]: SparseTensor with shape = [2, 2] -// [0, 0]: "a" -// [1, 0]: "b" -// [1, 1]: "c" -// -// inputs[1]: SparseTensor with shape = [2, 1] -// [0, 0]: "d" -// [1, 0]: "e" -// -// inputs[2]: Tensor [["f"], ["g"]] -// -// then the output will be -// -// shape = [2, 2] -// [0, 0]: "a_X_d_X_f" -// [1, 0]: "b_X_e_X_g" -// [1, 1]: "c_X_e_X_g" -// -// if hashed_output=true then the output will be -// -// shape = [2, 2] -// [0, 0]: FingerprintCat64( -// Fingerprint64("f"), FingerprintCat64( -// Fingerprint64("d"), Fingerprint64("a"))) -// [1, 0]: FingerprintCat64( -// Fingerprint64("g"), FingerprintCat64( -// Fingerprint64("e"), Fingerprint64("b"))) -// [1, 1]: FingerprintCat64( -// Fingerprint64("g"), FingerprintCat64( -// Fingerprint64("e"), Fingerprint64("c"))) -// -// Arguments: -// indices: 2-D. Indices of each input `SparseTensor`. -// values: 1-D. values of each `SparseTensor`. -// shapes: 1-D. Shapes of each `SparseTensor`. -// dense_inputs: 2-D. Columns represented by dense `Tensor`. -// hashed_output: If true, returns the hash of the cross instead of the string. -// This will allow us avoiding string manipulations. -// num_buckets: It is used if hashed_output is true. -// output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. -// hash_key: Specify the hash_key that will be used by the `FingerprintCat64` -// function to combine the crosses fingerprints. -// -// -// -// Returns: -// output_indices: 2-D. Indices of the concatenated `SparseTensor`. -// output_values: 1-D. Non-empty values of the concatenated or hashed -// `SparseTensor`. -// output_shape: 1-D. Shape of the concatenated `SparseTensor`. -func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, hashed_output bool, num_buckets int64, hash_key int64, out_type tf.DataType, internal_type tf.DataType) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"hashed_output": hashed_output, "num_buckets": num_buckets, "hash_key": hash_key, "out_type": out_type, "internal_type": internal_type} - opspec := tf.OpSpec{ - Type: "SparseCross", - Input: []tf.Input{ - tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // Creates and returns an empty tensor list. // // All list elements must be tensors of dtype element_dtype and shape compatible @@ -44706,6 +44979,188 @@ func ConfigureTPUEmbedding(scope *Scope, config string) (o *tf.Operation) { return scope.AddOperation(opspec) } +// QuantizedMatMulWithBiasAndReluAttr is an optional argument to QuantizedMatMulWithBiasAndRelu. +type QuantizedMatMulWithBiasAndReluAttr func(optionalAttr) + +// QuantizedMatMulWithBiasAndReluToutput sets the optional Toutput attribute to value. +// If not specified, defaults to DT_QINT32 +func QuantizedMatMulWithBiasAndReluToutput(value tf.DataType) QuantizedMatMulWithBiasAndReluAttr { + return func(m optionalAttr) { + m["Toutput"] = value + } +} + +// QuantizedMatMulWithBiasAndReluTransposeA sets the optional transpose_a attribute to value. +// +// value: If true, `a` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulWithBiasAndReluTransposeA(value bool) QuantizedMatMulWithBiasAndReluAttr { + return func(m optionalAttr) { + m["transpose_a"] = value + } +} + +// QuantizedMatMulWithBiasAndReluTransposeB sets the optional transpose_b attribute to value. +// +// value: If true, `b` is transposed before multiplication. +// If not specified, defaults to false +func QuantizedMatMulWithBiasAndReluTransposeB(value bool) QuantizedMatMulWithBiasAndReluAttr { + return func(m optionalAttr) { + m["transpose_b"] = value + } +} + +// QuantizedMatMulWithBiasAndReluInputQuantMode sets the optional input_quant_mode attribute to value. +// +// value: Input data quantization mode. Either MIN_FIRST(default) or SCALED. +// If not specified, defaults to "MIN_FIRST" +func QuantizedMatMulWithBiasAndReluInputQuantMode(value string) QuantizedMatMulWithBiasAndReluAttr { + return func(m optionalAttr) { + m["input_quant_mode"] = value + } +} + +// Perform a quantized matrix multiplication of `a` by the matrix `b` with bias +// add and relu fusion. +// +// The inputs must be two-dimensional matrices and 1D bias vector. And the inner +// dimension of `a` (after being transposed if `transpose_a` is non-zero) must +// match the outer dimension of `b` (after being transposed if `transposed_b` is +// non-zero). Then do broadcast add operation with bias values on the matrix +// multiplication result. The bias size must match inner dimension of `b`. Then do +// relu activation to get non-negative result. +// +// Arguments: +// a: A matrix to be multiplied. Must be a two-dimensional tensor of type `quint8`. +// b: A matrix to be multiplied and must be a two-dimensional tensor of type `qint8`. +// bias: A 1D bias tensor with size matching with inner dimension of `b` (after being +// transposed if `transposed_b` is non-zero). +// min_a: The float value that the lowest quantized `a` value represents. +// max_a: The float value that the highest quantized `a` value represents. +// min_b: The float value that the lowest quantized `b` value represents. +// max_b: The float value that the highest quantized `b` value represents. +// +// Returns: +// out +// min_out: The float value that the lowest quantized output value represents. +// max_out: The float value that the highest quantized output value represents. +func QuantizedMatMulWithBiasAndRelu(scope *Scope, a tf.Output, b tf.Output, bias tf.Output, min_a tf.Output, max_a tf.Output, min_b tf.Output, max_b tf.Output, optional ...QuantizedMatMulWithBiasAndReluAttr) (out tf.Output, min_out tf.Output, max_out tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "QuantizedMatMulWithBiasAndRelu", + Input: []tf.Input{ + a, b, bias, min_a, max_a, min_b, max_b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingMomentumParametersGradAccumDebug. +type RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr func(optionalAttr) + +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId sets the optional table_id attribute to value. +// If not specified, defaults to -1 +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableId(value int64) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_id"] = value + } +} + +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName sets the optional table_name attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugTableName(value string) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["table_name"] = value + } +} + +// RetrieveTPUEmbeddingMomentumParametersGradAccumDebugConfig sets the optional config attribute to value. +// If not specified, defaults to "" +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebugConfig(value string) RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr { + return func(m optionalAttr) { + m["config"] = value + } +} + +// Retrieve Momentum embedding parameters with debug support. +// +// An op that retrieves optimization parameters from embedding to host +// memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up +// the correct embedding table configuration. For example, this op is +// used to retrieve updated parameters before saving a checkpoint. +// +// Returns: +// parameters: Parameter parameters updated by the Momentum optimization algorithm. +// momenta: Parameter momenta updated by the Momentum optimization algorithm. +// gradient_accumulators: Parameter gradient_accumulators updated by the Momentum optimization algorithm. +func RetrieveTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, num_shards int64, shard_id int64, optional ...RetrieveTPUEmbeddingMomentumParametersGradAccumDebugAttr) (parameters tf.Output, momenta tf.Output, gradient_accumulators tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"num_shards": num_shards, "shard_id": shard_id} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + +// StatelessRandomUniformFullIntAttr is an optional argument to StatelessRandomUniformFullInt. +type StatelessRandomUniformFullIntAttr func(optionalAttr) + +// StatelessRandomUniformFullIntDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_UINT64 +func StatelessRandomUniformFullIntDtype(value tf.DataType) StatelessRandomUniformFullIntAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom random integers from a uniform distribution. +// +// The generated values are uniform integers covering the whole range of `dtype`. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessRandomUniformFullInt(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomUniformFullIntAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessRandomUniformFullInt", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Shuts down a running distributed TPU system. // // The op returns an error if no system is running. @@ -46522,6 +46977,98 @@ func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_value return op.Output(0), op.Output(1), op.Output(2) } +// Returns a batched diagonal tensor with a given batched diagonal values. +// +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: +// +// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a +// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: +// +// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. +// +// For example: +// +// ``` +// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] +// +// and diagonal.shape = (2, 4) +// +// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] +// +// which has shape (2, 4, 4) +// ``` +// +// Arguments: +// diagonal: Rank `k`, where `k >= 1`. +// +// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. +func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixDiag", + Input: []tf.Input{ + diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal. +type StatelessTruncatedNormalAttr func(optionalAttr) + +// StatelessTruncatedNormalDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_FLOAT +func StatelessTruncatedNormalDtype(value tf.DataType) StatelessTruncatedNormalAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Outputs deterministic pseudorandom values from a truncated normal distribution. +// +// The generated values follow a normal distribution with mean 0 and standard +// deviation 1, except that values whose magnitude is more than 2 standard +// deviations from the mean are dropped and re-picked. +// +// The outputs are a deterministic function of `shape` and `seed`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// +// Returns Random values with specified shape. +func StatelessTruncatedNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessTruncatedNormalAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "StatelessTruncatedNormal", + Input: []tf.Input{ + shape, seed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr is an optional argument to RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug. type RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugAttr func(optionalAttr) @@ -46863,6 +47410,153 @@ func LoadTPUEmbeddingMomentumParametersGradAccumDebug(scope *Scope, parameters t return scope.AddOperation(opspec) } +// DecodeCompressedAttr is an optional argument to DecodeCompressed. +type DecodeCompressedAttr func(optionalAttr) + +// DecodeCompressedCompressionType sets the optional compression_type attribute to value. +// +// value: A scalar containing either (i) the empty string (no +// compression), (ii) "ZLIB", or (iii) "GZIP". +// If not specified, defaults to "" +func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { + return func(m optionalAttr) { + m["compression_type"] = value + } +} + +// Decompress strings. +// +// This op decompresses each element of the `bytes` input `Tensor`, which +// is assumed to be compressed using the given `compression_type`. +// +// The `output` is a string `Tensor` of the same shape as `bytes`, +// each element containing the decompressed data from the corresponding +// element in `bytes`. +// +// Arguments: +// bytes: A Tensor of string which is compressed. +// +// Returns A Tensor with the same shape as input `bytes`, uncompressed +// from bytes. +func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeCompressed", + Input: []tf.Input{ + bytes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// NonDeterministicIntsAttr is an optional argument to NonDeterministicInts. +type NonDeterministicIntsAttr func(optionalAttr) + +// NonDeterministicIntsDtype sets the optional dtype attribute to value. +// +// value: The type of the output. +// If not specified, defaults to DT_INT64 +func NonDeterministicIntsDtype(value tf.DataType) NonDeterministicIntsAttr { + return func(m optionalAttr) { + m["dtype"] = value + } +} + +// Non-deterministically generates some integers. +// +// This op may use some OS-provided source of non-determinism (e.g. an RNG), so each execution will give different results. +// +// Arguments: +// shape: The shape of the output tensor. +// +// Returns Non-deterministic integer values with specified shape. +func NonDeterministicInts(scope *Scope, shape tf.Output, optional ...NonDeterministicIntsAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "NonDeterministicInts", + Input: []tf.Input{ + shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// MultinomialAttr is an optional argument to Multinomial. +type MultinomialAttr func(optionalAttr) + +// MultinomialSeed sets the optional seed attribute to value. +// +// value: If either seed or seed2 is set to be non-zero, the internal random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func MultinomialSeed(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// MultinomialSeed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func MultinomialSeed2(value int64) MultinomialAttr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// MultinomialOutputDtype sets the optional output_dtype attribute to value. +// If not specified, defaults to DT_INT64 +func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { + return func(m optionalAttr) { + m["output_dtype"] = value + } +} + +// Draws samples from a multinomial distribution. +// +// Arguments: +// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` +// represents the unnormalized log probabilities for all classes. +// num_samples: 0-D. Number of independent samples to draw for each row slice. +// +// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` +// contains the drawn class labels with range `[0, num_classes)`. +func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Multinomial", + Input: []tf.Input{ + logits, num_samples, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // SerializeSparseAttr is an optional argument to SerializeSparse. type SerializeSparseAttr func(optionalAttr) @@ -46902,6 +47596,26 @@ func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Ou return op.Output(0) } +// An op which supports basic einsum op with 2 inputs and 1 output. +// +// This op has better TPU performance since it doesn't have explicitly reshape and +// transpose operations as tf.einsum does. +func XlaEinsum(scope *Scope, a tf.Output, b tf.Output, equation string) (product tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"equation": equation} + opspec := tf.OpSpec{ + Type: "XlaEinsum", + Input: []tf.Input{ + a, b, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Extracts the average gradient in the given ConditionalAccumulator. // // The op blocks until sufficient (i.e., more than num_required) @@ -48425,6 +49139,57 @@ func ConfigureDistributedTPU(scope *Scope, optional ...ConfigureDistributedTPUAt return op.Output(0) } +// Creates a dataset that executes a SQL query and emits rows of the result set. +// +// Arguments: +// driver_name: The database type. Currently, the only supported type is 'sqlite'. +// data_source_name: A connection string to connect to the database. +// query: A SQL query to execute. +// +// +func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "SqlDataset", + Input: []tf.Input{ + driver_name, data_source_name, query, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs deterministic pseudorandom random integers from a uniform distribution. +// +// The generated values follow a uniform distribution in the range `[minval, maxval)`. +// +// The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. +// +// Arguments: +// shape: The shape of the output tensor. +// seed: 2 seeds (shape [2]). +// minval: Minimum value (inclusive, scalar). +// maxval: Maximum value (exclusive, scalar). +// +// Returns Random values with specified shape. +func StatelessRandomUniformInt(scope *Scope, shape tf.Output, seed tf.Output, minval tf.Output, maxval tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "StatelessRandomUniformInt", + Input: []tf.Input{ + shape, seed, minval, maxval, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // LoadTPUEmbeddingRMSPropParametersAttr is an optional argument to LoadTPUEmbeddingRMSPropParameters. type LoadTPUEmbeddingRMSPropParametersAttr func(optionalAttr) @@ -48949,106 +49714,6 @@ func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (out return op.Output(0) } -// MultinomialAttr is an optional argument to Multinomial. -type MultinomialAttr func(optionalAttr) - -// MultinomialSeed sets the optional seed attribute to value. -// -// value: If either seed or seed2 is set to be non-zero, the internal random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func MultinomialSeed(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// MultinomialSeed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func MultinomialSeed2(value int64) MultinomialAttr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// MultinomialOutputDtype sets the optional output_dtype attribute to value. -// If not specified, defaults to DT_INT64 -func MultinomialOutputDtype(value tf.DataType) MultinomialAttr { - return func(m optionalAttr) { - m["output_dtype"] = value - } -} - -// Draws samples from a multinomial distribution. -// -// Arguments: -// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]` -// represents the unnormalized log probabilities for all classes. -// num_samples: 0-D. Number of independent samples to draw for each row slice. -// -// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]` -// contains the drawn class labels with range `[0, num_classes)`. -func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Multinomial", - Input: []tf.Input{ - logits, num_samples, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// NonDeterministicIntsAttr is an optional argument to NonDeterministicInts. -type NonDeterministicIntsAttr func(optionalAttr) - -// NonDeterministicIntsDtype sets the optional dtype attribute to value. -// -// value: The type of the output. -// If not specified, defaults to DT_INT64 -func NonDeterministicIntsDtype(value tf.DataType) NonDeterministicIntsAttr { - return func(m optionalAttr) { - m["dtype"] = value - } -} - -// Non-deterministically generates some integers. -// -// This op may use some OS-provided source of non-determinism (e.g. an RNG), so each execution will give different results. -// -// Arguments: -// shape: The shape of the output tensor. -// -// Returns Non-deterministic integer values with specified shape. -func NonDeterministicInts(scope *Scope, shape tf.Output, optional ...NonDeterministicIntsAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "NonDeterministicInts", - Input: []tf.Input{ - shape, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Creates a dataset that caches elements from `input_dataset`. // // A CacheDataset will iterate over the input_dataset, and store tensors. If the @@ -49796,53 +50461,6 @@ func AutoShardDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Outp return op.Output(0) } -// DecodeCompressedAttr is an optional argument to DecodeCompressed. -type DecodeCompressedAttr func(optionalAttr) - -// DecodeCompressedCompressionType sets the optional compression_type attribute to value. -// -// value: A scalar containing either (i) the empty string (no -// compression), (ii) "ZLIB", or (iii) "GZIP". -// If not specified, defaults to "" -func DecodeCompressedCompressionType(value string) DecodeCompressedAttr { - return func(m optionalAttr) { - m["compression_type"] = value - } -} - -// Decompress strings. -// -// This op decompresses each element of the `bytes` input `Tensor`, which -// is assumed to be compressed using the given `compression_type`. -// -// The `output` is a string `Tensor` of the same shape as `bytes`, -// each element containing the decompressed data from the corresponding -// element in `bytes`. -// -// Arguments: -// bytes: A Tensor of string which is compressed. -// -// Returns A Tensor with the same shape as input `bytes`, uncompressed -// from bytes. -func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompressedAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeCompressed", - Input: []tf.Input{ - bytes, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // DecodeJpegAttr is an optional argument to DecodeJpeg. type DecodeJpegAttr func(optionalAttr) diff --git a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml index f40090ac45d..19f5e29da2b 100644 --- a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml +++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml @@ -35,7 +35,7 @@ 1.8 2.4.5 2.7.3 - 4.11 + 4.13.1 diff --git a/tensorflow/java/maven/tensorflow-hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml index e900d81e5da..675a3369cf1 100644 --- a/tensorflow/java/maven/tensorflow-hadoop/pom.xml +++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml @@ -16,7 +16,7 @@ 1.6 2.6.0 3.5.1 - 4.11 + 4.13.1 diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java index 3c9a678cf56..97de99cb75e 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java +++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java @@ -51,7 +51,7 @@ final class NativeLibrary { // (1) The native library has already been statically loaded, OR // (2) The required native code has been statically linked (through a custom launcher), OR // (3) The native code is part of another library (such as an application-level library) - // that has already been loaded. For example, tensorflow/examples/android and + // that has already been loaded. For example, tensorflow/tools/android/test and // tensorflow/tools/android/inference_interface include the required native code in // differently named libraries. // diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 754a86e5916..c4a0b5d525d 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -233,9 +233,7 @@ cc_library( "interpreter.cc", "interpreter_builder.cc", "model_builder.cc", - "mutable_op_resolver.cc", "optional_debug_tools.cc", - "stderr_reporter.cc", ], hdrs = FRAMEWORK_LIB_HDRS, compatible_with = get_compatible_with_portable(), @@ -251,14 +249,17 @@ cc_library( ":kernel_api", ":memory_planner", ":minimal_logging", + ":mutable_op_resolver", ":shared_library", ":simple_memory_arena", + ":stderr_reporter", ":string", ":type_to_tflitetype", ":util", ":version", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", "//tensorflow/lite/delegates:status", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/experimental/resource", @@ -266,6 +267,7 @@ cc_library( "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/profiling:platform_profiler", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", ], alwayslink = 1, ) @@ -293,6 +295,7 @@ cc_library( ":version", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/core/api:verifier", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/experimental/resource", "//tensorflow/lite/nnapi:nnapi_implementation", @@ -300,6 +303,70 @@ cc_library( ], ) +cc_library( + name = "error_reporter", + hdrs = ["error_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], + deps = [ + "//tensorflow/lite:stderr_reporter", + "//tensorflow/lite/core/api:error_reporter", + ], +) + +cc_library( + name = "stderr_reporter", + srcs = ["stderr_reporter.cc"], + hdrs = ["stderr_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], + deps = [ + ":minimal_logging", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api:error_reporter", + ], +) + +cc_library( + name = "op_resolver", + hdrs = ["op_resolver.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], + deps = [ + "//tensorflow/lite:mutable_op_resolver", + "//tensorflow/lite/core/api:op_resolver", + ], +) + +cc_library( + name = "mutable_op_resolver", + srcs = ["mutable_op_resolver.cc"], + hdrs = ["mutable_op_resolver.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], + deps = [ + ":util", + "//tensorflow/lite/core/api:op_resolver", + "//tensorflow/lite/schema:schema_fbs", + ], +) + cc_library( name = "string_util", srcs = ["string_util.cc"], @@ -669,6 +736,13 @@ cc_library( hdrs = ["core/macros.h"], ) +cc_library( + name = "stateful_error_reporter", + hdrs = ["stateful_error_reporter.h"], + compatible_with = get_compatible_with_portable(), + deps = ["//tensorflow/lite/core/api"], +) + # Shared lib target for convenience, pulls in the core runtime and builtin ops. # Note: This target is not yet finalized, and the exact set of exported (C/C++) # APIs is subject to change. The output library name is platform dependent: diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index ac7e373f18c..77a79f11559 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -20,8 +20,6 @@ # This has only been tested on Windows, Linux and macOS. # # The following are not currently supported: -# - GPU acceleration -# - Android # - iOS # - Micro backend # - Tests @@ -39,12 +37,12 @@ set(TENSORFLOW_SOURCE_DIR "" CACHE PATH ) if(NOT TENSORFLOW_SOURCE_DIR) get_filename_component(TENSORFLOW_SOURCE_DIR - "${CMAKE_SOURCE_DIR}/../../" + "${CMAKE_CURRENT_LIST_DIR}/../../" ABSOLUTE ) endif() set(TF_SOURCE_DIR "${TENSORFLOW_SOURCE_DIR}/tensorflow") -set(TFLITE_SOURCE_DIR "${CMAKE_SOURCE_DIR}") +set(TFLITE_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}") set(CMAKE_MODULE_PATH "${TFLITE_SOURCE_DIR}/tools/cmake/modules" ${CMAKE_MODULE_PATH} @@ -60,7 +58,7 @@ option(TFLITE_ENABLE_RUY "Enable experimental RUY integration" OFF) option(TFLITE_ENABLE_RESOURCE "Enable experimental support for resources" ON) option(TFLITE_ENABLE_NNAPI "Enable NNAPI (Android only)." ON) option(TFLITE_ENABLE_MMAP "Enable MMAP (unsupported on Windows)" ON) -option(TFLITE_ENABLE_GPU "Enable GPU (not supported)" OFF) +option(TFLITE_ENABLE_GPU "Enable GPU" OFF) # This must be enabled when converting from TF models with SELECT_TF_OPS # enabled. # https://www.tensorflow.org/lite/guide/ops_select#converting_the_model @@ -192,9 +190,60 @@ if(TFLITE_ENABLE_FLEX) ) endif() if(TFLITE_ENABLE_GPU) - # Implementation is under delegates/gpu. - message(FATAL_ERROR - "GPU acceleration is not currently supported in CMake builds" + find_package(opencl_headers REQUIRED) + find_package(vulkan_headers REQUIRED) + populate_tflite_source_vars( + "delegates/gpu/cl" TFLITE_DELEGATES_GPU_CL_SRCS + FILTER "(_test|gl_interop|egl_sync)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/cl/kernels" TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/cl/kernels/special" + TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/cl/selectors" TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common" TFLITE_DELEGATES_GPU_COMMON_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common/default" TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common/memory_management" + TFLITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + populate_tflite_source_vars( + "delegates/gpu/common/transformations" + TFLITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_SRCS + FILTER "(_test)\\.(cc|h)$" + ) + list(APPEND TFLITE_DELEGATES_GPU_SRCS + ${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc + ${TFLITE_DELEGATES_GPU_CL_SRCS} + ${TFLITE_DELEGATES_GPU_CL_KERNELS_SRCS} + ${TFLITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_SRCS} + ${TFLITE_DELEGATES_GPU_CL_SELECTORS_SRCS} + ${TFLITE_SOURCE_DIR}/delegates/gpu/cl/selectors/default/default_selector.cc + ${TFLITE_DELEGATES_GPU_COMMON_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_DEFAULT_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_MEMORY_MANAGEMENT_SRCS} + ${TFLITE_DELEGATES_GPU_COMMON_TRANSFORMATIONS_SRCS} + ) + list(APPEND TFLITE_TARGET_PUBLIC_OPTIONS "-DCL_DELEGATE_NO_GL" "-DEGL_NO_X11") + list(APPEND TFLITE_TARGET_DEPENDENCIES + absl::any + absl::flat_hash_map ) endif() if(_TFLITE_ENABLE_NNAPI) @@ -266,10 +315,13 @@ populate_tflite_source_vars("kernels/internal/reference/sparse_ops" ) # Common include directories +set(TFLITE_INCLUDE_DIRS + "${TENSORFLOW_SOURCE_DIR}" + "${TFLITE_FLATBUFFERS_SCHEMA_DIR}" +) include_directories( BEFORE - "${TENSORFLOW_SOURCE_DIR}" - "${TFLITE_FLATBUFFERS_SCHEMA_DIR}" + ${TFLITE_INCLUDE_DIRS} ) # TFLite library @@ -278,6 +330,7 @@ add_library(tensorflow-lite ${TFLITE_CORE_SRCS} ${TFLITE_C_SRCS} ${TFLITE_DELEGATES_FLEX_SRCS} + ${TFLITE_DELEGATES_GPU_SRCS} ${TFLITE_DELEGATES_NNAPI_SRCS} ${TFLITE_DELEGATES_SRCS} ${TFLITE_DELEGATES_XNNPACK_SRCS} @@ -294,6 +347,13 @@ add_library(tensorflow-lite ${TFLITE_KERNEL_SRCS} ${TFLITE_NNAPI_SRCS} ${TFLITE_SRCS} + ${TFLITE_SOURCE_DIR}/profiling/platform_profiler.cc + ${TFLITE_SOURCE_DIR}/schema/schema_utils.cc + ${TFLITE_SOURCE_DIR}/tools/optimize/sparsity/format_converter.cc +) +target_include_directories(tensorflow-lite + PUBLIC + ${TFLITE_INCLUDE_DIRS} ) target_link_libraries(tensorflow-lite PUBLIC @@ -353,14 +413,24 @@ endif() # TFLITE_ENABLE_XNNPACK if(CMAKE_SYSTEM_NAME MATCHES "Android") list(APPEND TFLITE_BENCHMARK_SRCS ${TFLITE_SOURCE_DIR}/profiling/atrace_profiler.cc - ${TFLITE_SOURCE_DIR}/tools/delegates/nnapi_delegate_provider.cc ) + if(_TFLITE_ENABLE_NNAPI) + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/tools/delegates/nnapi_delegate_provider.cc + ) + endif() # _TFLITE_ENABLE_NNAPI list(APPEND TFLITE_BENCHMARK_LIBS ${ANDROID_LOG_LIB} absl::strings ) endif() # Android +if(TFLITE_ENABLE_GPU) + list(APPEND TFLITE_BENCHMARK_SRCS + ${TFLITE_SOURCE_DIR}/tools/delegates/gpu_delegate_provider.cc + ) +endif() # TFLITE_ENABLE_GPU + add_executable(benchmark_model EXCLUDE_FROM_ALL ${TFLITE_BENCHMARK_SRCS} diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 85140289ac1..a37607f6260 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -24,7 +24,8 @@ extern "C" { #endif // __cplusplus // The enum for builtin operators. -// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. +// Note: CUSTOM, DELEGATE, and PLACEHOLDER_FOR_GREATER_OP_CODES are 3 special +// ops which are not real built-in ops. typedef enum { kTfLiteBuiltinAdd = 0, kTfLiteBuiltinAveragePool2d = 1, @@ -153,6 +154,7 @@ typedef enum { kTfLiteBuiltinDensify = 124, kTfLiteBuiltinSegmentSum = 125, kTfLiteBuiltinBatchMatmul = 126, + kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/c/c_api_experimental.h b/tensorflow/lite/c/c_api_experimental.h index 8d635f32a3a..bfbdd9c8fdd 100644 --- a/tensorflow/lite/c/c_api_experimental.h +++ b/tensorflow/lite/c/c_api_experimental.h @@ -24,6 +24,8 @@ extern "C" { #endif // __cplusplus /// Resets all variable tensors to zero. +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors( TfLiteInterpreter* interpreter); @@ -39,6 +41,8 @@ TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors( /// /// Code that uses this function should NOT call /// `TfLiteInterpreterOptionsSetOpResolver' on the same options object. +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp( TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op, const TfLiteRegistration* registration, int32_t min_version, @@ -56,6 +60,8 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp( /// /// Code that uses this function should NOT call /// `TfLiteInterpreterOptionsSetOpResolver' on the same options object. +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( TfLiteInterpreterOptions* options, const char* name, const TfLiteRegistration* registration, int32_t min_version, @@ -72,6 +78,8 @@ TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp( /// Code that uses this function should NOT call /// `TfLiteInterpreterOptionsAddBuiltin' or /// `TfLiteInterpreterOptionsAddCustomOp' on the same options object. +/// +/// WARNING: This is an experimental API and subject to change. void TfLiteInterpreterOptionsSetOpResolver( TfLiteInterpreterOptions* options, const TfLiteRegistration* (*find_builtin_op)(void* user_data, @@ -99,11 +107,15 @@ void TfLiteInterpreterOptionsSetOpResolver( /// /// NOTE: The client *must* explicitly allocate tensors before attempting to /// access input tensor data or invoke the interpreter. +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT extern TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel* model, const TfLiteInterpreterOptions* options); /// Enable or disable the NN API for the interpreter (true to enable). +/// +/// WARNING: This is an experimental API and subject to change. TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI( TfLiteInterpreterOptions* options, bool enable); diff --git a/tensorflow/lite/c/common.h b/tensorflow/lite/c/common.h index 31405dfb998..8917c254825 100644 --- a/tensorflow/lite/c/common.h +++ b/tensorflow/lite/c/common.h @@ -421,7 +421,7 @@ typedef struct TfLiteCustomAllocation { size_t bytes; } TfLiteCustomAllocation; -// An tensor in the interpreter system which is a wrapper around a buffer of +// A tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). #ifndef TF_LITE_STATIC_MEMORY typedef struct TfLiteTensor { diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index 451117106fe..93661996e9e 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -3,16 +3,14 @@ load("//tensorflow/lite/micro:build_def.bzl", "micro_copts") load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( - default_visibility = ["//visibility:public"], + default_visibility = ["//visibility:private"], licenses = ["notice"], # Apache 2.0 ) cc_library( name = "api", srcs = [ - "error_reporter.cc", "flatbuffer_conversions.cc", - "op_resolver.cc", "tensor_utils.cc", ], hdrs = [ @@ -24,16 +22,67 @@ cc_library( ], compatible_with = get_compatible_with_portable(), copts = tflite_copts() + micro_copts(), + visibility = ["//visibility:public"], deps = [ + ":error_reporter", + ":op_resolver", "@flatbuffers//:runtime_cc", "//tensorflow/lite/c:common", # TODO(b/158301698): consider moving internal:compatibility to a more # central location. "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", ], ) +# We define separate targets for "op_resolver" and "error_reporter", +# even though those headers are also exported by the "api" target, +# so that targets which only want to depend on these small abstract base +# class modules can express more fine-grained dependencies without +# pulling in tensor_utils and flatbuffer_conversions. + +cc_library( + name = "op_resolver", + srcs = ["op_resolver.cc"], + hdrs = ["op_resolver.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + micro_copts(), + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], + deps = [ + ":error_reporter", + "//tensorflow/lite/c:common", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", + "@flatbuffers//:runtime_cc", + ], +) + +cc_library( + name = "error_reporter", + srcs = ["error_reporter.cc"], + hdrs = ["error_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + micro_copts(), + visibility = [ + "//tensorflow/lite:__subpackages__", + "//tensorflow_lite_support:__subpackages__", + ], + deps = [], +) + +cc_library( + name = "verifier", + hdrs = ["verifier.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts() + micro_copts(), + visibility = ["//visibility:public"], + deps = [":error_reporter"], +) + cc_test( name = "error_reporter_test", size = "small", @@ -50,6 +99,7 @@ cc_test( srcs = ["op_resolver_test.cc"], deps = [ ":api", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 4049aea485c..77621c3f2fd 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -803,6 +803,8 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_DENSIFY: case BuiltinOperator_SEGMENT_SUM: return kTfLiteOk; + case BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES: + return kTfLiteError; } return kTfLiteError; } // NOLINT[readability/fn_size] diff --git a/tensorflow/lite/core/api/op_resolver.cc b/tensorflow/lite/core/api/op_resolver.cc index c239d9ed23e..c5dffb63549 100644 --- a/tensorflow/lite/core/api/op_resolver.cc +++ b/tensorflow/lite/core/api/op_resolver.cc @@ -18,6 +18,7 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { @@ -26,7 +27,7 @@ TfLiteStatus GetRegistrationFromOpCode( ErrorReporter* error_reporter, const TfLiteRegistration** registration) { TfLiteStatus status = kTfLiteOk; *registration = nullptr; - auto builtin_code = opcode->builtin_code(); + auto builtin_code = GetBuiltinCode(opcode); int version = opcode->version(); if (builtin_code > BuiltinOperator_MAX || diff --git a/tensorflow/lite/core/api/op_resolver_test.cc b/tensorflow/lite/core/api/op_resolver_test.cc index 4dfca5c971a..44acc92ba8c 100644 --- a/tensorflow/lite/core/api/op_resolver_test.cc +++ b/tensorflow/lite/core/api/op_resolver_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace { diff --git a/tensorflow/lite/core/api/verifier.h b/tensorflow/lite/core/api/verifier.h new file mode 100644 index 00000000000..ca1cfb044bd --- /dev/null +++ b/tensorflow/lite/core/api/verifier.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// Abstract interface for verifying a model. +#ifndef TENSORFLOW_LITE_CORE_API_VERIFIER_H_ +#define TENSORFLOW_LITE_CORE_API_VERIFIER_H_ + +#include "tensorflow/lite/core/api/error_reporter.h" + +namespace tflite { + +/// Abstract interface that verifies whether a given model is legit. +/// It facilitates the use-case to verify and build a model without loading it +/// twice. +/// (See also "tensorflow/lite/tools/verifier.h".) +class TfLiteVerifier { + public: + /// Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_CORE_API_VERIFIER_H_ diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index ebbf20706ca..2b9246a1100 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -86,6 +86,7 @@ template bool HasDynamicTensorImpl(const TfLiteContext& context, const TensorIntArray& int_array) { for (int i : int_array) { + if (i == kTfLiteOptionalTensor) continue; const TfLiteTensor& tensor = context.tensors[i]; if (tensor.allocation_type == kTfLiteDynamic) { return true; diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD index 955870b6281..d106ae4a738 100644 --- a/tensorflow/lite/delegates/BUILD +++ b/tensorflow/lite/delegates/BUILD @@ -86,6 +86,7 @@ cc_test( "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/testing:util", "//third_party/eigen3", "@com_google_googletest//:gtest", diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc index 5d8c46f814b..b70ebdcc3aa 100644 --- a/tensorflow/lite/delegates/delegate_test.cc +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 13deea6cede..098159d9d26 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -129,8 +129,6 @@ tf_cc_test( srcs = ["delegate_test.cc"], tags = [ "no_gpu", # GPU + flex is not officially supported. - "no_mac", # b/169276266 - "nomsan", # b/169431678 ], deps = [ ":delegate", @@ -294,6 +292,6 @@ cc_library( deps = [ "//tensorflow/core:portable_gif_internal", "//tensorflow/core:portable_jpeg_internal", - "//tensorflow/core:portable_png_internal", + "//tensorflow/core/lib/png:png_io", ], ) diff --git a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc index 3728c8d24c5..eee1c99ed58 100644 --- a/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc +++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc @@ -68,6 +68,10 @@ const std::set& GetFlexAllowlist() { "AvgPoolGrad", "BatchMatMul", "BatchMatMulV2", + "BatchMatrixDiag", + "BatchMatrixDiagPart", + "BatchMatrixInverse", + "BatchMatrixSetDiag", "BatchNormWithGlobalNormalization", "BatchNormWithGlobalNormalizationGrad", "BatchToSpace", @@ -76,7 +80,19 @@ const std::set& GetFlexAllowlist() { "BiasAddGrad", "BiasAddV1", "Bincount", + "Bitcast", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", "BoostedTreesBucketize", + "BoostedTreesCreateQuantileStreamResource", + "BoostedTreesFlushQuantileSummaries", + "BoostedTreesMakeQuantileSummaries", + "BoostedTreesQuantileStreamResourceAddSummaries", + "BoostedTreesQuantileStreamResourceDeserialize", + "BoostedTreesQuantileStreamResourceFlush", + "BoostedTreesQuantileStreamResourceGetBucketBoundaries", + "BoostedTreesQuantileStreamResourceHandleOp", "BroadcastArgs", "BroadcastGradientArgs", "BroadcastTo", @@ -86,6 +102,7 @@ const std::set& GetFlexAllowlist() { "Cast", "Ceil", "CheckNumerics", + "CheckNumericsV2", "CombinedNonMaxSuppression", "Complex", "ComplexAbs", @@ -100,6 +117,9 @@ const std::set& GetFlexAllowlist() { "Conv2DBackpropFilter", "Conv2DBackpropInput", "Conv3D", + "Conv3DBackpropFilter", + "Conv3DBackpropFilterV2", + "Conv3DBackpropInput", "Conv3DBackpropInputV2", "Cos", "Cosh", @@ -108,6 +128,7 @@ const std::set& GetFlexAllowlist() { "CropAndResizeGradImage", "Cumprod", "Cumsum", + "CumulativeLogsumexp", "DataFormatDimMap", "DataFormatVecPermute", "DebugGradientIdentity", @@ -129,7 +150,10 @@ const std::set& GetFlexAllowlist() { "Dequantize", "DestroyTemporaryVariable", "Diag", + "DiagPart", "Dilation2D", + "Dilation2DBackpropFilter", + "Dilation2DBackpropInput", "Div", "DivNoNan", "DynamicPartition", @@ -138,6 +162,7 @@ const std::set& GetFlexAllowlist() { "Elu", "EluGrad", "Empty", + "EmptyTensorList", "EncodeBase64", "EncodeJpeg", "EncodeJpegVariableQuality", @@ -183,6 +208,7 @@ const std::set& GetFlexAllowlist() { "GetSessionTensor", "Greater", "GreaterEqual", + "HistogramSummary", "IFFT", "IFFT2D", "IFFT3D", @@ -193,6 +219,7 @@ const std::set& GetFlexAllowlist() { "IdentityN", "Imag", "ImageProjectiveTransformV2", + "ImageProjectiveTransformV3", "ImmutableConst", "InTopK", "InTopKV2", @@ -201,13 +228,16 @@ const std::set& GetFlexAllowlist() { "InplaceUpdate", "Inv", "InvGrad", + "Invert", "InvertPermutation", + "IsBoostedTreesQuantileStreamResourceInitialized", "IsFinite", "IsNan", "IsVariableInitialized", "LRN", "LeakyRelu", "LeakyReluGrad", + "LeftShift", "Less", "LessEqual", "LinSpace", @@ -220,6 +250,9 @@ const std::set& GetFlexAllowlist() { "LoopCond", "MatMul", "MatrixDiag", + "MatrixDiagPart", + "MatrixDiagPartV2", + "MatrixDiagPartV3", "MatrixDiagV2", "MatrixDiagV3", "MatrixInverse", @@ -229,6 +262,8 @@ const std::set& GetFlexAllowlist() { "Max", "MaxPool", "MaxPool3D", + "MaxPool3DGrad", + "MaxPool3DGradGrad", "MaxPoolGrad", "MaxPoolGradGrad", "MaxPoolGradGradV2", @@ -239,6 +274,7 @@ const std::set& GetFlexAllowlist() { "Maximum", "Mean", "Merge", + "MergeSummary", "MergeV2Checkpoints", "Mfcc", "Min", @@ -255,6 +291,7 @@ const std::set& GetFlexAllowlist() { "NonMaxSuppressionV2", "NonMaxSuppressionV3", "NonMaxSuppressionV4", + "NonMaxSuppressionV5", "NonMaxSuppressionWithOverlaps", "NotEqual", "OneHot", @@ -264,15 +301,18 @@ const std::set& GetFlexAllowlist() { "PadV2", "PaddingFIFOQueue", "PaddingFIFOQueueV2", + "ParallelConcat", "ParallelDynamicStitch", "ParseExample", "ParseExampleV2", "ParseSequenceExample", + "ParseSequenceExampleV2", "ParseSingleExample", "ParseSingleSequenceExample", "Placeholder", "PlaceholderV2", "PlaceholderWithDefault", + "PopulationCount", "Pow", "PreventGradient", "Print", @@ -314,10 +354,13 @@ const std::set& GetFlexAllowlist() { "RFFT2D", "RFFT3D", "RaggedBincount", + "RaggedGather", "RaggedRange", "RaggedTensorToSparse", "RaggedTensorToTensor", "RandomGamma", + "RandomPoisson", + "RandomPoissonV2", "RandomStandardNormal", "RandomUniform", "RandomUniformInt", @@ -327,6 +370,7 @@ const std::set& GetFlexAllowlist() { "RealDiv", "Reciprocal", "ReciprocalGrad", + "Recv", "ReduceJoin", "RefEnter", "RefExit", @@ -354,22 +398,31 @@ const std::set& GetFlexAllowlist() { "ResourceApplyAdagradDA", "ResourceApplyAdagradV2", "ResourceApplyAdam", + "ResourceApplyAdamWithAmsgrad", "ResourceApplyAddSign", "ResourceApplyCenteredRMSProp", "ResourceApplyFtrl", "ResourceApplyFtrlV2", "ResourceApplyGradientDescent", + "ResourceApplyKerasMomentum", "ResourceApplyMomentum", "ResourceApplyPowerSign", "ResourceApplyProximalAdagrad", "ResourceApplyProximalGradientDescent", "ResourceApplyRMSProp", + "ResourceScatterNdAdd", + "ResourceScatterNdMax", + "ResourceScatterNdMin", + "ResourceScatterNdSub", + "ResourceScatterNdUpdate", "ResourceSparseApplyAdadelta", "ResourceSparseApplyAdagrad", "ResourceSparseApplyAdagradDA", + "ResourceSparseApplyAdagradV2", "ResourceSparseApplyCenteredRMSProp", "ResourceSparseApplyFtrl", "ResourceSparseApplyFtrlV2", + "ResourceSparseApplyKerasMomentum", "ResourceSparseApplyMomentum", "ResourceSparseApplyProximalAdagrad", "ResourceSparseApplyProximalGradientDescent", @@ -381,14 +434,23 @@ const std::set& GetFlexAllowlist() { "Reverse", "ReverseSequence", "ReverseV2", + "RightShift", "Round", "Rsqrt", "RsqrtGrad", + "SampleDistortedBoundingBox", "SampleDistortedBoundingBoxV2", "Save", "SaveSlices", "SaveV2", + "ScalarSummary", "ScatterNd", + "ScatterNdAdd", + "ScatterNdMax", + "ScatterNdMin", + "ScatterNdNonAliasingAdd", + "ScatterNdSub", + "ScatterNdUpdate", "SegmentMax", "SegmentMean", "SegmentMin", @@ -398,6 +460,7 @@ const std::set& GetFlexAllowlist() { "SelectV2", "Selu", "SeluGrad", + "Send", "Shape", "ShapeN", "ShardedFilename", @@ -421,6 +484,7 @@ const std::set& GetFlexAllowlist() { "SparseApplyAdadelta", "SparseApplyAdagrad", "SparseApplyAdagradDA", + "SparseApplyAdagradV2", "SparseApplyCenteredRMSProp", "SparseApplyFtrl", "SparseApplyFtrlV2", @@ -459,12 +523,14 @@ const std::set& GetFlexAllowlist() { "StackPush", "StackPushV2", "StackV2", + "StatelessMultinomial", "StatelessRandomGammaV2", "StatelessRandomNormal", "StatelessRandomPoisson", "StatelessRandomUniform", "StatelessRandomUniformFullInt", "StatelessRandomUniformInt", + "StatelessSampleDistortedBoundingBox", "StatelessTruncatedNormal", "StaticRegexReplace", "StopGradient", @@ -475,6 +541,7 @@ const std::set& GetFlexAllowlist() { "StringLower", "StringSplit", "StringSplitV2", + "StringStrip", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong", @@ -520,11 +587,31 @@ const std::set& GetFlexAllowlist() { "TensorArrayWrite", "TensorArrayWriteV2", "TensorArrayWriteV3", + "TensorListConcat", + "TensorListConcatLists", + "TensorListConcatV2", + "TensorListElementShape", + "TensorListFromTensor", + "TensorListGather", + "TensorListGetItem", + "TensorListLength", + "TensorListPopBack", + "TensorListPushBack", + "TensorListPushBackBatch", + "TensorListReserve", + "TensorListResize", + "TensorListScatter", + "TensorListScatterIntoExistingList", + "TensorListScatterV2", + "TensorListSetItem", + "TensorListSplit", + "TensorListStack", "TensorScatterAdd", "TensorScatterMax", "TensorScatterMin", "TensorScatterSub", "TensorScatterUpdate", + "TensorStridedSliceUpdate", "Tile", "TileGrad", "Timestamp", @@ -546,21 +633,30 @@ const std::set& GetFlexAllowlist() { "UnsortedSegmentMin", "UnsortedSegmentProd", "UnsortedSegmentSum", + "UnwrapDatasetVariant", "Variable", "VariableV2", "Where", + "WrapDatasetVariant", "Xdivy", + "Xlog1py", "Xlogy", "ZerosLike", "_Arg", "_ArrayToList", + "_DeviceArg", + "_DeviceRetval", + "_FusedConv2D", "_HostCast", "_HostRecv", "_HostSend", "_ListToArray", + "_ParallelConcatStart", + "_ParallelConcatUpdate", "_Recv", "_Retval", "_Send", + "_SwitchN", // go/keep-sorted end }); return *allowlisted_flex_ops; diff --git a/tensorflow/lite/delegates/flex/delegate_test.cc b/tensorflow/lite/delegates/flex/delegate_test.cc index e50ea1d9b0e..6450848bf0e 100644 --- a/tensorflow/lite/delegates/flex/delegate_test.cc +++ b/tensorflow/lite/delegates/flex/delegate_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/flex/delegate.h" #include +#include #include #include @@ -308,7 +309,7 @@ TEST_F(DelegateTest, MultiThreaded) { TEST_F(DelegateTest, TF_AcquireFlexDelegate) { auto TF_AcquireFlexDelegate = reinterpret_cast( - SharedLibrary::GetLibrarySymbol(nullptr, "TF_AcquireFlexDelegate")); + SharedLibrary::GetSymbol("TF_AcquireFlexDelegate")); ASSERT_TRUE(TF_AcquireFlexDelegate); auto delegate_ptr = TF_AcquireFlexDelegate(); ASSERT_TRUE(delegate_ptr != nullptr); @@ -361,6 +362,7 @@ TEST_F(DelegateTest, StaticOutputRFFT) { // Define inputs. SetShape(0, {3, 512}); + SetValues(0, std::vector(3 * 512, 1.0f)); ASSERT_TRUE(Invoke()); diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 28b6cc7671a..c616b081829 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -286,7 +286,7 @@ cc_library( ":cl_command_queue", ":cl_context", ":cl_device", - ":cl_kernel", + ":device_info", ":precision", ":program_cache", ":tensor", diff --git a/tensorflow/lite/delegates/gpu/cl/api.cc b/tensorflow/lite/delegates/gpu/cl/api.cc index 01d32aa9206..dfd8a8680fc 100644 --- a/tensorflow/lite/delegates/gpu/cl/api.cc +++ b/tensorflow/lite/delegates/gpu/cl/api.cc @@ -585,6 +585,8 @@ class InferenceBuilderImpl : public InferenceBuilder { if (options.usage == InferenceUsage::FAST_SINGLE_ANSWER) { create_info.hints.Add(ModelHints::kReduceKernelsCount); create_info.hints.Add(ModelHints::kFastTuning); + } else if (options.usage == InferenceUsage::SUSTAINED_SPEED) { + create_info.hints.Add(ModelHints::kAllowSpecialKernels); } RETURN_IF_ERROR(context_->InitFromGraph(create_info, graph, environment_)); @@ -601,8 +603,8 @@ class InferenceBuilderImpl : public InferenceBuilder { absl::make_unique(environment_, context_.get()); #endif - inputs_ = LinkTensors(graph, graph.inputs()); - outputs_ = LinkTensors(graph, graph.outputs()); + inputs_ = LinkTensors(context_->GetInputIds(), AccessType::READ); + outputs_ = LinkTensors(context_->GetOutputIds(), AccessType::WRITE); return absl::OkStatus(); } @@ -675,12 +677,13 @@ class InferenceBuilderImpl : public InferenceBuilder { if (GetRelativeImportance(options, InferencePriority::MIN_LATENCY, InferencePriority::MIN_MEMORY_USAGE) == PriorityImportance::HIGHER) { - preferred_storage_types = {GetFastestStorageType(environment_->device()), - TensorStorageType::BUFFER}; - } else { preferred_storage_types = { - GetStorageTypeWithMinimalMemoryConsumption(environment_->device()), + GetFastestStorageType(environment_->device().GetInfo()), TensorStorageType::BUFFER}; + } else { + preferred_storage_types = {GetStorageTypeWithMinimalMemoryConsumption( + environment_->device().GetInfo()), + TensorStorageType::BUFFER}; } for (TensorStorageType storage_type : preferred_storage_types) { @@ -718,15 +721,13 @@ class InferenceBuilderImpl : public InferenceBuilder { } // Links internal tensors with external user-facing objects. - std::vector LinkTensors(const GraphFloat32& graph, - const std::vector& values) { + std::vector LinkTensors(const std::vector& ids, + AccessType access) { std::vector links; - links.reserve(values.size()); - for (const auto& value : values) { - TensorObjectDef def = TensorToDef(*context_->GetTensor(value->id)); - AccessType access = - graph.IsGraphInput(value->id) ? AccessType::READ : AccessType::WRITE; - links.push_back({value->id, access, def, def}); + links.reserve(ids.size()); + for (const auto& id : ids) { + TensorObjectDef def = TensorToDef(*context_->GetTensor(id)); + links.push_back({id, access, def, def}); } return links; } diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc index b7e6b08616e..526a09f18f9 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.cc +++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc @@ -843,11 +843,16 @@ absl::Status Arguments::AllocateObjects(CLContext* context) { for (auto& t : objects_) { RETURN_IF_ERROR( t.second.descriptor->CreateGPUObject(context, &t.second.obj_ptr)); - t.second.descriptor->Release(); } return absl::OkStatus(); } +void Arguments::ReleaseCPURepresentation() { + for (auto& t : objects_) { + t.second.descriptor->Release(); + } +} + absl::Status Arguments::AddObjectArgs() { for (auto& t : objects_) { AddGPUResources(t.first, t.second.descriptor->GetGPUResources()); diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h index 4636a06db6f..2f87f8ae6a0 100644 --- a/tensorflow/lite/delegates/gpu/cl/arguments.h +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -33,7 +33,15 @@ namespace tflite { namespace gpu { namespace cl { -class Arguments { +class ArgumentsBinder { + public: + virtual absl::Status SetInt(const std::string& name, int value) = 0; + virtual absl::Status SetFloat(const std::string& name, float value) = 0; + virtual absl::Status SetHalf(const std::string& name, half value) = 0; + virtual ~ArgumentsBinder() = default; +}; + +class Arguments : public ArgumentsBinder { public: Arguments() = default; void AddFloat(const std::string& name, float value = 0.0f); @@ -44,9 +52,9 @@ class Arguments { void AddObject(const std::string& name, GPUObjectDescriptorPtr&& descriptor_ptr); - absl::Status SetInt(const std::string& name, int value); - absl::Status SetFloat(const std::string& name, float value); - absl::Status SetHalf(const std::string& name, half value); + absl::Status SetInt(const std::string& name, int value) override; + absl::Status SetFloat(const std::string& name, float value) override; + absl::Status SetHalf(const std::string& name, half value) override; absl::Status SetObjectRef(const std::string& name, const GPUObject* object); absl::Status Bind(cl_kernel kernel, int offset = 0); @@ -55,6 +63,7 @@ class Arguments { absl::Status Merge(Arguments&& args, const std::string& postfix); absl::Status AllocateObjects(CLContext* context); + void ReleaseCPURepresentation(); absl::Status TransformToCLCode( const DeviceInfo& device_info, const std::map& linkables, std::string* code); @@ -65,6 +74,8 @@ class Arguments { Arguments(const Arguments&) = delete; Arguments& operator=(const Arguments&) = delete; + ~Arguments() override = default; + private: void AddBuffer(const std::string& name, const GPUBufferDescriptor& desc); void AddImage2D(const std::string& name, const GPUImage2DDescriptor& desc); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc index 5f27d9827a2..10937cfc56b 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.cc @@ -56,14 +56,15 @@ void CLCommandQueue::Release() { } } -absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, - CLEvent* event) { +absl::Status CLCommandQueue::Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size, + CLEvent* event) { std::vector local(3); std::vector global(3); for (int i = 0; i < 3; ++i) { local[i] = work_group_size[i]; - global[i] = AlignByN(grid[i], work_group_size[i]); + global[i] = work_groups_count[i] * work_group_size[i]; } cl_event resulting_event; const int error_code = clEnqueueNDRangeKernel( @@ -80,9 +81,10 @@ absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, return absl::OkStatus(); } -absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) { - return DispatchImplicit(kernel, grid, work_group_size, nullptr); +absl::Status CLCommandQueue::Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size) { + return Dispatch(kernel, work_groups_count, work_group_size, nullptr); } absl::Status CLCommandQueue::EnqueueEvent(CLEvent* event) { @@ -191,12 +193,13 @@ void ProfilingCommandQueue::SetEventsLabel(const std::string& name) { void ProfilingCommandQueue::ResetMeasurements() { events_.clear(); } -absl::Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel, - int3 grid, - int3 work_group_size) { +absl::Status ProfilingCommandQueue::Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size) { events_.push_back(CLEvent()); - RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit( - kernel, grid, work_group_size, &events_[events_.size() - 1])); + RETURN_IF_ERROR(CLCommandQueue::Dispatch(kernel, work_groups_count, + work_group_size, + &events_[events_.size() - 1])); events_.back().SetName(current_label_); return absl::OkStatus(); } @@ -213,14 +216,15 @@ ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const { } absl::Status ProfilingCommandQueue::GetBestWorkGroupIndex( - const CLKernel& kernel, const DeviceInfo& device_info, const int3& grid, + const CLKernel& kernel, const DeviceInfo& device_info, + const std::vector& work_groups_count, const std::vector& work_group_sizes, int* index) { // Some Adreno 3xx can have wrong numbers for some events const bool possible_bug_with_events = device_info.IsAdreno3xx(); events_.resize(work_group_sizes.size()); for (int i = 0; i < work_group_sizes.size(); ++i) { - RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit( - kernel, grid, work_group_sizes[i], &events_[i])); + RETURN_IF_ERROR(CLCommandQueue::Dispatch(kernel, work_groups_count[i], + work_group_sizes[i], &events_[i])); // reducing the speed of memory leak on Mali for some kernels if (device_info.IsMali() && i % 8 == 7) { @@ -330,7 +334,11 @@ absl::Duration ProfilingInfo::GetTotalTime() const { std::string ProfilingInfo::GetDetailedReport() const { std::string result; - std::map timing; + struct OpStatistic { + int count; + double total_time; + }; + std::map statistics; result += "Per kernel timing(" + std::to_string(dispatches.size()) + " kernels):\n"; for (const auto& dispatch : dispatches) { @@ -338,16 +346,22 @@ std::string ProfilingInfo::GetDetailedReport() const { std::to_string(absl::ToDoubleMilliseconds(dispatch.duration)) + " ms\n"; auto name = dispatch.label.substr(0, dispatch.label.find(" ")); - if (timing.find(name) != timing.end()) { - timing[name] += absl::ToDoubleMilliseconds(dispatch.duration); + if (statistics.find(name) != statistics.end()) { + statistics[name].count++; + statistics[name].total_time += + absl::ToDoubleMilliseconds(dispatch.duration); } else { - timing[name] = absl::ToDoubleMilliseconds(dispatch.duration); + statistics[name].count = 1; + statistics[name].total_time = + absl::ToDoubleMilliseconds(dispatch.duration); } } result += "--------------------\n"; result += "Accumulated time per operation type:\n"; - for (auto& t : timing) { - result += " " + t.first + " - " + std::to_string(t.second) + " ms\n"; + for (auto& t : statistics) { + auto stat = t.second; + result += " " + t.first + "(x" + std::to_string(stat.count) + ") - " + + std::to_string(stat.total_time) + " ms\n"; } result += "--------------------\n"; result += "Ideal total time: " + diff --git a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h index 178e3b21a1e..519b87640e7 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_command_queue.h @@ -74,14 +74,15 @@ class CLCommandQueue { cl_command_queue queue() const { return queue_; } - virtual absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size); + virtual absl::Status Dispatch(const CLKernel& kernel, + const int3& work_groups_count, + const int3& work_group_size); + + absl::Status Dispatch(const CLKernel& kernel, const int3& work_groups_count, + const int3& work_group_size, CLEvent* event); absl::Status EnqueueEvent(CLEvent* event); - absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size, CLEvent* event); - absl::Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data); absl::Status EnqueueReadImage(cl_mem memory, int3 region, void* data); @@ -110,13 +111,13 @@ class ProfilingCommandQueue : public CLCommandQueue { ProfilingCommandQueue(const ProfilingCommandQueue&) = delete; ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete; - absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid, - int3 work_group_size) override; + absl::Status Dispatch(const CLKernel& kernel, const int3& work_groups_count, + const int3& work_group_size) override; // will write index for fastest work_group among work_group_sizes absl::Status GetBestWorkGroupIndex(const CLKernel& kernel, const DeviceInfo& device_info, - const int3& grid, + const std::vector& work_groups_count, const std::vector& work_group_sizes, int* index); diff --git a/tensorflow/lite/delegates/gpu/cl/cl_device.h b/tensorflow/lite/delegates/gpu/cl/cl_device.h index e7cd274661d..79335a61aff 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_device.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_device.h @@ -73,6 +73,7 @@ class CLDevice { bool SupportsOneLayerTextureArray() const; void DisableOneLayerTextureArray(); + const DeviceInfo& GetInfo() const { return info_; } // We update device info during context creation, so as supported texture // formats can be requested from context only. mutable DeviceInfo info_; diff --git a/tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h b/tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h new file mode 100644 index 00000000000..8a12bf2a9db --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_COMPILEDPROGRAMCACHE_TFLITE_GPU_CL_DATA_H_ +#define FLATBUFFERS_GENERATED_COMPILEDPROGRAMCACHE_TFLITE_GPU_CL_DATA_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace data { + +struct Program; + +struct CompiledCache; + +struct Program FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FINGERPRINT = 4, + VT_BINARY = 6 + }; + uint64_t fingerprint() const { + return GetField(VT_FINGERPRINT, 0); + } + const flatbuffers::Vector *binary() const { + return GetPointer *>(VT_BINARY); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FINGERPRINT) && + VerifyOffset(verifier, VT_BINARY) && + verifier.VerifyVector(binary()) && + verifier.EndTable(); + } +}; + +struct ProgramBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fingerprint(uint64_t fingerprint) { + fbb_.AddElement(Program::VT_FINGERPRINT, fingerprint, 0); + } + void add_binary(flatbuffers::Offset> binary) { + fbb_.AddOffset(Program::VT_BINARY, binary); + } + explicit ProgramBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ProgramBuilder &operator=(const ProgramBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateProgram( + flatbuffers::FlatBufferBuilder &_fbb, + uint64_t fingerprint = 0, + flatbuffers::Offset> binary = 0) { + ProgramBuilder builder_(_fbb); + builder_.add_fingerprint(fingerprint); + builder_.add_binary(binary); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateProgramDirect( + flatbuffers::FlatBufferBuilder &_fbb, + uint64_t fingerprint = 0, + const std::vector *binary = nullptr) { + auto binary__ = binary ? _fbb.CreateVector(*binary) : 0; + return tflite::gpu::cl::data::CreateProgram( + _fbb, + fingerprint, + binary__); +} + +struct CompiledCache FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DRIVER_VERSION = 4, + VT_PROGRAMS = 6 + }; + const flatbuffers::String *driver_version() const { + return GetPointer(VT_DRIVER_VERSION); + } + const flatbuffers::Vector> *programs() const { + return GetPointer> *>(VT_PROGRAMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DRIVER_VERSION) && + verifier.VerifyString(driver_version()) && + VerifyOffset(verifier, VT_PROGRAMS) && + verifier.VerifyVector(programs()) && + verifier.VerifyVectorOfTables(programs()) && + verifier.EndTable(); + } +}; + +struct CompiledCacheBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_driver_version(flatbuffers::Offset driver_version) { + fbb_.AddOffset(CompiledCache::VT_DRIVER_VERSION, driver_version); + } + void add_programs(flatbuffers::Offset>> programs) { + fbb_.AddOffset(CompiledCache::VT_PROGRAMS, programs); + } + explicit CompiledCacheBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CompiledCacheBuilder &operator=(const CompiledCacheBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCompiledCache( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset driver_version = 0, + flatbuffers::Offset>> programs = 0) { + CompiledCacheBuilder builder_(_fbb); + builder_.add_programs(programs); + builder_.add_driver_version(driver_version); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateCompiledCacheDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *driver_version = nullptr, + const std::vector> *programs = nullptr) { + auto driver_version__ = driver_version ? _fbb.CreateString(driver_version) : 0; + auto programs__ = programs ? _fbb.CreateVector>(*programs) : 0; + return tflite::gpu::cl::data::CreateCompiledCache( + _fbb, + driver_version__, + programs__); +} + +inline const tflite::gpu::cl::data::CompiledCache *GetCompiledCache(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const tflite::gpu::cl::data::CompiledCache *GetSizePrefixedCompiledCache(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline const char *CompiledCacheIdentifier() { + return "AFCM"; +} + +inline bool CompiledCacheBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier( + buf, CompiledCacheIdentifier()); +} + +inline bool VerifyCompiledCacheBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(CompiledCacheIdentifier()); +} + +inline bool VerifySizePrefixedCompiledCacheBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(CompiledCacheIdentifier()); +} + +inline const char *CompiledCacheExtension() { + return "jetbin"; +} + +inline void FinishCompiledCacheBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root, CompiledCacheIdentifier()); +} + +inline void FinishSizePrefixedCompiledCacheBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root, CompiledCacheIdentifier()); +} + +} // namespace data +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // FLATBUFFERS_GENERATED_COMPILEDPROGRAMCACHE_TFLITE_GPU_CL_DATA_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc index 785e88299a7..5b06b307133 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.cc +++ b/tensorflow/lite/delegates/gpu/cl/environment.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" #include "tensorflow/lite/delegates/gpu/cl/util.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" @@ -26,59 +25,6 @@ namespace tflite { namespace gpu { namespace cl { namespace { - -std::string GetKernelOneLayerTextureArray() { - return R"( - -__kernel void main_function(__write_only image2d_array_t dst) { - int X = (int)(get_global_id(0)); - int Y = (int)(get_global_id(1)); - - write_imagef(dst, (int4)(X, Y, 0, 0), (float4)(2.0, 2.0, 2.0, 2.0)); -} -)"; -} - -// Some Adreno < 600 have bug with one layer texture array. b/131099086 -// If we have one layer texture array and will write smt from kernel to this -// texture, we will get zeroes instead of actual values. -// The same kernel will work, if we use texture array with more than one layer. -// With help of this code we can detect this bug. -absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env, - bool* result) { - // No bug on Adreno 6xx - if (env->device().info_.adreno_info.gpu_version >= 600) { - *result = true; - return absl::OkStatus(); - } - CLKernel kernel; - RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel( - GetKernelOneLayerTextureArray(), "main_function", env->context(), - env->device(), &kernel)); - - Tensor tensor; - const BHWC shape(1, 4, 4, 4); - RETURN_IF_ERROR(CreateTensor( - env->context(), shape, - {DataType::FLOAT32, TensorStorageType::TEXTURE_ARRAY, Layout::HWC}, - &tensor)); - RETURN_IF_ERROR(kernel.SetMemory(0, tensor.GetMemoryPtr())); - RETURN_IF_ERROR(env->queue()->DispatchImplicit(kernel, {4, 4, 1}, {4, 4, 1})); - TensorFloat32 tensor_gpu; - tensor_gpu.shape = shape; - tensor_gpu.data.resize(shape.DimensionsProduct()); - RETURN_IF_ERROR(tensor.ReadData(env->queue(), &tensor_gpu)); - - *result = true; - for (int i = 0; i < 64; ++i) { - if (tensor_gpu.data[i] != 2.0) { - *result = false; - break; - } - } - return absl::OkStatus(); -} - absl::Status CreateEnvironment(Environment* result, bool shared, cl_context_properties egl_context, cl_context_properties egl_display) { @@ -99,16 +45,7 @@ absl::Status CreateEnvironment(Environment* result, bool shared, *result = Environment(std::move(gpu), std::move(context), std::move(queue), std::move(profiling_queue)); - if (result->device().IsAdreno() && result->device().SupportsTextureArray()) { - bool supports_one_layer; - RETURN_IF_ERROR( - CheckKernelSupportOfOneLayerTextureArray(result, &supports_one_layer)); - if (!supports_one_layer) { - result->GetDevicePtr()->DisableOneLayerTextureArray(); - } - } - - return absl::OkStatus(); + return result->Init(); } } // namespace @@ -141,10 +78,12 @@ Environment& Environment::operator=(Environment&& environment) { absl::Status Environment::Init() { if (device().IsAdreno() && device().SupportsTextureArray()) { - bool supports_one_layer; - RETURN_IF_ERROR( - CheckKernelSupportOfOneLayerTextureArray(this, &supports_one_layer)); - if (!supports_one_layer) { + // Some Adreno < 600 have bug with one layer texture array. b/131099086 + // If we have one layer texture array and will write smt from kernel to this + // texture, we will get zeroes instead of actual values. + // The same kernel will work, if we use texture array with more than one + // layer. + if (device().info_.adreno_info.gpu_version < 600) { GetDevicePtr()->DisableOneLayerTextureArray(); } } @@ -232,54 +171,54 @@ bool Environment::IsSupported(TensorStorageType storage_type) const { return false; } -TensorStorageType GetFastestStorageType(const CLDevice& gpu) { - if (gpu.IsAdreno()) { - if (gpu.IsAdreno6xxOrHigher()) { +TensorStorageType GetFastestStorageType(const DeviceInfo& gpu_info) { + if (gpu_info.IsAdreno()) { + if (gpu_info.IsAdreno6xxOrHigher()) { return TensorStorageType::TEXTURE_ARRAY; } else { return TensorStorageType::TEXTURE_2D; } - } else if (gpu.IsPowerVR()) { + } else if (gpu_info.IsPowerVR()) { return TensorStorageType::TEXTURE_2D; - } else if (gpu.IsMali()) { - const MaliInfo mali_info = gpu.info_.mali_info; + } else if (gpu_info.IsMali()) { + const MaliInfo mali_info = gpu_info.mali_info; if (mali_info.IsMaliT8xx() || mali_info.IsBifrostGen3() || mali_info.IsValhall()) { return TensorStorageType::TEXTURE_2D; } else { return TensorStorageType::BUFFER; } - } else if (gpu.IsNvidia()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsAMD()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsIntel()) { + } else if (gpu_info.IsNvidia()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsAMD()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsIntel()) { return TensorStorageType::BUFFER; } return TensorStorageType::BUFFER; } TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( - const CLDevice& gpu) { - if (gpu.IsAdreno()) { - if (gpu.IsAdreno3xx() || gpu.IsAdreno4xx()) { + const DeviceInfo& gpu_info) { + if (gpu_info.IsAdreno()) { + if (gpu_info.IsAdreno3xx() || gpu_info.IsAdreno4xx()) { return TensorStorageType::BUFFER; } else { return TensorStorageType::IMAGE_BUFFER; } - } else if (gpu.IsPowerVR()) { + } else if (gpu_info.IsPowerVR()) { return TensorStorageType::BUFFER; - } else if (gpu.IsMali()) { + } else if (gpu_info.IsMali()) { return TensorStorageType::BUFFER; - } else if (gpu.IsNvidia()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsAMD()) { - return gpu.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER - : TensorStorageType::BUFFER; - } else if (gpu.IsIntel()) { + } else if (gpu_info.IsNvidia()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsAMD()) { + return gpu_info.SupportsImageBuffer() ? TensorStorageType::IMAGE_BUFFER + : TensorStorageType::BUFFER; + } else if (gpu_info.IsIntel()) { return TensorStorageType::BUFFER; } return TensorStorageType::BUFFER; diff --git a/tensorflow/lite/delegates/gpu/cl/environment.h b/tensorflow/lite/delegates/gpu/cl/environment.h index 640f2d8cac3..1f5b4befdce 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.h +++ b/tensorflow/lite/delegates/gpu/cl/environment.h @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" #include "tensorflow/lite/delegates/gpu/cl/cl_context.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" #include "tensorflow/lite/delegates/gpu/cl/precision.h" #include "tensorflow/lite/delegates/gpu/cl/program_cache.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -75,9 +75,9 @@ class Environment { ProgramCache program_cache_; }; -TensorStorageType GetFastestStorageType(const CLDevice& gpu); +TensorStorageType GetFastestStorageType(const DeviceInfo& gpu_info); TensorStorageType GetStorageTypeWithMinimalMemoryConsumption( - const CLDevice& gpu); + const DeviceInfo& gpu_info); absl::Status CreateEnvironment(Environment* result); diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc index a6f12ea3508..b834bbfffef 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc @@ -182,6 +182,10 @@ absl::Status InferenceContext::InitFromGraph( RETURN_IF_ERROR(Compile(creation_context)); RETURN_IF_ERROR(UpdateParams()); + for (auto& node : nodes_) { + node.operation->args_.ReleaseCPURepresentation(); + } + TuningParameters tuning_parameters; tuning_parameters.queue = env->profiling_queue(); tuning_parameters.info = &env->device().info_; @@ -274,10 +278,12 @@ absl::Status InferenceContext::ConvertOperations(const DeviceInfo& device_info, if (consumed_nodes.find(node.id) != consumed_nodes.end()) { continue; } + std::string op_name = node.operation.type + " " + std::to_string(node.id); GPUOperationsSubgraph gpu_subgraph; if (hints.Check(ModelHints::kAllowSpecialKernels) && GPUSubgraphFromGraph(device_info, precision_, graph, node.id, - tensor_descriptors, &consumed_nodes, &gpu_subgraph) + tensor_descriptors, &consumed_nodes, &gpu_subgraph, + &op_name) .ok()) { // Mapping of subgraph (set of nodes) to GPU operations. Should happen // before straigtforward mapping. @@ -346,7 +352,7 @@ absl::Status InferenceContext::ConvertOperations(const DeviceInfo& device_info, cl_node.outputs[j] = mapping_to_global_ids[-(id + 1)]; } } - cl_node.name = node.operation.type + " " + std::to_string(node.id); + cl_node.name = op_name; nodes_.push_back(std::move(cl_node)); } } diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h index d30f9c49174..da687ffa3b5 100644 --- a/tensorflow/lite/delegates/gpu/cl/inference_context.h +++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h @@ -89,6 +89,9 @@ class InferenceContext { absl::Status GetOutputTensor(ValueId id, CLCommandQueue* queue, TensorFloat32* result); + const std::vector& GetInputIds() const { return input_ids_; } + const std::vector& GetOutputIds() const { return output_ids_; } + private: enum TensorMemoryType { STRONG_SHAPE = 0, BUFFER = 1, VARIABLE = 2 }; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index 5aa92e0cc81..7bce013a895 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -239,50 +239,6 @@ cc_test( ], ) -cc_library( - name = "conv_texture", - srcs = ["conv_texture.cc"], - hdrs = ["conv_texture.h"], - deps = [ - ":gpu_operation", - ":util", - ":work_group_picking", - "//tensorflow/lite/delegates/gpu/cl:cl_command_queue", - "//tensorflow/lite/delegates/gpu/cl:cl_context", - "//tensorflow/lite/delegates/gpu/cl:linear_storage", - "//tensorflow/lite/delegates/gpu/cl:precision", - "//tensorflow/lite/delegates/gpu/cl:tensor", - "//tensorflow/lite/delegates/gpu/cl:tensor_type", - "//tensorflow/lite/delegates/gpu/cl:texture2d", - "//tensorflow/lite/delegates/gpu/cl:util", - "//tensorflow/lite/delegates/gpu/common:data_type", - "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:shape", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common:tensor", - "//tensorflow/lite/delegates/gpu/common:types", - "//tensorflow/lite/delegates/gpu/common:winograd_util", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "conv_texture_test", - srcs = ["conv_texture_test.cc"], - linkstatic = True, - tags = tf_gpu_tests_tags() + [ - "linux", - "local", - ], - deps = [ - ":cl_test", - ":conv_texture", - "//tensorflow/lite/delegates/gpu/common:operations", - "//tensorflow/lite/delegates/gpu/common:status", - "@com_google_googletest//:gtest_main", - ], -) - cc_library( name = "conv_weights_converter", srcs = ["conv_weights_converter.cc"], @@ -1397,7 +1353,6 @@ test_suite( "conv_buffer_1x1_test", "conv_constants_test", "conv_powervr_test", - "conv_texture_test", "convolution_transposed_3x3_thin_test", "convolution_transposed_4x4_test", "convolution_transposed_test", diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc index 0112241117e..efe97f9931b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/cl_test.cc @@ -55,6 +55,7 @@ absl::Status ExecuteGPUOperation(const std::vector& src_cpu, RETURN_IF_ERROR(operation->Compile(creation_context)); RETURN_IF_ERROR(operation->UpdateParams()); + operation->args_.ReleaseCPURepresentation(); RETURN_IF_ERROR(operation->AddToQueue(creation_context.queue)); RETURN_IF_ERROR(creation_context.queue->WaitForCompletion()); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc index 786248e67d7..8952504bda0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.cc @@ -222,6 +222,9 @@ ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) { } void ConvPowerVR::GenerateCode(const DeviceInfo& device_info) { + if (conv_params_.linear_spatial) { + grid_dimension_ = 2; + } const bool stride_correction = definition_.IsBatchSupported() && stride_.x != 1; code_ = @@ -230,40 +233,50 @@ void ConvPowerVR::GenerateCode(const DeviceInfo& device_info) { device_info.IsPowerVR()) { compiler_options_.push_back(CompilerOptions::POWERVR_FP16); } - if (conv_params_.IsPrivateMemBroadcast()) { + if (conv_params_.IsPrivateMemBroadcast() && device_info.IsCL20OrHigher()) { compiler_options_.push_back(CompilerOptions::CL_2_0); } + bool kernel_is_trivial = + conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1; + if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { + kernel_is_trivial = kernel_is_trivial & conv_params_.z_kernel_is_1; + } + if (device_info.IsAdreno3xx() && + definition_.precision == CalculationsPrecision::F16 && + kernel_is_trivial) { + compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE); + } } -absl::Status ConvPowerVR::BindArguments() { +absl::Status ConvPowerVR::BindArguments(ArgumentsBinder* args) { if (!conv_params_.x_kernel_is_1) { - RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x)); - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x)); - RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch())); + RETURN_IF_ERROR(args->SetInt("stride_x", stride_.x)); + RETURN_IF_ERROR(args->SetInt("padding_x", padding_.x * src_[0]->Batch())); + RETURN_IF_ERROR(args->SetInt("kernel_size_x", kernel_size_.x)); + RETURN_IF_ERROR(args->SetInt("dilation_x", dilation_.x * src_[0]->Batch())); } if (!conv_params_.y_kernel_is_1) { - RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y)); - RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y)); - RETURN_IF_ERROR(args_.SetInt("dilation_y", dilation_.y)); + RETURN_IF_ERROR(args->SetInt("stride_y", stride_.y)); + RETURN_IF_ERROR(args->SetInt("padding_y", padding_.y)); + RETURN_IF_ERROR(args->SetInt("kernel_size_y", kernel_size_.y)); + RETURN_IF_ERROR(args->SetInt("dilation_y", dilation_.y)); } if (definition_.src_tensors[0].HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) { - RETURN_IF_ERROR(args_.SetInt("stride_z", stride_.z)); - RETURN_IF_ERROR(args_.SetInt("padding_z", padding_.z)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_z", kernel_size_.z)); - RETURN_IF_ERROR(args_.SetInt("dilation_z", dilation_.z)); + RETURN_IF_ERROR(args->SetInt("stride_z", stride_.z)); + RETURN_IF_ERROR(args->SetInt("padding_z", padding_.z)); + RETURN_IF_ERROR(args->SetInt("kernel_size_z", kernel_size_.z)); + RETURN_IF_ERROR(args->SetInt("dilation_z", dilation_.z)); } if (conv_params_.linear_spatial) { const int grid_x = DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), conv_params_.block_size.x); - RETURN_IF_ERROR(args_.SetInt("task_size_x", grid_x)); + RETURN_IF_ERROR(args->SetInt("task_size_x", grid_x)); } if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { const int task_size_y = DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y); - RETURN_IF_ERROR(args_.SetInt("task_size_y", task_size_y)); + RETURN_IF_ERROR(args->SetInt("task_size_y", task_size_y)); } return absl::OkStatus(); } @@ -284,23 +297,13 @@ int3 ConvPowerVR::GetGridSize() const { if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { grid_x *= task_size_z; } - wg.x = DivideRoundUp(grid_x, work_group_size_.x); - wg.y = DivideRoundUp(task_size_s, work_group_size_.y); - return int3( - wg[conv_params_.work_group_launch_order[0]] * work_group_size_.x, - wg[conv_params_.work_group_launch_order[1]] * work_group_size_.y, 1); + return int3(grid_x, task_size_s, 1); } else { int grid_y = task_size_y; if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { grid_y *= task_size_z; } - wg.x = DivideRoundUp(task_size_x, work_group_size_.x); - wg.y = DivideRoundUp(grid_y, work_group_size_.y); - wg.z = DivideRoundUp(task_size_s, work_group_size_.z); - return int3( - wg[conv_params_.work_group_launch_order[0]] * work_group_size_.x, - wg[conv_params_.work_group_launch_order[1]] * work_group_size_.y, - wg[conv_params_.work_group_launch_order[2]] * work_group_size_.z); + return int3(task_size_x, grid_y, task_size_s); } } @@ -315,14 +318,8 @@ void ConvPowerVR::GetPossibleKernelWorkGroups( work_groups->push_back(work_group_size_); return; } - if (conv_params_.work_group_launch_order[0] == 0 && - conv_params_.work_group_launch_order[1] == 1 && - conv_params_.work_group_launch_order[2] == 2) { - GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, - work_groups); - } else { - work_groups->push_back(work_group_size_); - } + GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, + work_groups); } std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, @@ -453,6 +450,8 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, if (use_simd_broadcast) { if (device_info.cl_version == OpenCLVersion::CL_2_0) { c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n"; + } else if (device_info.SupportsExtension("cl_intel_subgroups")) { + c += "#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"; } } const int4 block_size = conv_params.block_size; @@ -490,9 +489,9 @@ std::string ConvPowerVR::GenerateConv(const DeviceInfo& device_info, } c += "__kernel void main_function(\n"; c += "$0) {\n"; - c += GenerateBlockCoords( - conv_params.block_size, conv_params.work_group_launch_order, - conv_params.linear_spatial, src_def.HasAxis(Axis::DEPTH)); + c += GenerateBlockCoords(conv_params.block_size, work_group_launch_order_, + conv_params.linear_spatial, + src_def.HasAxis(Axis::DEPTH)); if (!late_oob_check) { c += " if (" + dst_oob_check + ") {\n"; c += " return;\n"; @@ -1028,12 +1027,12 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( if (device_info.IsNvidia()) { if (different_weights_for_height) { work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } else { conv_params.linear_spatial = true; work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(1, 0, 2); + work_group_launch_order_ = int3(1, 0, 2); conv_params.fixed_work_group_size = true; } conv_params.block_size = int4(2, 1, 1, 4); @@ -1073,12 +1072,12 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( } else if (device_info.IsPowerVR()) { if (different_weights_for_height) { work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } else { conv_params.linear_spatial = true; work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(1, 0, 2); + work_group_launch_order_ = int3(1, 0, 2); conv_params.fixed_work_group_size = true; } conv_params.weights_data_type = @@ -1121,11 +1120,11 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( } else if (device_info.IsAMD()) { if (different_weights_for_height) { work_group_size_ = int3(32, 1, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } else { work_group_size_ = int3(8, 4, 1); - conv_params.work_group_launch_order = int3(2, 0, 1); + work_group_launch_order_ = int3(2, 0, 1); conv_params.fixed_work_group_size = true; } @@ -1184,13 +1183,22 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( conv_params.src_depth_loop_size = 4; } work_group_size_ = int3(4, 4, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = false; conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; } else if (device_info.IsAdreno()) { - conv_params.block_size = int4(2, 2, 1, 1); + conv_params.block_size = int4(2, 2, 1, 2); + if (device_info.IsAdreno3xx()) { + if (definition.precision == CalculationsPrecision::F16) { + conv_params.block_size = int4(2, 2, 1, 2); + } else if (definition.precision == CalculationsPrecision::F32_F16) { + conv_params.block_size = int4(2, 1, 1, 2); + } else { // F32 + conv_params.block_size = int4(2, 2, 1, 1); + } + } work_group_size_ = int3(8, 2, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = false; conv_params.src_depth_loop_size = 1; if (definition.src_tensors.size() == 2) { @@ -1202,21 +1210,23 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( } else if (device_info.IsIntel()) { if (different_weights_for_height) { work_group_size_ = int3(16, 1, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = true; } else { conv_params.linear_spatial = true; work_group_size_ = int3(16, 1, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = true; } conv_params.block_size = int4(1, 1, 1, 4); conv_params.src_depth_loop_size = 1; int sub_group_size = 16; + const bool supports_subgroups = + device_info.SupportsExtension("cl_khr_subgroups") || + device_info.SupportsExtension("cl_intel_subgroups"); if (definition.precision != CalculationsPrecision::F32_F16 && - device_info.SupportsExtension("cl_khr_subgroups") && + supports_subgroups && device_info.SupportsExtension("cl_intel_required_subgroup_size") && - device_info.IsCL20OrHigher() && device_info.SupportsSubGroupWithSize(sub_group_size)) { conv_params.weights_upload_type = WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST; @@ -1240,7 +1250,7 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParams( } else { conv_params.block_size = int4(1, 1, 1, 4); work_group_size_ = int3(8, 2, 1); - conv_params.work_group_launch_order = int3(0, 1, 2); + work_group_launch_order_ = int3(0, 1, 2); conv_params.fixed_work_group_size = false; conv_params.src_depth_loop_size = 1; conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h index a26e5770ce8..30e412cd923 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h @@ -47,7 +47,7 @@ class ConvPowerVR : public GPUOperation { TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, std::vector* work_groups) const override; - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; ConvWeightsDescription GetConvWeightsDescription() const { @@ -83,7 +83,6 @@ class ConvPowerVR : public GPUOperation { // F32_F16 precision mode DataType weights_data_type; // used for weights and biases int4 block_size; // WHDS - int3 work_group_launch_order; bool fixed_work_group_size; bool linear_spatial; // spatial dimensions are Width/Height/Depth bool different_weights_for_height; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc deleted file mode 100644 index bff328772d7..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.cc +++ /dev/null @@ -1,461 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" - -#include -#include -#include - -#include "absl/strings/substitute.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" -#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/cl/precision.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" - -namespace tflite { -namespace gpu { -namespace cl { -namespace { -bool UseFP16SIMD(const DeviceInfo& device_info, CalculationsPrecision precision, - bool kernel1x1) { - if (!device_info.IsAdreno()) { - return false; - } - switch (precision) { - case CalculationsPrecision::F32: - case CalculationsPrecision::F32_F16: - return false; - case CalculationsPrecision::F16: - return device_info.IsAdreno3xx() && kernel1x1; - } -} -} // namespace - -ConvTexture::ConvTexture(const OperationDef& definition, - const Convolution2DAttributes& attr) - : GPUOperation(definition), - kernel_size_(attr.weights.shape.w, attr.weights.shape.h), - stride_(attr.strides.w, attr.strides.h), - padding_(-attr.padding.prepended.w, -attr.padding.prepended.h), - dilation_(attr.dilations.w, attr.dilations.h), - different_weights_for_height_(false), - block_size_(2, 2, 2) { - work_group_size_ = int3(4, 4, 2); -} - -ConvTexture::ConvTexture(const OperationDef& definition) - : GPUOperation(definition), - kernel_size_(1, 1), - stride_(1, 1), - padding_(0, 0), - dilation_(1, 1), - different_weights_for_height_(false), - block_size_(4, 1, 2) { - work_group_size_ = int3(16, 1, 2); -} - -ConvTexture::ConvTexture(ConvTexture&& operation) - : GPUOperation(std::move(operation)), - kernel_size_(operation.kernel_size_), - stride_(operation.stride_), - padding_(operation.padding_), - dilation_(operation.dilation_), - different_weights_for_height_(operation.different_weights_for_height_), - block_size_(operation.block_size_) {} - -ConvTexture& ConvTexture::operator=(ConvTexture&& operation) { - if (this != &operation) { - std::swap(kernel_size_, operation.kernel_size_); - std::swap(stride_, operation.stride_); - std::swap(padding_, operation.padding_); - std::swap(dilation_, operation.dilation_); - std::swap(different_weights_for_height_, - operation.different_weights_for_height_); - std::swap(block_size_, operation.block_size_); - GPUOperation::operator=(std::move(operation)); - } - return *this; -} - -std::string ConvTexture::GenerateConvCode(const OperationDef& op_def, - const int3& block_size, bool is1x1, - bool adreno4xx_optimization, - bool stride_correction, - bool different_weights_for_height) { - auto src_desc = op_def.src_tensors[0]; - src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); - if (op_def.IsBatchSupported()) { - src_desc.SetStateVar("BatchedWidth", "true"); - } - AddSrcTensor("src_tensor", src_desc); - - auto dst_desc = op_def.dst_tensors[0]; - if (op_def.IsBatchSupported()) { - dst_desc.SetStateVar("BatchedWidth", "true"); - } - AddDstTensor("dst_tensor", dst_desc); - - if (!is1x1) { - args_.AddInt("kernel_size_x"); - args_.AddInt("kernel_size_y"); - args_.AddInt("dilation_x"); - args_.AddInt("dilation_y"); - } - args_.AddInt("stride_x"); - args_.AddInt("stride_y"); - args_.AddInt("padding_x"); - args_.AddInt("padding_y"); - - const auto src_tensor_type = op_def.src_tensors[0].storage_type; - const bool is_buffer = src_tensor_type == TensorStorageType::IMAGE_BUFFER || - src_tensor_type == TensorStorageType::BUFFER; - - std::vector xs(block_size.x); - for (int x = 0; x < block_size.x; ++x) { - xs[x] = std::to_string(x); - } - - std::vector ys(block_size.y); - for (int y = 0; y < block_size.y; ++y) { - ys[y] = std::to_string(y); - } - - std::vector zs(block_size.z); - for (int z = 0; z < block_size.z; ++z) { - zs[z] = std::to_string(z); - } - - std::string c = GetCommonDefines(op_def.precision); - for (int z = 0; z < block_size.z; ++z) { - const std::string f0 = std::to_string(z * 4 + 0); - const std::string f1 = std::to_string(z * 4 + 1); - const std::string f2 = std::to_string(z * 4 + 2); - const std::string f3 = std::to_string(z * 4 + 3); - switch (op_def.precision) { - case CalculationsPrecision::F32: - case CalculationsPrecision::F16: - c += "#define CONV" + zs[z] + "(R, S) \\\n"; - c += "R += S.x * f" + f0 + "; \\\n"; - c += "R += S.y * f" + f1 + "; \\\n"; - c += "R += S.z * f" + f2 + "; \\\n"; - c += "R += S.w * f" + f3 + "; \n"; - break; - case CalculationsPrecision::F32_F16: - c += "#define CONV" + zs[z] + "(R, S) \\\n"; - c += "R += convert_float4(S.x * f" + f0 + " + S.y * f" + f1 + - " + S.z * f" + f2 + " + S.w * f" + f3 + ");\n"; - break; - } - } - - c += "__kernel void main_function(\n"; - c += "$0) {\n"; - c += " int X = get_global_id(0) * " + std::to_string(block_size.x) + ";\n"; - c += " int Y = get_global_id(1) * " + std::to_string(block_size.y) + ";\n"; - c += " int Z = get_global_id(2) * " + std::to_string(block_size.z) + ";\n"; - c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() " - "|| Z >= args.dst_tensor.Slices()) return;\n"; - std::vector s_x(block_size.x); - std::vector s_y(block_size.y); - for (int x = 0; x < block_size.x; ++x) { - if (stride_correction) { - c += " int xc" + xs[x] + " = " + - GetXStrideCorrected("X + " + xs[x], "args.src_tensor.Batch()", - "args.stride_x", "args.padding_x") + - ";\n"; - } else { - c += " int xc" + xs[x] + " = (X +" + xs[x] + - ") * args.stride_x + args.padding_x;\n"; - } - s_x[x] = is1x1 ? "xc" + xs[x] : "cx" + xs[x]; - } - for (int y = 0; y < block_size.y; ++y) { - c += " int yc" + ys[y] + " = (Y +" + ys[y] + - ") * args.stride_y + args.padding_y;\n"; - s_y[y] = is1x1 ? "yc" + ys[y] : "cy" + ys[y]; - } - for (int i = 0; i < block_size.x * block_size.y * block_size.z; ++i) { - c += " ACCUM_FLT4 r" + std::to_string(i) + - " = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n"; - } - std::string f_y = is1x1 ? "s" : "filter_offset"; - if (different_weights_for_height) { - f_y = "Y * args.src_tensor.Slices() + s"; - } - if (!is1x1) { - for (int x = 0; x < block_size.x; ++x) { - c += " int cx" + xs[x] + ";\n"; - } - for (int y = 0; y < block_size.y; ++y) { - c += " int cy" + ys[y] + ";\n"; - } - c += " int filter_offset = 0;\n"; - c += " for (int y = 0; y < args.kernel_size_y; ++y) {\n"; - for (int y = 0; y < block_size.y; ++y) { - c += " cy" + ys[y] + " = y * args.dilation_y + yc" + ys[y] + ";\n"; - } - if (is_buffer) { - for (int y = 0; y < block_size.y; ++y) { - c += " bool in_y" + ys[y] + " = cy" + ys[y] + " >= 0 && cy" + ys[y] + - " < args.src_tensor.Height();\n"; - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " cy" + ys[y] + " = clamp(cy" + ys[y] + - ", 0, args.src_tensor.Height() - 1);\n"; - } - } - } - c += " for (int x = 0; x < args.kernel_size_x; ++x) {\n"; - for (int x = 0; x < block_size.x; ++x) { - c += " cx" + xs[x] + " = x * args.dilation_x + xc" + xs[x] + ";\n"; - } - if (is_buffer) { - for (int x = 0; x < block_size.x; ++x) { - c += " bool in_x" + xs[x] + " = cx" + xs[x] + " >= 0 && cx" + xs[x] + - " < args.src_tensor.Width();\n"; - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " cx" + xs[x] + " = clamp(cx" + xs[x] + - ", 0, args.src_tensor.Width() - 1);\n"; - } - } - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - c += absl::Substitute( - " int addr_$0 = select(-1, cy$2 * args.src_tensor.Width() + " - "cx$1, (in_x$1 " - "&& " - "in_y$2));\n", - y * block_size.x + x, x, y); - c += absl::Substitute( - " int dz_$0 = select(0, args.src_tensor.Width() * " - "args.src_tensor.Height(), (in_x$1 && " - "in_y$2));\n", - y * block_size.x + x, x, y); - } else { - c += absl::Substitute( - " int addr_$0 = cy$2 * args.src_tensor.Width() + cx$1;\n", - y * block_size.x + x, x, y); - } - } - } - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " int dz = args.src_tensor.Width() * args.src_tensor.Height();\n"; - } - } - } else if (is_buffer) { - for (int y = 0; y < block_size.y; ++y) { - c += " bool in_y" + ys[y] + " = yc" + ys[y] + " >= 0 && yc" + ys[y] + - " < args.src_tensor.Height();\n"; - } - for (int x = 0; x < block_size.x; ++x) { - c += " bool in_x" + xs[x] + " = xc" + xs[x] + " >= 0 && xc" + xs[x] + - " < args.src_tensor.Width();\n"; - } - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - c += absl::Substitute( - " int addr_$0 = select(-1, yc$2 * args.src_tensor.Width() + " - "xc$1, (in_x$1 && " - "in_y$2));\n", - y * block_size.x + x, x, y); - c += absl::Substitute( - " int dz_$0 = select(0, args.src_tensor.Width() * " - "args.src_tensor.Height(), (in_x$1 && " - "in_y$2));\n", - y * block_size.x + x, x, y); - } else { - c += absl::Substitute( - " int addr_$0 = yc$2 * args.src_tensor.Width() + xc$1;\n", - y * block_size.x + x, x, y); - } - } - } - if (src_tensor_type == TensorStorageType::BUFFER) { - c += " int dz = args.src_tensor.Width() * args.src_tensor.Height();\n"; - } - } - c += " for (int s = 0; s < args.src_tensor.Slices(); ++s) {\n"; - if (is_buffer) { - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - for (int index = 0; index < block_size.x * block_size.y; ++index) { - const std::string id = std::to_string(index); - c += - " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + ");\n"; - } - } else { - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - c += " FLT4 src" + id + " = args.src_tensor.Read(addr_" + id + - ") * (FLT)(in_x" + xs[x] + " && in_y" + ys[y] + "); addr_" + id + - " += dz;\n"; - } - } - } - } - for (int z = 0; z < block_size.z; ++z) { - c += absl::Substitute(R"( FLT4 f$2 = args.weights0.Read($0, $1); - FLT4 f$3 = args.weights1.Read($0, $1); - FLT4 f$4 = args.weights2.Read($0, $1); - FLT4 f$5 = args.weights3.Read($0, $1); -)", - "Z + " + zs[z], f_y, z * 4 + 0, z * 4 + 1, z * 4 + 2, - z * 4 + 3); - } - if (!is_buffer) { - for (int x = 0; x < block_size.x; ++x) { - for (int y = 0; y < block_size.y; ++y) { - const std::string id = std::to_string(y * block_size.x + x); - c += " FLT4 src" + id + " = args.src_tensor.Read(" + s_x[x] + ", " + - s_y[y] + ", s);\n"; - } - } - } - for (int z = 0; z < block_size.z; ++z) { - for (int i = 0; i < block_size.x * block_size.y; ++i) { - c += " CONV" + zs[z] + "(r" + - std::to_string(i + z * block_size.x * block_size.y) + ", src" + - std::to_string(i) + ");\n"; - } - } - if (!is1x1) { - c += " filter_offset++;\n"; - } - if (is_buffer) { - if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) { - for (int index = 0; index < block_size.x * block_size.y; ++index) { - const std::string id = std::to_string(index); - c += " addr_" + id + " += dz_" + id + ";\n"; - } - } - } - c += " }\n"; // args.src_tensor.Slices() - if (!is1x1) { - c += " }\n"; // kernel_size_x - c += " }\n"; // kernel_size_y - } - // when is1x1 && adreno4xx_optimization is true, xc0 == X and yc0 == Y - std::string dst_x = is1x1 && adreno4xx_optimization ? "xc0" : "X"; - std::string dst_y = is1x1 && adreno4xx_optimization ? "yc0" : "Y"; - for (int z = 0; z < block_size.z; ++z) { - c += " if (Z < args.dst_tensor.Slices()) {\n"; - c += " FLT4 bias_val = args.biases.Read(Z);\n"; - for (int y = 0; y < block_size.y; ++y) { - for (int x = 0; x < block_size.x; ++x) { - const std::string id = - std::to_string((z * block_size.y + y) * block_size.x + x); - c += " {\n"; - c += " int xc = " + dst_x + " + " + xs[x] + ";\n"; - c += " int yc = " + dst_y + " + " + ys[y] + ";\n"; - c += " if (xc < args.dst_tensor.Width() && yc < " - "args.dst_tensor.Height()) {\n"; - c += " FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n"; - c += " args.dst_tensor.Write(res, xc, yc, Z);\n"; - c += " }\n"; - c += " }\n"; - } - } - c += " }\n"; - c += " Z++;\n"; - } - c += "}\n"; - return c; -} - -void ConvTexture::GenerateCode(const DeviceInfo& device_info) { - auto storage_type = definition_.GetPrimaryStorageType(); - bool is1x1 = kernel_size_.x == 1 && kernel_size_.y == 1; - bool adreno4xx_optimization = - stride_.x == 1 && stride_.y == 1 && padding_.x == 0 && padding_.y == 0 && - device_info.IsAdreno4xx() && - storage_type == TensorStorageType::TEXTURE_ARRAY && - definition_.precision == CalculationsPrecision::F16; - const bool stride_correction = - definition_.IsBatchSupported() && stride_.x != 1; - code_ = - GenerateConvCode(definition_, block_size_, is1x1, adreno4xx_optimization, - stride_correction, different_weights_for_height_); - - if (UseFP16SIMD(device_info, definition_.precision, is1x1)) { - compiler_options_.push_back(CompilerOptions::ADRENO_FULL_SIMD_LINE); - } -} - -absl::Status ConvTexture::BindArguments() { - if (!(kernel_size_.x == 1 && kernel_size_.y == 1)) { - RETURN_IF_ERROR(args_.SetInt("kernel_size_x", kernel_size_.x)); - RETURN_IF_ERROR(args_.SetInt("kernel_size_y", kernel_size_.y)); - RETURN_IF_ERROR(args_.SetInt("dilation_x", dilation_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("dilation_y", dilation_.y)); - } - RETURN_IF_ERROR(args_.SetInt("stride_x", stride_.x)); - RETURN_IF_ERROR(args_.SetInt("stride_y", stride_.y)); - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_.x * src_[0]->Batch())); - RETURN_IF_ERROR(args_.SetInt("padding_y", padding_.y)); - return absl::OkStatus(); -} - -int3 ConvTexture::GetGridSize() const { - const int grid_x = - DivideRoundUp(dst_[0]->Width() * dst_[0]->Batch(), block_size_.x); - const int grid_y = DivideRoundUp(dst_[0]->Height(), block_size_.y); - const int grid_z = DivideRoundUp(dst_[0]->Slices(), block_size_.z); - return int3(grid_x, grid_y, grid_z); -} - -void ConvTexture::GetPossibleKernelWorkGroups( - TuningType tuning_type, const DeviceInfo& device_info, - const KernelInfo& kernel_info, std::vector* work_groups) const { - GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, - work_groups); -} - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr) { - ConvTexture result(definition, attr); - result.GenerateCode(device_info); - result.UploadData(attr.weights, attr.bias); - return result; -} - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const FullyConnectedAttributes& attr) { - ConvTexture result(definition); - result.GenerateCode(device_info); - result.UploadData(attr.weights, attr.bias); - return result; -} - -ConvTexture CreateConvTextureWino4x4To6x6(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr) { - ConvTexture result(definition); - result.different_weights_for_height_ = true; - result.block_size_ = {4, 1, 2}; - result.GenerateCode(device_info); - result.UploadDataForWinograd4x4To6x6(attr.weights); - return result; -} - -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h deleted file mode 100644 index b1889265930..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h +++ /dev/null @@ -1,193 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_TEXTURE_H_ -#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_TEXTURE_H_ - -#include - -#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" -#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" -#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" -#include "tensorflow/lite/delegates/gpu/cl/tensor.h" -#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" -#include "tensorflow/lite/delegates/gpu/cl/util.h" -#include "tensorflow/lite/delegates/gpu/common/data_type.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" -#include "tensorflow/lite/delegates/gpu/common/tensor.h" -#include "tensorflow/lite/delegates/gpu/common/types.h" -#include "tensorflow/lite/delegates/gpu/common/winograd_util.h" - -namespace tflite { -namespace gpu { -namespace cl { - -// This convolution process BLOCK_SIZE(XxYxZ) of FLT4 values per thread. -class ConvTexture : public GPUOperation { - public: - ConvTexture() = default; - void GetPossibleKernelWorkGroups( - TuningType tuning_type, const DeviceInfo& device_info, - const KernelInfo& kernel_info, - std::vector* work_groups) const override; - absl::Status BindArguments() override; - int3 GetGridSize() const override; - - // Move only - ConvTexture(ConvTexture&& operation); - ConvTexture& operator=(ConvTexture&& operation); - ConvTexture(const ConvTexture&) = delete; - ConvTexture& operator=(const ConvTexture&) = delete; - - private: - friend ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); - friend ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const FullyConnectedAttributes& attr); - - friend ConvTexture CreateConvTextureWino4x4To6x6( - const DeviceInfo& device_info, const OperationDef& definition, - const Convolution2DAttributes& attr); - - ConvTexture(const OperationDef& definition, - const Convolution2DAttributes& attr); - explicit ConvTexture(const OperationDef& definition); - template - void UploadData(const tflite::gpu::Tensor& weights, - const tflite::gpu::Tensor& biases); - - template - void UploadDataForWinograd4x4To6x6( - const tflite::gpu::Tensor& weights); - - template - void UploadWeights(const tflite::gpu::Tensor& weights); - - void GenerateCode(const DeviceInfo& device_info); - - std::string GenerateConvCode(const OperationDef& op_def, - const int3& block_size, bool is1x1, - bool adreno4xx_optimization, - bool stride_correction, - bool different_weights_for_height); - - int2 kernel_size_; - int2 stride_; - int2 padding_; - int2 dilation_; - - // By default in 2d convolution we have the same weights for WH dims, but in - // some cases we need separate weights for H dimension and convolution kernel - // requires very small modifications to support it. - bool different_weights_for_height_; - - int3 block_size_ = int3(2, 2, 2); -}; - -template -void ConvTexture::UploadData(const tflite::gpu::Tensor& weights, - const tflite::gpu::Tensor& biases) { - UploadWeights(weights); - - TensorLinearDescriptor desc; - desc.storage_type = LinearStorageType::TEXTURE_2D; - desc.element_type = definition_.GetDataType(); - desc.UploadLinearData(biases); - args_.AddObject("biases", - absl::make_unique(std::move(desc))); -} - -template -void ConvTexture::UploadDataForWinograd4x4To6x6( - const tflite::gpu::Tensor& weights) { - tflite::gpu::Tensor wino_weights; - RearrangeWeightsToWinograd4x4To6x6Weights(weights, &wino_weights); - UploadWeights(wino_weights); - - tflite::gpu::Tensor bias; - bias.shape = Linear(1); - bias.data = {0.0f}; - TensorLinearDescriptor desc; - desc.storage_type = LinearStorageType::TEXTURE_2D; - desc.element_type = definition_.GetDataType(); - desc.UploadLinearData(bias); - args_.AddObject("biases", - absl::make_unique(std::move(desc))); -} - -template -void ConvTexture::UploadWeights(const tflite::gpu::Tensor& weights) { - int dst_depth = DivideRoundUp(weights.shape.o, 4); - dst_depth = AlignByN(dst_depth, block_size_.z); - const int src_depth = DivideRoundUp(weights.shape.i, 4); - const int kernel_x = weights.shape.w; - const int kernel_y = weights.shape.h; - - const bool f32_weights = definition_.precision == CalculationsPrecision::F32; - DataType data_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; - - const int elements_count = dst_depth * src_depth * kernel_x * kernel_y * 4; - const int float4_size = f32_weights ? sizeof(float4) : sizeof(half4); - - std::vector data(float4_size * elements_count); - - if (f32_weights) { - float4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsToI4HWIOOGroupO4(weights, block_size_.z, - absl::MakeSpan(ptr, elements_count)); - } else { - half4* ptr = reinterpret_cast(data.data()); - RearrangeWeightsToI4HWIOOGroupO4(weights, block_size_.z, - absl::MakeSpan(ptr, elements_count)); - } - - const int texture_width = dst_depth; - const int texture_height = src_depth * kernel_x * kernel_y; - const int sub_size = float4_size * texture_width * texture_height; - for (int i = 0; i < 4; ++i) { - Texture2DDescriptor desc; - desc.element_type = data_type; - desc.size = int2(texture_width, texture_height); - desc.data.resize(sub_size); - memcpy(desc.data.data(), data.data() + sub_size * i, sub_size); - const std::string name = "weights" + std::to_string(i); - args_.AddObject(name, - absl::make_unique(std::move(desc))); - } -} - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); - -ConvTexture CreateConvTexture(const DeviceInfo& device_info, - const OperationDef& definition, - const FullyConnectedAttributes& attr); - -ConvTexture CreateConvTextureWino4x4To6x6(const DeviceInfo& device_info, - const OperationDef& definition, - const Convolution2DAttributes& attr); - -} // namespace cl -} // namespace gpu -} // namespace tflite - -#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_CONV_TEXTURE_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture_test.cc deleted file mode 100644 index 2a92573b689..00000000000 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_texture_test.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" - -#include - -#include -#include -#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" -#include "tensorflow/lite/delegates/gpu/common/status.h" - -using ::testing::FloatNear; -using ::testing::Pointwise; - -namespace tflite { -namespace gpu { -namespace cl { -namespace { - -TEST_F(OpenCLOperationTest, ConvTextureSimpleWeights) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 2, 2); - src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - - Convolution2DAttributes attr; - attr.padding.prepended = HW(0, 0); - attr.padding.appended = HW(1, 1); - attr.strides = HW(1, 1); - attr.dilations = HW(1, 1); - attr.weights.shape = OHWI(1, 2, 2, 2); - attr.weights.data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - attr.bias.shape = Linear(1); - attr.bias.data = {0.0f}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - ConvTexture operation = - CreateConvTexture(creation_context_.GetDeviceInfo(), op_def, attr); - ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, - BHWC(1, 2, 2, 1), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(eps), {28.0f, 18.0f, 22.0f, 13.0f})); - } - } -} - -TEST_F(OpenCLOperationTest, ConvTexture) { - TensorFloat32 src_tensor; - src_tensor.shape = BHWC(1, 2, 2, 2); - src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}; - - Convolution2DAttributes attr; - attr.padding.prepended = HW(0, 0); - attr.padding.appended = HW(1, 1); - attr.strides = HW(1, 1); - attr.dilations = HW(1, 1); - attr.weights.shape = OHWI(2, 2, 2, 2); - attr.weights.data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, - 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}; - attr.bias.shape = Linear(2); - attr.bias.data = {0.5f, -0.5f}; - - for (auto storage : env_.GetSupportedStorages()) { - for (auto precision : env_.GetSupportedPrecisions()) { - const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f; - OperationDef op_def; - op_def.precision = precision; - auto data_type = DeduceDataTypeFromPrecision(precision); - op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); - op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); - TensorFloat32 dst_tensor; - ConvTexture operation = - CreateConvTexture(creation_context_.GetDeviceInfo(), op_def, attr); - ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation, - BHWC(1, 2, 2, 2), &dst_tensor)); - EXPECT_THAT(dst_tensor.data, - Pointwise(FloatNear(eps), {168.5f, 391.5f, 80.5f, 223.5f, - 60.5f, 235.5f, 20.5f, 123.5f})); - } - } -} - -} // namespace -} // namespace cl -} // namespace gpu -} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc index d6e17ce2a86..521cbefd885 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.cc @@ -110,12 +110,12 @@ std::string ConverterToConvWeights::GetConverterToConvWeightsCode( return c; } -absl::Status ConverterToConvWeights::BindArguments() { +absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) { float4 mask = GetMaskForLastPlane(src_[0]->Channels()); - RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x)); - RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y)); - RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z)); - return args_.SetFloat("mask_w", mask.w); + RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x)); + RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y)); + RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z)); + return args->SetFloat("mask_w", mask.w); } int3 ConverterToConvWeights::GetGridSize() const { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h index fe814d296fa..3c7314ea6c9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h @@ -31,7 +31,7 @@ class ConverterToConvWeights : public GPUOperation { public: ConverterToConvWeights(const OperationDef& definition, const ConvWeightsDescription& conv_weights_desc); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc index d52efb43a08..77ac946637d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc @@ -46,9 +46,11 @@ class OpenClConverterImpl : public TensorObjectConverter { RETURN_IF_ERROR(kernel_.SetMemoryAuto(buffer_mem)); RETURN_IF_ERROR(args_.SetObjectRef("tensor", tensor)); RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter())); - int3 grid = int3(tensor->Width() * tensor->Batch(), tensor->Height(), - tensor->Slices()); - return queue_->DispatchImplicit(kernel_, grid, {16, 8, 1}); + const int3 grid = int3(tensor->Width() * tensor->Batch(), tensor->Height(), + tensor->Slices()); + const int3 work_group_size = {16, 8, 1}; + const int3 work_groups_count = GetWorkGroupsCount(grid, work_group_size); + return queue_->Dispatch(kernel_, work_groups_count, work_group_size); } Arguments args_; @@ -63,47 +65,168 @@ bool IsSupportedDataType(DataType type) { return type == DataType::FLOAT16 || type == DataType::FLOAT32; } -// Implements conversion from OpenCL-specific tensor layout to BHWC. -class FromTensorConverter : public OpenClConverterImpl { +bool IsBHWCOpenCLBuffer(const ObjectDef& def) { + return IsSupportedDataType(def.data_type) && + def.object_type == ObjectType::OPENCL_BUFFER && + def.data_layout == DataLayout::BHWC; +} + +bool IsOpenCLTensor(const ObjectDef& def) { + const bool is_buffer_tensor = def.object_type == ObjectType::OPENCL_BUFFER && + def.data_layout == DataLayout::DHWC4; + const bool is_image2d_tensor = + def.object_type == ObjectType::OPENCL_TEXTURE && + def.data_layout == DataLayout::HDWC4; + const bool is_image2d_array_tensor = + def.object_type == ObjectType::OPENCL_TEXTURE && + def.data_layout == DataLayout::DHWC4; + const bool is_single_image_tensor = + def.object_type == ObjectType::OPENCL_TEXTURE && + def.data_layout == DataLayout::BHWC; + return IsSupportedDataType(def.data_type) && + (is_buffer_tensor || is_image2d_tensor || is_image2d_array_tensor || + is_single_image_tensor); +} + +absl::Status GetOpenCLMemory(const TensorObject& obj, cl_mem* memory) { + auto texture = absl::get_if(&obj); + auto buffer = absl::get_if(&obj); + if (texture && texture->memobj) { + *memory = texture->memobj; + } else if (buffer && buffer->memobj) { + *memory = buffer->memobj; + } else { + return absl::InvalidArgumentError("Missing OpenCL object."); + } + return absl::OkStatus(); +} + +// Implements conversion from OpenCL tensor to another OpenCL tensor. +class TensorToTensorConverter : public OpenClConverterImpl { public: static bool IsSupported(const ObjectDef& input, const ObjectDef& output) { - return IsSupportedDataType(input.data_type) && - IsSupportedDataType(output.data_type) && - // Output is always Buffer/(BHWC|DHWC4) - output.object_type == ObjectType::OPENCL_BUFFER && - (output.data_layout == DataLayout::BHWC || - output.data_layout == DataLayout::DHWC4) && - // Texture2D/HDWC4 -> - ((input.object_type == ObjectType::OPENCL_TEXTURE && - input.data_layout == DataLayout::HDWC4) || - // SingleTextureArray/BHWC -> - (input.object_type == ObjectType::OPENCL_TEXTURE && - input.data_layout == DataLayout::BHWC) || - // TextureArray/DHWC4 -> - (input.object_type == ObjectType::OPENCL_TEXTURE && - input.data_layout == DataLayout::DHWC4) || - // Buffer/DHWC4 -> - (input.object_type == ObjectType::OPENCL_BUFFER && - input.data_layout == DataLayout::DHWC4)); + return IsOpenCLTensor(input) && IsOpenCLTensor(output); } - std::pair GetToDhwc4Kernel( - const TensorObjectDef& input_def, - const TensorObjectDef& output_def) const { - return std::make_pair("__global " + - ToCLDataType(output_def.object_def.data_type, 4) + - "* dst", - "dst[(d * args.tensor.Height() + y) * " - "args.tensor.Width() + x] = input;"); + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { + src_tensor_descriptor_.layout = Layout::BHWC; + src_tensor_descriptor_.storage_type = ToTensorStorageType( + input_def.object_def.object_type, input_def.object_def.data_layout); + src_tensor_descriptor_.data_type = input_def.object_def.data_type; + args_.AddObjectRef( + "src_tensor", AccessType::READ, + absl::make_unique(src_tensor_descriptor_)); + + dst_tensor_descriptor_.layout = Layout::BHWC; + dst_tensor_descriptor_.storage_type = ToTensorStorageType( + output_def.object_def.object_type, output_def.object_def.data_layout); + dst_tensor_descriptor_.data_type = output_def.object_def.data_type; + args_.AddObjectRef( + "dst_tensor", AccessType::WRITE, + absl::make_unique(dst_tensor_descriptor_)); + + const bool need_fp16_support = + input_def.object_def.data_type == DataType::FLOAT16 || + output_def.object_def.data_type == DataType::FLOAT16; + const std::string out_data_type = + ToCLDataType(output_def.object_def.data_type); + std::string shader_src; + if (need_fp16_support) { + shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + } + shader_src += + R"(__kernel void tensor_to_tensor($0) { + int linear_id = get_global_id(0); + int x = linear_id / args.dst_tensor.Batch(); + int b = linear_id % args.dst_tensor.Batch(); + int y = get_global_id(1); + int d = get_global_id(2); + if (x >= args.dst_tensor.Width() || y >= args.dst_tensor.Height() || d >= args.dst_tensor.Slices()) return; +)"; + shader_src += " " + out_data_type + "4 input = args.src_tensor.Read<" + + out_data_type + ">(x, y, d, b);\n"; + shader_src += " args.dst_tensor.Write(input, x, y, d, b);\n}"; + queue_ = environment->queue(); + context_ = &environment->context(); + shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h, + input_def.dimensions.w, input_def.dimensions.c); + RETURN_IF_ERROR( + args_.TransformToCLCode(environment->device().info_, {}, &shader_src)); + return environment->program_cache()->GetOrCreateCLKernel( + shader_src, "tensor_to_tensor", environment->context(), + environment->device(), &kernel_); } - std::pair GetToBhwcKernel( - const TensorObjectDef& input_def, - const TensorObjectDef& output_def) const { - return std::make_pair( - "__global " + ToCLDataType(output_def.object_def.data_type) + "* dst", - R"( - int c = d * 4; + absl::Status Convert(const TensorObject& input_obj, + const TensorObject& output_obj) override { + cl_mem in_memory; + RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory)); + cl_mem out_memory; + RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory)); + + Tensor src_tensor; + RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_, + src_tensor_descriptor_, &src_tensor)); + Tensor dst_tensor; + RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_, + dst_tensor_descriptor_, &dst_tensor)); + RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", &src_tensor)); + RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", &dst_tensor)); + RETURN_IF_ERROR(args_.Bind(kernel_.kernel())); + const int3 grid = int3(dst_tensor.Width() * dst_tensor.Batch(), + dst_tensor.Height(), dst_tensor.Slices()); + const int3 work_group_size = {16, 8, 1}; + const int3 work_groups_count = GetWorkGroupsCount(grid, work_group_size); + return queue_->Dispatch(kernel_, work_groups_count, work_group_size); + } + + private: + TensorDescriptor src_tensor_descriptor_; + TensorDescriptor dst_tensor_descriptor_; +}; + +// Implements conversion from OpenCL-specific tensor layout to BHWC OpenCL +// buffer. +class TensorToBHWCBufferConverter : public OpenClConverterImpl { + public: + static bool IsSupported(const ObjectDef& input, const ObjectDef& output) { + return IsOpenCLTensor(input) && IsBHWCOpenCLBuffer(output); + } + + absl::Status Init(const TensorObjectDef& input_def, + const TensorObjectDef& output_def, + Environment* environment) final { + TensorStorageType src_tensor_type = ToTensorStorageType( + input_def.object_def.object_type, input_def.object_def.data_layout); + tensor_descriptor_.layout = Layout::BHWC; + tensor_descriptor_.storage_type = src_tensor_type; + tensor_descriptor_.data_type = input_def.object_def.data_type; + args_.AddObjectRef("tensor", AccessType::READ, + absl::make_unique(tensor_descriptor_)); + + const bool need_fp16_support = + input_def.object_def.data_type == DataType::FLOAT16 || + output_def.object_def.data_type == DataType::FLOAT16; + std::string shader_src; + if (need_fp16_support) { + shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + } + const std::string out_data_type = + ToCLDataType(output_def.object_def.data_type); + shader_src += "__kernel void tensor_to_bhwc("; + shader_src += "__global " + out_data_type + "* dst, $0) {\n"; + shader_src += R"( int linear_id = get_global_id(0); + int x = linear_id / args.tensor.Batch(); + int b = linear_id % args.tensor.Batch(); + int y = get_global_id(1); + int d = get_global_id(2); + if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return; +)"; + shader_src += " " + out_data_type + "4 input = args.tensor.Read<" + + out_data_type + ">(x, y, d, b);\n"; + shader_src += R"( int c = d * 4; int index = ((b * args.tensor.Height() + y) * args.tensor.Width() + x) * args.tensor.Channels() + c; dst[index] = input.x; @@ -115,39 +238,8 @@ class FromTensorConverter : public OpenClConverterImpl { } if (c + 3 < args.tensor.Channels()) { dst[index + 3] = input.w; - })"); } - - absl::Status Init(const TensorObjectDef& input_def, - const TensorObjectDef& output_def, - Environment* environment) final { - auto params_kernel = output_def.object_def.data_layout == DataLayout::BHWC - ? GetToBhwcKernel(input_def, output_def) - : GetToDhwc4Kernel(input_def, output_def); - - TensorStorageType src_tensor_type = ToTensorStorageType( - input_def.object_def.object_type, input_def.object_def.data_layout); - tensor_descriptor_.layout = Layout::BHWC; - tensor_descriptor_.storage_type = src_tensor_type; - tensor_descriptor_.data_type = input_def.object_def.data_type; - args_.AddObjectRef("tensor", AccessType::READ, - absl::make_unique(tensor_descriptor_)); - std::string shader_src = - R"( -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -__kernel void from_tensor()" + - params_kernel.first + R"(, $0) { - int linear_id = get_global_id(0); - int x = linear_id / args.tensor.Batch(); - int b = linear_id % args.tensor.Batch(); - int y = get_global_id(1); - int d = get_global_id(2); - if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return; - )" + ToCLDataType(output_def.object_def.data_type, 4) + - " input = args.tensor.Read<" + - ToCLDataType(output_def.object_def.data_type) + ">(x, y, d, b);\n" + - params_kernel.second + "\n}"; +})"; queue_ = environment->queue(); context_ = &environment->context(); shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h, @@ -155,7 +247,7 @@ __kernel void from_tensor()" + RETURN_IF_ERROR( args_.TransformToCLCode(environment->device().info_, {}, &shader_src)); return environment->program_cache()->GetOrCreateCLKernel( - shader_src, "from_tensor", environment->context(), + shader_src, "tensor_to_bhwc", environment->context(), environment->device(), &kernel_); } @@ -164,64 +256,24 @@ __kernel void from_tensor()" + auto output = absl::get_if(&output_obj); if (!output || !output->memobj) { return absl::InvalidArgumentError( - "Missing output in from_tensor converter"); - } - cl_mem memory = nullptr; - auto input_texture = absl::get_if(&input_obj); - if (input_texture && input_texture->memobj) { - memory = input_texture->memobj; - } - auto input_buffer = absl::get_if(&input_obj); - if (input_buffer && input_buffer->memobj) { - memory = input_buffer->memobj; - } - if (!memory) { - return absl::InvalidArgumentError( - "Missing input in from_tensor converter"); + "Missing output in tensor_to_bhwc converter"); } + + cl_mem in_memory; + RETURN_IF_ERROR(GetOpenCLMemory(input_obj, &in_memory)); Tensor tensor; - RETURN_IF_ERROR(CreateSharedTensor(*context_, memory, shape_, + RETURN_IF_ERROR(CreateSharedTensor(*context_, in_memory, shape_, tensor_descriptor_, &tensor)); return DispatchKernel(output->memobj, &tensor); } }; -// Implements conversion from BHWC to OpenCL-specific tensor layout. -class ToTensorConverter : public OpenClConverterImpl { +// Implements conversion from BHWC OpenCL buffer to OpenCL-specific tensor +// layout. +class BHWCBufferToTensorConverter : public OpenClConverterImpl { public: static bool IsSupported(const ObjectDef& input, const ObjectDef& output) { - return IsSupportedDataType(input.data_type) && - IsSupportedDataType(output.data_type) && - // Input is always Buffer/BHWC - input.object_type == ObjectType::OPENCL_BUFFER && - (input.data_layout == DataLayout::BHWC || - input.data_layout == DataLayout::DHWC4) && - // -> Texture2D/HDWC4 - ((output.object_type == ObjectType::OPENCL_TEXTURE && - output.data_layout == DataLayout::HDWC4) || - // -> TextureArray/DHWC4 - (output.object_type == ObjectType::OPENCL_TEXTURE && - output.data_layout == DataLayout::DHWC4) || - // -> SingleTextureArray/BHWC - (output.object_type == ObjectType::OPENCL_TEXTURE && - output.data_layout == DataLayout::BHWC) || - // -> Buffer/DHWC4 - (output.object_type == ObjectType::OPENCL_BUFFER && - output.data_layout == DataLayout::DHWC4)); - } - - std::pair GetFromDhwc4Kernel( - const TensorObjectDef& input_def, - const TensorObjectDef& output_def) const { - return std::make_pair( - "__global " + ToCLDataType(input_def.object_def.data_type, 4) + "* src", - output_def.object_def.data_type == input_def.object_def.data_type - ? "result = src[(d * args.tensor.Height() + y) * " - "args.tensor.Width() + x];" - : "result = convert_" + - ToCLDataType(output_def.object_def.data_type, 4) + - "(src[(d * args.tensor.Height() + y) * args.tensor.Width() + " - "x]);"); + return IsBHWCOpenCLBuffer(input) && IsOpenCLTensor(output); } std::pair GetFromBhwcKernel( @@ -241,9 +293,8 @@ class ToTensorConverter : public OpenClConverterImpl { absl::Status Init(const TensorObjectDef& input_def, const TensorObjectDef& output_def, Environment* environment) final { - auto params_kernel = input_def.object_def.data_layout == DataLayout::BHWC - ? GetFromBhwcKernel(input_def, output_def) - : GetFromDhwc4Kernel(input_def, output_def); + auto params_kernel = GetFromBhwcKernel(input_def, output_def); + TensorStorageType dst_tensor_type = ToTensorStorageType( output_def.object_def.object_type, output_def.object_def.data_layout); tensor_descriptor_.layout = Layout::BHWC; @@ -251,23 +302,38 @@ class ToTensorConverter : public OpenClConverterImpl { tensor_descriptor_.data_type = output_def.object_def.data_type; args_.AddObjectRef("tensor", AccessType::WRITE, absl::make_unique(tensor_descriptor_)); - std::string shader_src = - R"( -#pragma OPENCL EXTENSION cl_khr_3d_image_writes : enable -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void to_tensor()" + - params_kernel.first + R"(, $0) { - int linear_id = get_global_id(0); + const bool need_fp16_support = + input_def.object_def.data_type == DataType::FLOAT16 || + output_def.object_def.data_type == DataType::FLOAT16; + std::string shader_src; + if (need_fp16_support) { + shader_src += "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"; + } + const std::string in_data_type = + ToCLDataType(input_def.object_def.data_type); + const std::string out_data_type = + ToCLDataType(output_def.object_def.data_type); + shader_src += "__kernel void bhwc_to_tensor("; + shader_src += "__global " + in_data_type + "* src, $0) {\n"; + + shader_src += R"( int linear_id = get_global_id(0); int x = linear_id / args.tensor.Batch(); int b = linear_id % args.tensor.Batch(); int y = get_global_id(1); int d = get_global_id(2); if (x >= args.tensor.Width() || y >= args.tensor.Height() || d >= args.tensor.Slices()) return; - )" + ToCLDataType(output_def.object_def.data_type, 4) + - " result;\n" + params_kernel.second + "\n " + - "args.tensor.Write(result, x, y, d, b);\n}"; +)"; + shader_src += " " + out_data_type + "4 result;\n"; + shader_src += R"( int c = d * 4; + int index = ((b * args.tensor.Height() + y) * args.tensor.Width() + x) * args.tensor.Channels() + c; + result.x = src[index]; + result.y = c + 1 < args.tensor.Channels() ? src[index + 1] : 1; + result.z = c + 2 < args.tensor.Channels() ? src[index + 2] : 2; + result.w = c + 3 < args.tensor.Channels() ? src[index + 3] : 3; +)"; + shader_src += " args.tensor.Write(result, x, y, d, b);\n}"; queue_ = environment->queue(); context_ = &environment->context(); shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h, @@ -275,31 +341,21 @@ __kernel void to_tensor()" + RETURN_IF_ERROR( args_.TransformToCLCode(environment->device().info_, {}, &shader_src)); return environment->program_cache()->GetOrCreateCLKernel( - shader_src, "to_tensor", environment->context(), environment->device(), - &kernel_); + shader_src, "bhwc_to_tensor", environment->context(), + environment->device(), &kernel_); } absl::Status Convert(const TensorObject& input_obj, const TensorObject& output_obj) override { auto input = absl::get_if(&input_obj); if (!input || !input->memobj) { - return absl::InvalidArgumentError("Missing input in to_tensor converter"); - } - cl_mem memory = nullptr; - auto output_texture = absl::get_if(&output_obj); - if (output_texture && output_texture->memobj) { - memory = output_texture->memobj; - } - auto output_buffer = absl::get_if(&output_obj); - if (output_buffer && output_buffer->memobj) { - memory = output_buffer->memobj; - } - if (!memory) { return absl::InvalidArgumentError( - "Missing output in to_tensor converter"); + "Missing input in bhwc_to_tensor converter"); } + cl_mem out_memory; + RETURN_IF_ERROR(GetOpenCLMemory(output_obj, &out_memory)); Tensor tensor; - RETURN_IF_ERROR(CreateSharedTensor(*context_, memory, shape_, + RETURN_IF_ERROR(CreateSharedTensor(*context_, out_memory, shape_, tensor_descriptor_, &tensor)); return DispatchKernel(input->memobj, &tensor); } @@ -465,9 +521,10 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { const auto& output_def = output.object_def; return input.dimensions == output.dimensions && (TrivialCopier::IsSupported(input_def, output_def) || + TensorToTensorConverter::IsSupported(input_def, output_def) || CpuCopier::IsSupported(input_def, output_def) || - FromTensorConverter::IsSupported(input_def, output_def) || - ToTensorConverter::IsSupported(input_def, output_def)); + TensorToBHWCBufferConverter::IsSupported(input_def, output_def) || + BHWCBufferToTensorConverter::IsSupported(input_def, output_def)); } absl::Status MakeConverter( @@ -478,12 +535,16 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder { const auto& output_def = output.object_def; if (TrivialCopier::IsSupported(input_def, output_def)) { impl = absl::make_unique(); + } else if (TensorToTensorConverter::IsSupported(input_def, output_def)) { + impl = absl::make_unique(); } else if (CpuCopier::IsSupported(input_def, output_def)) { impl = absl::make_unique(); - } else if (FromTensorConverter::IsSupported(input_def, output_def)) { - impl = absl::make_unique(); - } else if (ToTensorConverter::IsSupported(input_def, output_def)) { - impl = absl::make_unique(); + } else if (TensorToBHWCBufferConverter::IsSupported(input_def, + output_def)) { + impl = absl::make_unique(); + } else if (BHWCBufferToTensorConverter::IsSupported(input_def, + output_def)) { + impl = absl::make_unique(); } else { return absl::UnimplementedError("Unsupported conversion"); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc index d22dbbd88cf..b2bf5216f8e 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.cc @@ -516,12 +516,12 @@ std::string ConvolutionTransposed::GenerateConvolutionTransposedCode( return c; } -absl::Status ConvolutionTransposed::BindArguments() { +absl::Status ConvolutionTransposed::BindArguments(ArgumentsBinder* args) { if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) { const int aligned_h = AlignByN(dst_[0]->Height(), stride_.y * block_size_.y); RETURN_IF_ERROR( - args_.SetInt("grid_size_y", DivideRoundUp(aligned_h, block_size_.y))); + args->SetInt("grid_size_y", DivideRoundUp(aligned_h, block_size_.y))); } return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h index 134bfc4839c..5aa86f33e5a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed.h @@ -44,7 +44,7 @@ class ConvolutionTransposed : public GPUOperation { TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, std::vector* work_groups) const override; - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc index af952dd3f78..7880f31013a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.cc @@ -29,10 +29,9 @@ namespace gpu { namespace cl { ConvolutionTransposed3x3::ConvolutionTransposed3x3( const OperationDef& definition, const DeviceInfo& device_info, int2 padding) - : GPUOperation(definition), - padding_(padding), - work_group_launch_order_(2, 0, 1) { + : GPUOperation(definition), padding_(padding) { work_group_size_ = int3(8, 4, 1); + work_group_launch_order_ = int3(2, 0, 1); if (device_info.IsPowerVR()) { weights_upload_type_ = WeightsUploadType::LOCAL_MEM_ASYNC; } else if (device_info.IsNvidia() || device_info.IsIntel()) { @@ -54,14 +53,12 @@ ConvolutionTransposed3x3::ConvolutionTransposed3x3( ConvolutionTransposed3x3&& operation) : GPUOperation(std::move(operation)), padding_(operation.padding_), - work_group_launch_order_(operation.work_group_launch_order_), weights_upload_type_(operation.weights_upload_type_) {} ConvolutionTransposed3x3& ConvolutionTransposed3x3::operator=( ConvolutionTransposed3x3&& operation) { if (this != &operation) { std::swap(padding_, operation.padding_); - std::swap(work_group_launch_order_, operation.work_group_launch_order_); std::swap(weights_upload_type_, operation.weights_upload_type_); GPUOperation::operator=(std::move(operation)); } @@ -305,27 +302,33 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode( return c; } -absl::Status ConvolutionTransposed3x3::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("filter_offset", 4 * 9 * src_[0]->Slices())); +absl::Status ConvolutionTransposed3x3::BindArguments(ArgumentsBinder* args) { + RETURN_IF_ERROR(args->SetInt("filter_offset", 4 * 9 * src_[0]->Slices())); const int padding_x = padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2; const int padding_y = padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2; - RETURN_IF_ERROR(args_.SetInt("padding_x", padding_x * src_[0]->Batch())); - return args_.SetInt("padding_y", padding_y); + RETURN_IF_ERROR(args->SetInt("padding_x", padding_x * src_[0]->Batch())); + return args->SetInt("padding_y", padding_y); +} + +void ConvolutionTransposed3x3::GetPossibleKernelWorkGroups( + TuningType tuning_type, const DeviceInfo& device_info, + const KernelInfo& kernel_info, std::vector* work_groups) const { + if (weights_upload_type_ == WeightsUploadType::LOCAL_MEM_ASYNC || + weights_upload_type_ == WeightsUploadType::LOCAL_MEM_BY_THREADS) { + work_groups->push_back(work_group_size_); + return; + } + GetPossibleWorkGroupsConv(tuning_type, device_info, kernel_info, grid_size_, + work_groups); } int3 ConvolutionTransposed3x3::GetGridSize() const { const int grid_x = DivideRoundUp(dst_[0]->Width(), 2) * dst_[0]->Batch(); const int grid_y = DivideRoundUp(dst_[0]->Height(), 2); const int grid_z = dst_[0]->Slices(); - int3 wg; - wg.x = DivideRoundUp(grid_x, work_group_size_.x); - wg.y = DivideRoundUp(grid_y, work_group_size_.y); - wg.z = DivideRoundUp(grid_z, work_group_size_.z); - return int3(wg[work_group_launch_order_[0]] * work_group_size_.x, - wg[work_group_launch_order_[1]] * work_group_size_.y, - wg[work_group_launch_order_[2]] * work_group_size_.z); + return int3(grid_x, grid_y, grid_z); } bool IsConvolutionTransposed3x3Supported( diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h index ad3e459da3e..074fc23b0e7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_3x3.h @@ -40,10 +40,8 @@ class ConvolutionTransposed3x3 : public GPUOperation { void GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, const KernelInfo& kernel_info, - std::vector* work_groups) const override { - work_groups->push_back(work_group_size_); - } - absl::Status BindArguments() override; + std::vector* work_groups) const override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only @@ -78,7 +76,6 @@ class ConvolutionTransposed3x3 : public GPUOperation { int2 padding, int3 work_group_launch_order); int2 padding_; - int3 work_group_launch_order_; WeightsUploadType weights_upload_type_; }; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc index d606a822d7e..0f389361724 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.cc @@ -296,8 +296,8 @@ std::string ConvolutionTransposed4x4::GenerateConvolutionTransposedCode( return c; } -absl::Status ConvolutionTransposed4x4::BindArguments() { - return args_.SetInt("filter_offset", 4 * 16 * src_[0]->Slices()); +absl::Status ConvolutionTransposed4x4::BindArguments(ArgumentsBinder* args) { + return args->SetInt("filter_offset", 4 * 16 * src_[0]->Slices()); } int3 ConvolutionTransposed4x4::GetGridSize() const { diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h index 2577eb47513..17d63233864 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/convolution_transposed_4x4.h @@ -43,7 +43,7 @@ class ConvolutionTransposed4x4 : public GPUOperation { std::vector* work_groups) const override { work_groups->push_back(work_group_size_); } - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc index f42ad824006..05d5d086bc7 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.cc @@ -67,17 +67,18 @@ std::string GetSrcValue(int channel_multiplier, const std::string coords) { return c; } -std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def, - bool stride_correction, - int channel_multiplier, - bool weights_are_buffer, - GPUOperation* op) { +std::string GenerateDepthwiseConvolutionCode( + const OperationDef& op_def, bool stride_correction, int channel_multiplier, + bool weights_are_buffer, bool dynamic_weights, GPUOperation* op) { auto src_desc = op_def.src_tensors[0]; src_desc.SetTextureAddressMode(TextureAddressMode::ZERO); if (op_def.IsBatchSupported()) { src_desc.SetStateVar("BatchedWidth", "true"); } op->AddSrcTensor("src_tensor", src_desc); + if (dynamic_weights) { + op->AddSrcTensor("weights", op_def.src_tensors[1]); + } auto dst_desc = op_def.dst_tensors[0]; if (op_def.IsBatchSupported()) { @@ -122,16 +123,24 @@ std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def, } } c += " int y_offseted = Y * args.stride_y + args.padding_y;\n"; - std::string weights_offset = "args.kernel_size_x * args.kernel_size_y"; - if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { - c += " int z_offseted = Z * args.stride_z + args.padding_z;\n"; - weights_offset += " * args.kernel_size_z"; - } - if (weights_are_buffer) { - c += " int fx_c = S * " + weights_offset + ";\n"; - } else { - c += " int fx_c = 0;\n"; + if (!dynamic_weights) { + std::string weights_offset = "args.kernel_size_x * args.kernel_size_y"; + if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { + c += " int z_offseted = Z * args.stride_z + args.padding_z;\n"; + weights_offset += " * args.kernel_size_z"; + } + if (weights_are_buffer) { + c += " int fx_c = S * " + weights_offset + ";\n"; + } else { + c += " int fx_c = 0;\n"; + } } + std::string kernel_size_x = + dynamic_weights ? "args.weights.Width()" : "args.kernel_size_x"; + std::string kernel_size_y = + dynamic_weights ? "args.weights.Height()" : "args.kernel_size_y"; + std::string kernel_size_z = + dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z"; std::string flat_coords = "x_c, y_c"; if (manual_clamp) { @@ -139,29 +148,35 @@ std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def, if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { check += " && !outside_z"; flat_coords += ", z_c"; - c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n"; + c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n"; c += " int z_c = z_offseted + kz * args.dilation_z;\n"; c += " bool outside_z = z_c < 0 || z_c >= args.src_tensor.Depth();\n"; } - c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n"; + c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n"; c += " int y_c = y_offseted + ky * args.dilation_y;\n"; c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n"; - c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n"; + c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n"; const std::string dilation_x = op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" : "args.dilation_x"; c += " int x_c = x_offseted + kx * " + dilation_x + ";\n"; c += " bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n"; c += " if (" + check + ") {\n"; - if (weights_are_buffer) { - c += " FLT4 f = args.weights.Read(fx_c);\n"; + if (dynamic_weights) { + c += " FLT4 f = args.weights.Read(kx, ky, S);\n"; } else { - c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + if (weights_are_buffer) { + c += " FLT4 f = args.weights.Read(fx_c);\n"; + } else { + c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + } } c += GetSrcValue(channel_multiplier, flat_coords); c += " r += TO_ACCUM_TYPE(src_final * f);\n"; c += " };\n"; - c += " fx_c++;\n"; + if (!dynamic_weights) { + c += " fx_c++;\n"; + } c += " }\n"; c += " }\n"; if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { @@ -170,7 +185,7 @@ std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def, } else { // Texture types with ZERO clamping if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) { flat_coords += ", z_c"; - c += " for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n"; + c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n"; c += " int z_c = z_offseted + kz * args.dilation_z;\n"; if (src_tensor_type != TensorStorageType::TEXTURE_3D) { // Only TEXTURE_3D supports clamping @@ -181,20 +196,24 @@ std::string GenerateDepthwiseConvolutionCode(const OperationDef& op_def, c += " }\n"; } } - c += " for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n"; + c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n"; c += " int y_c = y_offseted + ky * args.dilation_y;\n"; - c += " for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n"; + c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n"; const std::string dilation_x = op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()" : "args.dilation_x"; c += " int x_c = x_offseted + kx * " + dilation_x + ";\n"; c += GetSrcValue(channel_multiplier, flat_coords); - if (weights_are_buffer) { - c += " FLT4 f = args.weights.Read(fx_c);\n"; + if (dynamic_weights) { + c += " FLT4 f = args.weights.Read(kx, ky, S);\n"; } else { - c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + if (weights_are_buffer) { + c += " FLT4 f = args.weights.Read(fx_c);\n"; + } else { + c += " FLT4 f = args.weights.Read(fx_c, S);\n"; + } + c += " fx_c++;\n"; } - c += " fx_c++;\n"; c += " r += TO_ACCUM_TYPE(src_final * f);\n"; c += " }\n"; c += " }\n"; @@ -234,7 +253,7 @@ GPUOperation CreateDepthwiseConvolution2D( definition.IsBatchSupported() && attr.strides.w != 1; op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, attr.weights.shape.o, - weights_are_buffer, &op); + weights_are_buffer, false, &op); UploadWeightsForDWConv2D(attr.weights, weights_are_buffer, definition.precision, &op); op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; @@ -249,6 +268,32 @@ GPUOperation CreateDepthwiseConvolution2D( return op; } +GPUOperation CreateDepthwiseConvolution2DDynamicWeights( + const DeviceInfo& device_info, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr) { + GPUOperation op(definition); + op.args_.AddInt("stride_x", attr.strides.w); + op.args_.AddInt("padding_x", -attr.padding.prepended.w); + op.args_.AddInt("dilation_x", attr.dilations.w); + op.args_.AddInt("stride_y", attr.strides.h); + op.args_.AddInt("padding_y", -attr.padding.prepended.h); + op.args_.AddInt("dilation_y", attr.dilations.h); + const bool stride_correction = + definition.IsBatchSupported() && attr.strides.w != 1; + op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, 1, + false, true, &op); + op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; + + TensorLinearDescriptor desc; + desc.storage_type = device_info.IsMali() ? LinearStorageType::BUFFER + : LinearStorageType::TEXTURE_2D; + desc.element_type = definition.GetDataType(); + desc.UploadLinearData(attr.bias); + op.args_.AddObject( + "biases", absl::make_unique(std::move(desc))); + return op; +} + GPUOperation CreateDepthwiseConvolution3D( const DeviceInfo& device_info, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr) { @@ -273,7 +318,7 @@ GPUOperation CreateDepthwiseConvolution3D( definition.IsBatchSupported() && attr.strides.w != 1; op.code_ = GenerateDepthwiseConvolutionCode(definition, stride_correction, attr.weights.shape.o, - weights_are_buffer, &op); + weights_are_buffer, false, &op); UploadWeightsForDWConv3D(attr.weights, weights_are_buffer, definition.precision, &op); op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h index a5a4f5ba339..3bb034849bc 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h @@ -186,6 +186,10 @@ GPUOperation CreateDepthwiseConvolution2D( const DeviceInfo& device_info, const OperationDef& definition, const DepthwiseConvolution2DAttributes& attr); +GPUOperation CreateDepthwiseConvolution2DDynamicWeights( + const DeviceInfo& device_info, const OperationDef& definition, + const DepthwiseConvolution2DAttributes& attr); + GPUOperation CreateDepthwiseConvolution3D( const DeviceInfo& device_info, const OperationDef& definition, const DepthwiseConvolution3DAttributes& attr); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc index b34b8e38b41..025de5a7c7f 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.cc @@ -49,6 +49,33 @@ std::string GetElementWiseCode(const OperationDef& op_def, return c; } +int3 GetWorkGroupsCount(int grid_dimension, const int3& grid_size, + const int3& work_group_size, + const int3& work_group_launch_order) { + int3 work_groups_count; + if (grid_dimension == 1) { + work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x); + work_groups_count.y = 1; + work_groups_count.z = 1; + } else if (grid_dimension == 2) { + int3 wgs; + wgs.x = DivideRoundUp(grid_size.x, work_group_size.x); + wgs.y = DivideRoundUp(grid_size.y, work_group_size.y); + work_groups_count.x = wgs[work_group_launch_order[0]]; + work_groups_count.y = wgs[work_group_launch_order[1]]; + work_groups_count.z = 1; + } else { // grid_dimension == 3 + int3 wgs; + wgs.x = DivideRoundUp(grid_size.x, work_group_size.x); + wgs.y = DivideRoundUp(grid_size.y, work_group_size.y); + wgs.z = DivideRoundUp(grid_size.z, work_group_size.z); + work_groups_count.x = wgs[work_group_launch_order[0]]; + work_groups_count.y = wgs[work_group_launch_order[1]]; + work_groups_count.z = wgs[work_group_launch_order[2]]; + } + return work_groups_count; +} + } // namespace DataType OperationDef::GetDataType() const { @@ -106,9 +133,12 @@ GPUOperation::GPUOperation(GPUOperation&& operation) src_(std::move(operation.src_)), dst_(std::move(operation.dst_)), kernel_(std::move(operation.kernel_)), + grid_dimension_(operation.grid_dimension_), + work_group_launch_order_(operation.work_group_launch_order_), grid_size_(operation.grid_size_), src_tensors_names_(std::move(operation.src_tensors_names_)), dst_tensors_names_(std::move(operation.dst_tensors_names_)), + work_groups_count_(operation.work_groups_count_), linkable_count_(operation.linkable_count_), elementwise_code_(std::move(operation.elementwise_code_)) {} @@ -126,9 +156,12 @@ GPUOperation& GPUOperation::operator=(GPUOperation&& operation) { src_ = std::move(operation.src_); dst_ = std::move(operation.dst_); kernel_ = std::move(operation.kernel_); + std::swap(grid_dimension_, operation.grid_dimension_); + std::swap(work_group_launch_order_, operation.work_group_launch_order_); std::swap(grid_size_, operation.grid_size_); src_tensors_names_ = std::move(operation.src_tensors_names_); dst_tensors_names_ = std::move(operation.dst_tensors_names_); + std::swap(work_groups_count_, operation.work_groups_count_); std::swap(linkable_count_, operation.linkable_count_); elementwise_code_ = std::move(operation.elementwise_code_); } @@ -183,8 +216,10 @@ absl::Status GPUOperation::UpdateParams() { for (int i = 0; i < dst_tensors_names_.size(); ++i) { RETURN_IF_ERROR(args_.SetObjectRef(dst_tensors_names_[i], dst_[i])); } - RETURN_IF_ERROR(BindArguments()); + RETURN_IF_ERROR(BindArguments(&args_)); grid_size_ = GetGridSize(); + work_groups_count_ = GetWorkGroupsCount( + grid_dimension_, grid_size_, work_group_size_, work_group_launch_order_); return absl::OkStatus(); } @@ -245,14 +280,26 @@ absl::Status GPUOperation::Tune(const TuningParameters& params) { } if (possible_work_groups.size() == 1) { work_group_size_ = possible_work_groups[0]; + work_groups_count_ = + GetWorkGroupsCount(grid_dimension_, grid_size_, work_group_size_, + work_group_launch_order_); return absl::OkStatus(); } else { + std::vector work_groups_count(possible_work_groups.size()); + for (int i = 0; i < work_groups_count.size(); ++i) { + work_groups_count[i] = + GetWorkGroupsCount(grid_dimension_, grid_size_, + possible_work_groups[i], work_group_launch_order_); + } RETURN_IF_ERROR(args_.Bind(kernel_.kernel())); int best_work_group_index; RETURN_IF_ERROR(params.queue->GetBestWorkGroupIndex( - kernel_, *params.info, grid_size_, possible_work_groups, + kernel_, *params.info, work_groups_count, possible_work_groups, &best_work_group_index)); work_group_size_ = possible_work_groups[best_work_group_index]; + work_groups_count_ = + GetWorkGroupsCount(grid_dimension_, grid_size_, work_group_size_, + work_group_launch_order_); return absl::OkStatus(); } } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h index 11386cc434e..fe41e78f93c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h @@ -120,7 +120,7 @@ class GPUOperation { absl::Status AddToQueue(CLCommandQueue* queue) { RETURN_IF_ERROR(args_.Bind(kernel_.kernel())); - return queue->DispatchImplicit(kernel_, grid_size_, work_group_size_); + return queue->Dispatch(kernel_, work_groups_count_, work_group_size_); } virtual void GetPossibleKernelWorkGroups( @@ -164,7 +164,9 @@ class GPUOperation { bool check_src_channels_size_ = false; protected: - virtual absl::Status BindArguments() { return absl::OkStatus(); } + virtual absl::Status BindArguments(ArgumentsBinder* args) { + return absl::OkStatus(); + } virtual int3 GetGridSize() const; // Defines operation calculation precision and format of src/dst tensors. @@ -172,11 +174,14 @@ class GPUOperation { std::vector src_; std::vector dst_; CLKernel kernel_; + int grid_dimension_ = 3; // can be 1, 2 or 3 + int3 work_group_launch_order_ = int3(0, 1, 2); int3 grid_size_ = int3(0, 0, 0); std::vector src_tensors_names_; std::vector dst_tensors_names_; private: + int3 work_groups_count_ = int3(0, 0, 0); int linkable_count_ = 0; std::string elementwise_code_; // temporary, used during op construction }; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc index a4b01ffad16..c5b659463ea 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.cc @@ -117,12 +117,12 @@ std::string Mean::GetMeanKernelCode(const OperationDef& op_def, return c; } -absl::Status Mean::BindArguments() { +absl::Status Mean::BindArguments(ArgumentsBinder* args) { const double total_size = src_[0]->Width() * src_[0]->Height(); const double size_0 = work_group_size_.x * work_group_size_.y; const double size_1 = total_size / size_0; - RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_1", 1.0 / size_1)); - RETURN_IF_ERROR(args_.SetFloat("inv_multiplier_2", 1.0 / size_0)); + RETURN_IF_ERROR(args->SetFloat("inv_multiplier_1", 1.0 / size_1)); + RETURN_IF_ERROR(args->SetFloat("inv_multiplier_2", 1.0 / size_0)); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h index 12735c0b916..3bf2061d329 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/mean.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/mean.h @@ -37,7 +37,7 @@ class Mean : public GPUOperation { std::vector* work_groups) const override { work_groups->push_back(work_group_size_); } - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc index a0fd699062c..91266ef29a6 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.cc @@ -132,13 +132,13 @@ std::string Resize::GetResizeCode(const OperationDef& op_def, return c; } -absl::Status Resize::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1)); - RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1)); - RETURN_IF_ERROR(args_.SetFloat( +absl::Status Resize::BindArguments(ArgumentsBinder* args) { + RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1)); + RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1)); + RETURN_IF_ERROR(args->SetFloat( "scale_factor_x", CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_))); - RETURN_IF_ERROR(args_.SetFloat( + RETURN_IF_ERROR(args->SetFloat( "scale_factor_y", CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_))); return absl::OkStatus(); @@ -286,17 +286,17 @@ std::string Resize3D::GetResize3DCode(const OperationDef& op_def, return c; } -absl::Status Resize3D::BindArguments() { - RETURN_IF_ERROR(args_.SetInt("border_x", src_[0]->Width() - 1)); - RETURN_IF_ERROR(args_.SetInt("border_y", src_[0]->Height() - 1)); - RETURN_IF_ERROR(args_.SetInt("border_z", src_[0]->Depth() - 1)); - RETURN_IF_ERROR(args_.SetFloat( +absl::Status Resize3D::BindArguments(ArgumentsBinder* args) { + RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1)); + RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1)); + RETURN_IF_ERROR(args->SetInt("border_z", src_[0]->Depth() - 1)); + RETURN_IF_ERROR(args->SetFloat( "scale_factor_x", CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_))); - RETURN_IF_ERROR(args_.SetFloat( + RETURN_IF_ERROR(args->SetFloat( "scale_factor_y", CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_))); - RETURN_IF_ERROR(args_.SetFloat( + RETURN_IF_ERROR(args->SetFloat( "scale_factor_z", CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_))); return absl::OkStatus(); diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h index 0349afe5664..859d750b7e0 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/resize.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/resize.h @@ -27,7 +27,7 @@ namespace cl { class Resize : public GPUOperation { public: - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only @@ -53,7 +53,7 @@ Resize CreateResize(const OperationDef& definition, class Resize3D : public GPUOperation { public: - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc index e7cf72aa72a..d4d0442e61d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.cc @@ -109,14 +109,14 @@ std::string Softmax1x1::GetSoftmaxKernelCode(const OperationDef& op_def) { return c; } -absl::Status Softmax1x1::BindArguments() { +absl::Status Softmax1x1::BindArguments(ArgumentsBinder* args) { float4 mask = GetMaskForLastPlane(src_[0]->Channels()); - RETURN_IF_ERROR(args_.SetFloat("mask_x", mask.x)); - RETURN_IF_ERROR(args_.SetFloat("mask_y", mask.y)); - RETURN_IF_ERROR(args_.SetFloat("mask_z", mask.z)); - RETURN_IF_ERROR(args_.SetFloat("mask_w", mask.w)); + RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x)); + RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y)); + RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z)); + RETURN_IF_ERROR(args->SetFloat("mask_w", mask.w)); RETURN_IF_ERROR( - args_.SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32))); + args->SetInt("slices_x32", DivideRoundUp(src_[0]->Slices(), 32))); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h index 5bc9278d612..202f46d2a51 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/softmax1x1.h @@ -35,7 +35,7 @@ class Softmax1x1 : public GPUOperation { std::vector* work_groups) const override { work_groups->push_back(work_group_size_); } - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD index d5ff93e6845..f601556900c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/BUILD @@ -23,3 +23,30 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:types", ], ) + +cc_library( + name = "fc_fc_add", + srcs = ["fc_fc_add.cc"], + hdrs = ["fc_fc_add.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/cl:arguments", + "//tensorflow/lite/delegates/gpu/cl:buffer", + "//tensorflow/lite/delegates/gpu/cl:cl_kernel", + "//tensorflow/lite/delegates/gpu/cl:device_info", + "//tensorflow/lite/delegates/gpu/cl:linear_storage", + "//tensorflow/lite/delegates/gpu/cl:precision", + "//tensorflow/lite/delegates/gpu/cl:tensor", + "//tensorflow/lite/delegates/gpu/cl:tensor_type", + "//tensorflow/lite/delegates/gpu/cl:texture2d", + "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", + "//tensorflow/lite/delegates/gpu/cl/kernels:tuning_parameters", + "//tensorflow/lite/delegates/gpu/cl/kernels:util", + "//tensorflow/lite/delegates/gpu/common:data_type", + "//tensorflow/lite/delegates/gpu/common:operations", + "//tensorflow/lite/delegates/gpu/common:shape", + "//tensorflow/lite/delegates/gpu/common:tensor", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/memory", + ], +) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc new file mode 100644 index 00000000000..a8d3d434bd9 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.cc @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" +#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor.h" +#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +bool UseBufferForWeights(const DeviceInfo& device_info) { + return device_info.IsAdreno() || device_info.IsAMD() || device_info.IsMali(); +} +} // namespace + +FCFCAdd::FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info) + : GPUOperation(definition) { + if (device_info.IsAdreno()) { + if (device_info.IsAdreno3xx()) { + work_group_size_ = int3(16, 4, 1); + } else if (device_info.IsAdreno4xx()) { + work_group_size_ = int3(32, 4, 1); + } else { + work_group_size_ = int3(32, 4, 1); + } + } else if (device_info.IsIntel()) { + work_group_size_ = int3(8, 4, 1); + } else if (device_info.IsNvidia()) { + work_group_size_ = int3(8, 4, 1); + } else if (device_info.IsPowerVR()) { + work_group_size_ = int3(8, 4, 1); + } else { + work_group_size_ = int3(16, 4, 1); + } + code_ = GetFCFCAddKernelCode(definition_, device_info); +} + +FCFCAdd::FCFCAdd(FCFCAdd&& kernel) : GPUOperation(std::move(kernel)) {} + +FCFCAdd& FCFCAdd::operator=(FCFCAdd&& kernel) { + if (this != &kernel) { + GPUOperation::operator=(std::move(kernel)); + } + return *this; +} + +// We split vec vec dot (every thread do vec vec dot product in basic +// vec mat mult) on 4 parts to create more threads +// tid.y thread process every 4-th element in vec vec dot +// Good results for ~1024 x 1024 sizes, for other can be written more +// optimized shaders + +std::string FCFCAdd::GetFCFCAddKernelCode(const OperationDef& op_def, + const DeviceInfo& device_info) { + AddSrcTensor("src_tensor_0", op_def.src_tensors[0]); + AddSrcTensor("src_tensor_1", op_def.src_tensors[1]); + AddDstTensor("dst_tensor", op_def.dst_tensors[0]); + + const bool weights_are_buffer = UseBufferForWeights(device_info); + + std::string c = GetCommonDefines(op_def.precision); + switch (op_def.precision) { + case CalculationsPrecision::F32: + c += "#define FLT16 float16\n"; + break; + case CalculationsPrecision::F32_F16: + case CalculationsPrecision::F16: + c += "#define FLT16 half16\n"; + break; + } + + c += "#define WG_X " + std::to_string(work_group_size_.x) + "\n"; + c += "#define WG_Y " + std::to_string(work_group_size_.y) + "\n"; + + c += R"(__kernel void main_function($0) { + int gid = get_global_id(0); + int2 tid = (int2)(get_local_id(0), get_local_id(1)); + ACCUM_FLT4 s = (ACCUM_FLT4)(0.0f); + if (gid < args.dst_tensor.Slices()) { + for (int c = tid.y; c < args.src_tensor_0.Slices(); c += WG_Y) { + FLT4 v = args.src_tensor_0.Read(0, 0, c); +)"; + if (weights_are_buffer) { + c += R"(FLT16 w = args.weights0.Read(c * args.dst_tensor.Slices() + gid); + FLT4 partial = v.s0 * w.s0123; + partial = mad(v.s1, w.s4567, partial); + partial = mad(v.s2, w.s89ab, partial); + partial = mad(v.s3, w.scdef, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } else { + c += R"(FLT4 w0 = args.weights0.Read(c * 4 + 0, gid); + FLT4 w1 = args.weights0.Read(c * 4 + 1, gid); + FLT4 w2 = args.weights0.Read(c * 4 + 2, gid); + FLT4 w3 = args.weights0.Read(c * 4 + 3, gid); + FLT4 partial = v.s0 * w0; + partial = mad(v.s1, w1, partial); + partial = mad(v.s2, w2, partial); + partial = mad(v.s3, w3, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } + c += R"( } + for (int c = tid.y; c < args.src_tensor_1.Slices(); c += WG_Y) { + FLT4 v = args.src_tensor_1.Read(0, 0, c); + )"; + if (weights_are_buffer) { + c += R"(FLT16 w = args.weights1.Read(c * args.dst_tensor.Slices() + gid); + FLT4 partial = v.s0 * w.s0123; + partial = mad(v.s1, w.s4567, partial); + partial = mad(v.s2, w.s89ab, partial); + partial = mad(v.s3, w.scdef, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } else { + c += R"(FLT4 w0 = args.weights1.Read(c * 4 + 0, gid); + FLT4 w1 = args.weights1.Read(c * 4 + 1, gid); + FLT4 w2 = args.weights1.Read(c * 4 + 2, gid); + FLT4 w3 = args.weights1.Read(c * 4 + 3, gid); + FLT4 partial = v.s0 * w0; + partial = mad(v.s1, w1, partial); + partial = mad(v.s2, w2, partial); + partial = mad(v.s3, w3, partial); + s += TO_ACCUM_TYPE(partial); +)"; + } + c += R"( } + } + __local ACCUM_FLT4 temp[WG_X][WG_Y]; + temp[tid.x][tid.y] = s; + barrier(CLK_LOCAL_MEM_FENCE); + if (gid >= args.dst_tensor.Slices()) { + return; + } + if (tid.y == 0) { +)"; + for (int i = 1; i < work_group_size_.y; ++i) { + c += " s += temp[tid.x][" + std::to_string(i) + "];\n"; + } + c += + R"( FLT4 r0 = TO_FLT4(s) + args.biases0.Read(gid) + args.biases1.Read(gid); + args.dst_tensor.Write(r0, 0, 0, gid); + } +})"; + + return c; +} + +int3 FCFCAdd::GetGridSize() const { return int3(dst_[0]->Slices(), 1, 1); } + +FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info, + const OperationDef& definition, + const FullyConnectedAttributes& attr0, + const FullyConnectedAttributes& attr1) { + FCFCAdd result(definition, device_info); + result.UploadWeights(attr0.weights, "weights0", + UseBufferForWeights(device_info)); + result.UploadWeights(attr1.weights, "weights1", + UseBufferForWeights(device_info)); + + TensorLinearDescriptor desc0; + desc0.storage_type = LinearStorageType::TEXTURE_2D; + desc0.element_type = definition.GetDataType(); + desc0.UploadLinearData(attr0.bias); + result.args_.AddObject( + "biases0", absl::make_unique(std::move(desc0))); + + TensorLinearDescriptor desc1; + desc1.storage_type = LinearStorageType::TEXTURE_2D; + desc1.element_type = definition.GetDataType(); + desc1.UploadLinearData(attr1.bias); + result.args_.AddObject( + "biases1", absl::make_unique(std::move(desc1))); + + return result; +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h new file mode 100644 index 00000000000..fea9d1a4990 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h @@ -0,0 +1,189 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_ + +#include + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" +#include "tensorflow/lite/delegates/gpu/cl/buffer.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h" +#include "tensorflow/lite/delegates/gpu/cl/device_info.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/tuning_parameters.h" +#include "tensorflow/lite/delegates/gpu/cl/precision.h" +#include "tensorflow/lite/delegates/gpu/cl/texture2d.h" +#include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/common/tensor.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace cl { + +template +void RearrangeFCWeightsToIOO4I4(const tflite::gpu::Tensor& weights, + S* dst) { + const int src_channels = weights.shape.i; + const int padded_src_channels = AlignByN(src_channels, 4); + const int dst_channels = weights.shape.o; + const int padded_dst_channels = AlignByN(dst_channels, 4); + + for (int block_y = 0; 4 * block_y < padded_dst_channels; block_y++) { + for (int y_in_block = 0; y_in_block < 4; y_in_block++) { + for (int block_x = 0; 4 * block_x < padded_src_channels; block_x++) { + for (int x_in_block = 0; x_in_block < 4; x_in_block++) { + int y = 4 * block_y + y_in_block; + int x = 4 * block_x + x_in_block; + int dst_index = block_x * padded_dst_channels * 4 + block_y * 16 + + x_in_block * 4 + y_in_block; + if (x < src_channels && y < dst_channels) { + dst[dst_index] = weights.data[src_channels * y + x]; + } else { + dst[dst_index] = 0.0f; + } + } + } + } + } +} + +template +void RearrangeFCWeightsToOIO4I4(const tflite::gpu::Tensor& weights, + S* dst) { + const int src_channels = weights.shape.i; + const int src_depth = DivideRoundUp(src_channels, 4); + const int dst_channels = weights.shape.o; + const int dst_depth = DivideRoundUp(dst_channels, 4); + + int counter = 0; + for (int d = 0; d < dst_depth; ++d) { + for (int s = 0; s < src_depth; ++s) { + for (int i = 0; i < 4; ++i) { + const int src_ch = s * 4 + i; + for (int j = 0; j < 4; ++j) { + const int dst_ch = d * 4 + j; + if (src_ch < src_channels && dst_ch < dst_channels) { + dst[counter++] = weights.data[dst_ch * src_channels + src_ch]; + } else { + dst[counter++] = 0.0f; + } + } + } + } + } +} + +class FCFCAdd : public GPUOperation { + public: + FCFCAdd() = default; + void GetPossibleKernelWorkGroups( + TuningType tuning_type, const DeviceInfo& device_info, + const KernelInfo& kernel_info, + std::vector* work_groups) const override { + work_groups->push_back(work_group_size_); + } + int3 GetGridSize() const override; + + // Move only + FCFCAdd(FCFCAdd&& kernel); + FCFCAdd& operator=(FCFCAdd&& kernel); + FCFCAdd(const FCFCAdd&) = delete; + FCFCAdd& operator=(const FCFCAdd&) = delete; + + private: + FCFCAdd(const OperationDef& definition, const DeviceInfo& device_info); + friend FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info, + const OperationDef& definition, + const FullyConnectedAttributes& attr0, + const FullyConnectedAttributes& attr1); + + template + void UploadWeights(const tflite::gpu::Tensor& weights, + const std::string& name, bool weights_are_buffer); + + std::string GetFCFCAddKernelCode(const OperationDef& op_def, + const DeviceInfo& device_info); +}; + +template +void FCFCAdd::UploadWeights(const tflite::gpu::Tensor& weights, + const std::string& name, bool weights_are_buffer) { + const int src_depth = DivideRoundUp(weights.shape.i, 4); + const int dst_depth = DivideRoundUp(weights.shape.o, 4); + + const int elements_count = src_depth * dst_depth * 4; + const bool f32_weights = definition_.precision == CalculationsPrecision::F32; + + const int float4_size = f32_weights ? 16 : 8; + + if (weights_are_buffer) { + BufferDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + desc.element_size = 16; + desc.size = float4_size * elements_count; + desc.data.resize(desc.size); + + if (f32_weights) { + float* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToIOO4I4(weights, ptr); + } else { + half* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToIOO4I4(weights, ptr); + } + + args_.AddObject(name, absl::make_unique(std::move(desc))); + } else { + Texture2DDescriptor desc; + desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16; + // desc.element_type = DataType::UINT8; + // desc.normalized = true; + // desc.normalized_type = f32_weights ? DataType::FLOAT32 : + // DataType::FLOAT16; + desc.size = int2(src_depth * 4, dst_depth); + desc.data.resize(float4_size * elements_count); + + if (f32_weights) { + float* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToOIO4I4(weights, ptr); + } else { + half* ptr = reinterpret_cast(desc.data.data()); + RearrangeFCWeightsToOIO4I4(weights, ptr); + } + + args_.AddObject(name, + absl::make_unique(std::move(desc))); + } +} + +FCFCAdd CreateFCFCAdd(const DeviceInfo& device_info, + const OperationDef& definition, + const FullyConnectedAttributes& attr0, + const FullyConnectedAttributes& attr1); + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_SPECIAL_FC_FC_ADD_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc index b2ce0690a9c..1f8f985f3ee 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.cc @@ -154,17 +154,17 @@ std::string StridedSlice::GetStridedSliceCode(const OperationDef& op_def, return c; } -absl::Status StridedSlice::BindArguments() { +absl::Status StridedSlice::BindArguments(ArgumentsBinder* args) { int4 offset = GetOffset(attributes_, src_[0]->Width(), src_[0]->Height(), src_[0]->Channels(), src_[0]->Batch()); - RETURN_IF_ERROR(args_.SetInt("offset_x", offset.x)); - RETURN_IF_ERROR(args_.SetInt("offset_y", offset.y)); - RETURN_IF_ERROR(args_.SetInt("offset_z", offset.z)); - RETURN_IF_ERROR(args_.SetInt("offset_b", offset.w)); - RETURN_IF_ERROR(args_.SetInt("stride_x", attributes_.strides.w)); - RETURN_IF_ERROR(args_.SetInt("stride_y", attributes_.strides.h)); - RETURN_IF_ERROR(args_.SetInt("stride_z", attributes_.strides.c)); - RETURN_IF_ERROR(args_.SetInt("stride_b", attributes_.strides.b)); + RETURN_IF_ERROR(args->SetInt("offset_x", offset.x)); + RETURN_IF_ERROR(args->SetInt("offset_y", offset.y)); + RETURN_IF_ERROR(args->SetInt("offset_z", offset.z)); + RETURN_IF_ERROR(args->SetInt("offset_b", offset.w)); + RETURN_IF_ERROR(args->SetInt("stride_x", attributes_.strides.w)); + RETURN_IF_ERROR(args->SetInt("stride_y", attributes_.strides.h)); + RETURN_IF_ERROR(args->SetInt("stride_z", attributes_.strides.c)); + RETURN_IF_ERROR(args->SetInt("stride_b", attributes_.strides.b)); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h index 5a6d8ad6047..dddff2faf35 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/strided_slice.h @@ -27,7 +27,7 @@ namespace cl { class StridedSlice : public GPUOperation { public: StridedSlice(const OperationDef& definition, const SliceAttributes& attr); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; // Move only diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc index f0e0c412b7e..25fa60c776a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/util.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.cc @@ -188,6 +188,14 @@ int GetRecommendedBlockSizeForConv(const DeviceInfo& device_info, return block_size; } +int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size) { + int3 work_groups_count; + work_groups_count.x = DivideRoundUp(grid_size.x, work_group_size.x); + work_groups_count.y = DivideRoundUp(grid_size.y, work_group_size.y); + work_groups_count.z = DivideRoundUp(grid_size.z, work_group_size.z); + return work_groups_count; +} + } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/util.h b/tensorflow/lite/delegates/gpu/cl/kernels/util.h index 093e9b83292..69f6808146c 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/util.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/util.h @@ -213,6 +213,8 @@ int3 GetFirstSuitableWorkGroup(const std::vector& wgs, int max_wg_size); int GetRecommendedBlockSizeForConv(const DeviceInfo& device, CalculationsPrecision precision, int task_size); + +int3 GetWorkGroupsCount(const int3& grid_size, const int3& work_group_size); } // namespace cl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc index 0f94847f08a..1244f769b48 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.cc @@ -262,16 +262,16 @@ int3 Winograd4x4To36::SelectBestWorkGroup(const KernelInfo& kernel_info) const { return GetFirstSuitableWorkGroup(wgs, kernel_info.max_work_group_size); } -absl::Status Winograd4x4To36::BindArguments() { +absl::Status Winograd4x4To36::BindArguments(ArgumentsBinder* args) { const int tiles_x = DivideRoundUp( src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2, 4); const int tiles_y = DivideRoundUp( src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4); const int tiles_total = tiles_x * tiles_y; - RETURN_IF_ERROR(args_.SetInt("padding_x", -padding_.prepended.w)); - RETURN_IF_ERROR(args_.SetInt("padding_y", -padding_.prepended.h)); - RETURN_IF_ERROR(args_.SetInt("tiles_total", tiles_total)); - RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x)); + RETURN_IF_ERROR(args->SetInt("padding_x", -padding_.prepended.w)); + RETURN_IF_ERROR(args->SetInt("padding_y", -padding_.prepended.h)); + RETURN_IF_ERROR(args->SetInt("tiles_total", tiles_total)); + RETURN_IF_ERROR(args->SetInt("tiles_x", tiles_x)); return absl::OkStatus(); } @@ -463,9 +463,9 @@ int3 Winograd36To4x4::SelectBestWorkGroup(const KernelInfo& kernel_info) const { return GetFirstSuitableWorkGroup(wgs, kernel_info.max_work_group_size); } -absl::Status Winograd36To4x4::BindArguments() { +absl::Status Winograd36To4x4::BindArguments(ArgumentsBinder* args) { const int tiles_x = DivideRoundUp(dst_[0]->Width(), 4); - RETURN_IF_ERROR(args_.SetInt("tiles_x", tiles_x)); + RETURN_IF_ERROR(args->SetInt("tiles_x", tiles_x)); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h index a5da49e7939..609e38a4c9a 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd.h @@ -36,7 +36,7 @@ class Winograd4x4To36 : public GPUOperation { Winograd4x4To36() = default; Winograd4x4To36(const OperationDef& definition, const Padding2D& padding, const DeviceInfo& device_info); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; void GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, @@ -73,7 +73,7 @@ class Winograd36To4x4 : public GPUOperation { Winograd36To4x4() = default; Winograd36To4x4(const OperationDef& definition, const DeviceInfo& device_info); - absl::Status BindArguments() override; + absl::Status BindArguments(ArgumentsBinder* args) override; int3 GetGridSize() const override; void GetPossibleKernelWorkGroups( TuningType tuning_type, const DeviceInfo& device_info, diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc index bf2fd449291..add0e2fd4e9 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc @@ -36,7 +36,7 @@ namespace cl { #ifdef __ANDROID__ #define LoadFunction(function) \ - if (is_pixel) { \ + if (use_wrapper) { \ function = reinterpret_cast(loadOpenCLPointer(#function)); \ } else { \ function = reinterpret_cast(dlsym(libopencl, #function)); \ @@ -53,7 +53,7 @@ namespace cl { #ifdef __WINDOWS__ void LoadOpenCLFunctions(HMODULE libopencl); #else -void LoadOpenCLFunctions(void* libopencl, bool is_pixel); +void LoadOpenCLFunctions(void* libopencl, bool use_wrapper); #endif absl::Status LoadOpenCL() { @@ -77,8 +77,11 @@ absl::Status LoadOpenCL() { // record error std::string error(dlerror()); #ifdef __ANDROID__ - // Pixel phone? + // Pixel phone or auto? libopencl = dlopen("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); + if (!libopencl) { + libopencl = dlopen("libOpenCL-car.so", RTLD_NOW | RTLD_LOCAL); + } if (libopencl) { typedef void (*enableOpenCL_t)(); enableOpenCL_t enableOpenCL = @@ -96,11 +99,11 @@ absl::Status LoadOpenCL() { #ifdef __WINDOWS__ void LoadOpenCLFunctions(HMODULE libopencl) { #else -void LoadOpenCLFunctions(void* libopencl, bool is_pixel) { +void LoadOpenCLFunctions(void* libopencl, bool use_wrapper) { #ifdef __ANDROID__ typedef void* (*loadOpenCLPointer_t)(const char* name); loadOpenCLPointer_t loadOpenCLPointer; - if (is_pixel) { + if (use_wrapper) { loadOpenCLPointer = reinterpret_cast( dlsym(libopencl, "loadOpenCLPointer")); } diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD index 520145d7d5b..8a22741f013 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/selectors/BUILD @@ -14,7 +14,6 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:conv_common", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_constants", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr", - "//tensorflow/lite/delegates/gpu/cl/kernels:conv_texture", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_weights_converter", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:work_group_picking", @@ -82,7 +81,6 @@ cc_library( deps = [ "//tensorflow/lite/delegates/gpu/cl/kernels:conv_buffer_1x1", "//tensorflow/lite/delegates/gpu/cl/kernels:conv_powervr", - "//tensorflow/lite/delegates/gpu/cl/kernels:conv_texture", "//tensorflow/lite/delegates/gpu/cl/kernels:fully_connected", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/common:operations", @@ -132,6 +130,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl/kernels:add", "//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy", "//tensorflow/lite/delegates/gpu/cl/kernels:concat_z", + "//tensorflow/lite/delegates/gpu/cl/kernels:depthwise_conv", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:lstm", "//tensorflow/lite/delegates/gpu/cl/kernels:max_unpooling", @@ -167,6 +166,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/cl:tensor_type", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels/special:depthwise_conv_plus_1x1_conv", + "//tensorflow/lite/delegates/gpu/cl/kernels/special:fc_fc_add", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:operations", diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc index 7fa7978034b..a3282f05200 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/convolution_selector.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_constants.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_weights_converter.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" @@ -38,8 +37,8 @@ std::unique_ptr SelectConvolutionAdreno( GPUOperation conv = CreateConvConstants(device_info, op_def, attr); return absl::make_unique(std::move(conv)); } else { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } } @@ -47,8 +46,9 @@ std::unique_ptr SelectConvolutionWinogradAdreno( const Convolution2DAttributes& attr, const BHWC& dst_shape, const DeviceInfo& device_info, const OperationDef& op_def, ModelHints hints) { - ConvTexture conv = CreateConvTextureWino4x4To6x6(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + ConvPowerVR conv = + CreateConvPowerVRWino4x4To6x6(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } std::unique_ptr SelectConvolutionDynamicWeightsAdreno( diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc index 24c48d52f2a..6c6ee044cdd 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_buffer_1x1.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/conv_powervr.h" -#include "tensorflow/lite/delegates/gpu/cl/kernels/conv_texture.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/fully_connected.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -31,8 +30,9 @@ std::unique_ptr SelectFullyConnectedGeneric( const FullyConnectedAttributes& attr, const DeviceInfo& device_info, const OperationDef& op_def, int batch_size) { if (op_def.IsBatchSupported()) { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o); + ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } else { FullyConnected fc = CreateFullyConnected(device_info, op_def, attr); return absl::make_unique(std::move(fc)); @@ -43,8 +43,9 @@ std::unique_ptr SelectFullyConnectedAdreno( const FullyConnectedAttributes& attr, const DeviceInfo& device_info, const OperationDef& op_def, int batch_size) { if (op_def.IsBatchSupported()) { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o); + ConvPowerVR conv = CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } else { FullyConnected fc = CreateFullyConnected(device_info, op_def, attr); return absl::make_unique(std::move(fc)); @@ -71,8 +72,10 @@ std::unique_ptr SelectFullyConnectedMali( ConvBuffer1x1 conv = CreateConvBuffer1x1(device_info, op_def, attr); return absl::make_unique(std::move(conv)); } else { - ConvTexture conv = CreateConvTexture(device_info, op_def, attr); - return absl::make_unique(std::move(conv)); + BHWC dst_shape = BHWC(batch_size, 1, 1, attr.weights.shape.o); + ConvPowerVR conv = + CreateConvPowerVR(device_info, op_def, attr, &dst_shape); + return absl::make_unique(std::move(conv)); } } else { FullyConnected fc = CreateFullyConnected(device_info, op_def, attr); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc index 9b0a169e5b4..f7981fc67bb 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.cc @@ -315,7 +315,16 @@ absl::Status GPUOperationFromNode(const DeviceInfo& device_info, case OperationType::DEPTHWISE_CONVOLUTION: { auto attr = absl::any_cast( node.operation.attributes); - *gpu_op = SelectDWConvolution(attr, device_info, op_def); + if (inputs.size() == 1) { + *gpu_op = SelectDWConvolution(attr, device_info, op_def); + } else { + if (inputs[1]->tensor.shape.b != 1) { + return absl::UnimplementedError( + "No support of depthwise runtime weights with channel multiplier " + "!= 1"); + } + *gpu_op = SelectDWConvolutionDynamicWeights(attr, device_info, op_def); + } return absl::OkStatus(); } case OperationType::FULLY_CONNECTED: { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc index 4dbb1ffd734..713892f9902 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/kernels/add.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/depthwise_conv.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/max_unpooling.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/mean.h" @@ -110,6 +111,13 @@ absl::Status SelectConcat(const ConcatAttributes& attr, } } +std::unique_ptr SelectDWConvolutionDynamicWeights( + const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info, + const OperationDef& op_def) { + return absl::make_unique( + CreateDepthwiseConvolution2DDynamicWeights(device_info, op_def, attr)); +} + void SelectReshape(int src_channels, int dst_channels, const OperationDef& op_def, std::unique_ptr* ptr) { diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h index c6c604da982..084298442e3 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/simple_selectors.h @@ -57,6 +57,10 @@ absl::Status SelectConcat(const ConcatAttributes& attr, const DeviceInfo& device_info, std::unique_ptr* ptr); +std::unique_ptr SelectDWConvolutionDynamicWeights( + const DepthwiseConvolution2DAttributes& attr, const DeviceInfo& device_info, + const OperationDef& op_def); + void SelectReshape(int src_channels, int dst_channels, const OperationDef& op_def, std::unique_ptr* ptr); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc index 31480f231b0..631eabc4569 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/cl/cl_device.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h" +#include "tensorflow/lite/delegates/gpu/cl/kernels/special/fc_fc_add.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" @@ -39,6 +40,10 @@ absl::Status TryDepthwiseConvPlus1x1Conv( OperationType::DEPTHWISE_CONVOLUTION) { return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable."); } + auto dw_inputs = graph.FindInputs(dw_node->id); + if (dw_inputs.size() != 1) { + return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable."); + } auto dw_outputs = graph.FindOutputs(dw_node->id); auto consumers = graph.FindConsumers(dw_outputs[0]->id); if (consumers.size() != 1) { @@ -59,7 +64,6 @@ absl::Status TryDepthwiseConvPlus1x1Conv( dw_node->operation.attributes); auto conv_attr = absl::any_cast(conv_node->operation.attributes); - auto dw_inputs = graph.FindInputs(dw_node->id); auto conv_outputs = graph.FindOutputs(conv_node->id); OperationDef op_def; op_def.precision = precision; @@ -82,22 +86,108 @@ absl::Status TryDepthwiseConvPlus1x1Conv( consumed_nodes->insert(conv_node->id); return absl::OkStatus(); } + +// fully connected + fully connected + add +absl::Status TryFCFCAdd( + const DeviceInfo& device_info, CalculationsPrecision precision, + const GraphFloat32& graph, NodeId first_node_id, + const std::map& tensor_descriptors, + std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) { + auto* fc0_node = graph.GetNode(first_node_id); + if (OperationTypeFromString(fc0_node->operation.type) != + OperationType::FULLY_CONNECTED) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc0_inputs = graph.FindInputs(fc0_node->id); + if (fc0_inputs.size() != 1) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc0_output_id = graph.FindOutputs(fc0_node->id)[0]->id; + auto consumers = graph.FindConsumers(fc0_output_id); + if (consumers.size() != 1) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto* add_node = consumers[0]; + if (consumed_nodes->find(add_node->id) != consumed_nodes->end()) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + if (OperationTypeFromString(add_node->operation.type) != OperationType::ADD) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto add_inputs = graph.FindInputs(add_node->id); + if (add_inputs.size() != 2) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc1_output_id = add_inputs[0]->id + add_inputs[1]->id - fc0_output_id; + auto* fc1_node = graph.FindProducer(fc1_output_id); + if (OperationTypeFromString(fc1_node->operation.type) != + OperationType::FULLY_CONNECTED) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + if (consumed_nodes->find(fc1_node->id) != consumed_nodes->end()) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc1_inputs = graph.FindInputs(fc1_node->id); + if (fc1_inputs.size() != 1) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto fc0_attr = + absl::any_cast(fc0_node->operation.attributes); + auto fc1_attr = + absl::any_cast(fc1_node->operation.attributes); + if (fc0_attr.weights.shape.o != fc1_attr.weights.shape.o) { + return absl::NotFoundError("FCFCAdd not suitable."); + } + auto add_outputs = graph.FindOutputs(add_node->id); + + OperationDef op_def; + op_def.precision = precision; + auto it = tensor_descriptors.find(fc0_inputs[0]->id); + if (it != tensor_descriptors.end()) { + op_def.src_tensors.push_back(it->second); + } + it = tensor_descriptors.find(fc1_inputs[0]->id); + if (it != tensor_descriptors.end()) { + op_def.src_tensors.push_back(it->second); + } + it = tensor_descriptors.find(add_outputs[0]->id); + if (it != tensor_descriptors.end()) { + op_def.dst_tensors.push_back(it->second); + } + + for (int i = 0; i < fc1_inputs.size(); ++i) { + fc0_inputs.push_back(fc1_inputs[i]); + } + std::unique_ptr* gpu_op = + InitSingleOpSubgraph(fc0_inputs, add_outputs, gpu_subgraph); + FCFCAdd fc = CreateFCFCAdd(device_info, op_def, fc0_attr, fc1_attr); + *gpu_op = absl::make_unique(std::move(fc)); + consumed_nodes->insert(fc0_node->id); + consumed_nodes->insert(fc1_node->id); + consumed_nodes->insert(add_node->id); + return absl::OkStatus(); +} } // namespace absl::Status GPUSubgraphFromGraph( const DeviceInfo& device_info, CalculationsPrecision precision, const GraphFloat32& graph, NodeId first_node_id, const std::map& tensor_descriptors, - std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) { - if (!device_info.IsNvidia()) { - return absl::NotFoundError( - "Experimental feature, enabled for NVidia only, but device is not " - "nvidia gpu."); - } - if (TryDepthwiseConvPlus1x1Conv(precision, graph, first_node_id, + std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph, + std::string* name) { + if ((device_info.IsAdreno() || device_info.IsNvidia()) && + TryDepthwiseConvPlus1x1Conv(precision, graph, first_node_id, tensor_descriptors, consumed_nodes, gpu_subgraph) .ok()) { + *name = "depthwise_conv_plus_1x1_conv"; + return absl::OkStatus(); + } + if ((device_info.IsIntel() || device_info.IsNvidia()) && + TryFCFCAdd(device_info, precision, graph, first_node_id, + tensor_descriptors, consumed_nodes, gpu_subgraph) + .ok()) { + *name = "fully_connected_x2_and_add"; return absl::OkStatus(); } return absl::NotFoundError("No special combination."); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h index 3ea99b2515a..6091415e14c 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h +++ b/tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h @@ -34,7 +34,8 @@ absl::Status GPUSubgraphFromGraph( const DeviceInfo& device_info, CalculationsPrecision precision, const GraphFloat32& graph, NodeId first_node_id, const std::map& tensor_descriptors, - std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph); + std::set* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph, + std::string* name); } // namespace cl } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/cl/testing/BUILD b/tensorflow/lite/delegates/gpu/cl/testing/BUILD index c82190ca0e6..a14dfd72cfd 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/testing/BUILD @@ -3,20 +3,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -cc_binary( - name = "performance_profiling", - srcs = ["performance_profiling.cc"], - deps = [ - "//tensorflow/lite/delegates/gpu/cl:environment", - "//tensorflow/lite/delegates/gpu/cl:inference_context", - "//tensorflow/lite/delegates/gpu/common:model", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/time", - ], -) - cc_binary( name = "delegate_testing", srcs = ["delegate_testing.cc"], @@ -34,3 +20,38 @@ cc_binary( "@com_google_absl//absl/time", ], ) + +cc_binary( + name = "internal_api_samples", + srcs = ["internal_api_samples.cc"], + tags = [ + "nobuilder", + "notap", + ], + deps = [ + "//tensorflow/lite/delegates/gpu:api", + "//tensorflow/lite/delegates/gpu/cl:api", + "//tensorflow/lite/delegates/gpu/cl:environment", + "//tensorflow/lite/delegates/gpu/cl:inference_context", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "performance_profiling", + srcs = ["performance_profiling.cc"], + deps = [ + "//tensorflow/lite/delegates/gpu/cl:environment", + "//tensorflow/lite/delegates/gpu/cl:inference_context", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", + "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc new file mode 100644 index 00000000000..3e9b614c8c4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc @@ -0,0 +1,224 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include // NOLINT(build/c++11) +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/lite/delegates/gpu/api.h" +#include "tensorflow/lite/delegates/gpu/cl/api.h" +#include "tensorflow/lite/delegates/gpu/cl/environment.h" +#include "tensorflow/lite/delegates/gpu/cl/inference_context.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/register.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +void FillInputTensors(tflite::Interpreter* interpreter) { + for (int k = 0; k < interpreter->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = interpreter->tensor(interpreter->inputs()[k]); + const auto tensor_elements_count = tflite::NumElements(tensor_ptr); + if (tensor_ptr->type == kTfLiteFloat32) { + float* p = interpreter->typed_input_tensor(k); + for (int i = 0; i < tensor_elements_count; ++i) { + p[i] = std::sin(i); + } + } else { + std::cout << "No support of non Float32 input/output tensors" + << std::endl; + } + } +} + +void CompareCPUGPUResults(tflite::Interpreter* cpu, + const std::vector& outputs, + const std::vector>& gpu, + float eps) { + for (int i = 0; i < gpu.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu->tensor(outputs[i]); + const float* cpu_out = tensor_ptr->data.f; + const float* gpu_out = gpu[i].data(); + const int kMaxPrint = 10; + int printed = 0; + int total_different = 0; + for (int k = 0; k < tensor_ptr->bytes / 4; ++k) { + const float abs_diff = fabs(cpu_out[k] - gpu_out[k]); + if (abs_diff > eps) { + total_different++; + if (printed < kMaxPrint) { + std::cout << "Output #" << i << ": element #" << k << ": CPU value - " + << cpu_out[k] << ", GPU value - " << gpu_out[k] + << ", abs diff - " << abs_diff << std::endl; + printed++; + } + if (printed == kMaxPrint) { + std::cout << "Printed " << kMaxPrint + << " different elements, threshhold - " << eps + << ", next different elements skipped" << std::endl; + printed++; + } + } + } + std::cout << "Total " << total_different + << " different elements, for output #" << i << ", threshhold - " + << eps << std::endl; + } +} +} // namespace + +// Run Jet with OpenCL internal API and compares correctness with TFLite CPU +absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) { + auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); + + ops::builtin::BuiltinOpResolver op_resolver; + InterpreterBuilder tfl_builder(*flatbuffer, op_resolver); + + // CPU. + std::unique_ptr cpu_inference; + tfl_builder(&cpu_inference); + if (!cpu_inference) { + return absl::InternalError("Failed to build CPU inference."); + } + auto status = cpu_inference->AllocateTensors(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to AllocateTensors for CPU inference."); + } + for (int k = 0; k < cpu_inference->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->inputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 input tensors"); + } + } + for (int k = 0; k < cpu_inference->outputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->outputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 output tensors"); + } + } + FillInputTensors(cpu_inference.get()); + status = cpu_inference->Invoke(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to Invoke CPU inference."); + } + + GraphFloat32 graph_cl; + RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl)); + + auto inputs = graph_cl.inputs(); + auto outputs = graph_cl.outputs(); + std::vector in_refs(inputs.size()); + std::vector out_refs(outputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + in_refs[i] = inputs[i]->tensor.ref; + } + for (int i = 0; i < outputs.size(); ++i) { + out_refs[i] = outputs[i]->tensor.ref; + } + + Environment env; + RETURN_IF_ERROR(CreateEnvironment(&env)); + + std::unique_ptr inf_env; + // Initializes environment. + InferenceEnvironmentOptions env_options; + env_options.device = env.device().id(); + env_options.context = env.context().context(); + env_options.command_queue = env.queue()->queue(); + RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr)); + + std::unique_ptr builder; + // Initializes builder. + InferenceOptions options; + options.priority1 = InferencePriority::MIN_LATENCY; + options.priority2 = InferencePriority::MIN_MEMORY_USAGE; + options.priority3 = InferencePriority::MAX_PRECISION; + options.usage = InferenceUsage::SUSTAINED_SPEED; + RETURN_IF_ERROR( + inf_env->NewInferenceBuilder(options, std::move(graph_cl), &builder)); + + // Sets input/output object def for builder_. + ObjectDef obj_def; + obj_def.data_type = DataType::FLOAT32; + obj_def.data_layout = DataLayout::BHWC; + obj_def.object_type = ObjectType::CPU_MEMORY; + obj_def.user_provided = true; + for (int i = 0; i < in_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def)); + } + for (int i = 0; i < out_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def)); + } + + std::unique_ptr<::tflite::gpu::InferenceRunner> runner; + // Builds runner. + RETURN_IF_ERROR(builder->Build(&runner)); + + // Sets the input/output object. + for (int i = 0; i < in_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]); + RETURN_IF_ERROR(runner->SetInputObject( + i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes})); + } + + std::vector> output_tensors(out_refs.size()); + for (int i = 0; i < out_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(out_refs[i]); + output_tensors[i].resize(tensor_ptr->bytes / 4); + RETURN_IF_ERROR(runner->SetOutputObject( + i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes})); + } + + RETURN_IF_ERROR(runner->Run()); + + CompareCPUGPUResults(cpu_inference.get(), out_refs, output_tensors, 1e-4f); + + return absl::OkStatus(); +} + +} // namespace cl +} // namespace gpu +} // namespace tflite + +int main(int argc, char** argv) { + if (argc <= 1) { + std::cerr << "Expected model path as second argument."; + return -1; + } + + auto load_status = tflite::gpu::cl::LoadOpenCL(); + if (!load_status.ok()) { + std::cerr << load_status.message(); + return -1; + } + + auto run_status = tflite::gpu::cl::RunModelSampleWithInternalAPI(argv[1]); + if (!run_status.ok()) { + std::cerr << run_status.message(); + return -1; + } + + return EXIT_SUCCESS; +} diff --git a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc index ab2e52f14ed..540004ad746 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc +++ b/tensorflow/lite/delegates/gpu/cl/testing/performance_profiling.cc @@ -43,7 +43,7 @@ absl::Status RunModelSample(const std::string& model_name) { create_info.precision = env.IsSupported(CalculationsPrecision::F16) ? CalculationsPrecision::F16 : CalculationsPrecision::F32; - create_info.storage_type = GetFastestStorageType(env.device()); + create_info.storage_type = GetFastestStorageType(env.device().GetInfo()); create_info.hints.Add(ModelHints::kAllowSpecialKernels); std::cout << "Precision: " << ToString(create_info.precision) << std::endl; std::cout << "Storage type: " << ToString(create_info.storage_type) diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh new file mode 100755 index 00000000000..21900c55875 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh @@ -0,0 +1,101 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +shopt -s expand_aliases # to work with commands aliases in .sh + +description="Example of intetrnal api usage: +How to use: +[-h or --help, print instructions] +[-m or --model_path, path to the model in .tflite format] +[-d or --device, select device](optional, if you have few connected devices)" + +model_path="" +alias ADB='adb' +host="" + +while [[ "$1" != "" ]]; do + case $1 in + -m | --model_path) + shift + model_path=$1 + ;; + -d | --device) + shift + if [[ "$1" == "HOST" ]] + then + host="HOST" + fi + alias ADB='adb -s '$1'' + ;; + -h | --help) + echo "$description" + exit + ;; + esac + shift +done + +if [ "$model_path" = "" ] +then +echo "No model provided." +echo "$description" +exit +fi + +SHELL_DIR=$(dirname "$0") +BINARY_NAME=internal_api_samples + +if [[ "$host" == "HOST" ]] +then +bazel build -c opt --copt -DCL_DELEGATE_NO_GL //"$SHELL_DIR":"$BINARY_NAME" +chmod +x bazel-bin/"$SHELL_DIR"/"$BINARY_NAME" +./bazel-bin/"$SHELL_DIR"/"$BINARY_NAME" "$model_path" +exit +fi + +model_name=${model_path##*/} # finds last token after '/' + +OPENCL_DIR=/data/local/tmp/internal_api_samples/ + +ADB shell mkdir -p $OPENCL_DIR + +ADB push "$model_path" "$OPENCL_DIR" + +declare -a BUILD_CONFIG +abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r') +if [[ "$abi_version" == "armeabi-v7a" ]]; then +#"32 bit ARM" +BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie ) +elif [[ "$abi_version" == "arm64-v8a" ]]; then +#"64 bit ARM" +BUILD_CONFIG=( --config=android_arm64 -c opt ) +elif [[ "$abi_version" == "x86_64" ]]; then +# x86_64 +BUILD_CONFIG=( --config=android_x86_64 -c opt ) +else +echo "Error: Unknown processor ABI" +exit 1 +fi + +bazel build "${BUILD_CONFIG[@]}" --copt -DCL_DELEGATE_NO_GL //$SHELL_DIR:$BINARY_NAME + +ADB push bazel-bin/$SHELL_DIR/$BINARY_NAME $OPENCL_DIR + +ADB shell chmod +x $OPENCL_DIR/$BINARY_NAME +ADB shell "cd $OPENCL_DIR && ./$BINARY_NAME $model_name" + +# clean up files from device +ADB shell rm -rf $OPENCL_DIR diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index ba7b897cab7..ebb5d628cdc 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -505,54 +505,28 @@ class Conv2DOperationParser : public TFLiteOperationParser { } }; -// Custom op version of TRANSPOSE_CONV. -class Convolution2DTransposeBiasParser : public TFLiteOperationParser { - public: - absl::Status IsSupported(const TfLiteContext* context, - const TfLiteNode* tflite_node, - const TfLiteRegistration* registration) final { - RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); - const TfLiteTransposeConvParams* tf_options; - RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); - RETURN_IF_ERROR( - CheckStrides(tf_options->stride_height, tf_options->stride_width)); - return absl::OkStatus(); - } - - absl::Status Parse(const TfLiteNode* tflite_node, - const TfLiteRegistration* registration, - GraphFloat32* graph, ObjectReader* reader) final { - auto* node = graph->NewNode(); - node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); - RETURN_IF_ERROR(reader->AddInput(node, 0)); - RETURN_IF_ERROR(reader->AddOutputs(node)); - - const TfLiteTransposeConvParams* tf_options; - auto status = RetrieveCustomInitialData(tflite_node, &tf_options); - - ConvolutionTransposedAttributes attr; - attr.stride = status.ok() - ? HW(tf_options->stride_height, tf_options->stride_width) - : HW(1, 1); - RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); - reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional - - UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown, - graph->FindInputs(node->id)[0]->tensor.shape, &attr); - node->operation.attributes = std::move(attr); - return absl::OkStatus(); - } -}; - class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, - /*runtime_inputs=*/1, /*outputs=*/1)); - RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + const int runtime_inputs = + GetNumberOfRuntimeInputsForNode(context, tflite_node); + if (runtime_inputs > 2) { + return absl::InternalError( + absl::StrCat("Expected 1 or 2 input tensor(s), but node has ", + runtime_inputs, " runtime inputs.")); + } + const int runtime_outputs = NumOutputs(tflite_node); + if (runtime_outputs != 1) { + return absl::InternalError( + absl::StrCat("Expected 1 output tensor(s), but node has ", + runtime_outputs, " runtime outputs.")); + } + if (runtime_inputs == 1) { + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + } const TfLiteDepthwiseConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckStridesAndDilation( @@ -606,7 +580,12 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddOutputs(node)); DepthwiseConvolution2DAttributes attr; - RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + const int runtime_inputs = reader->GetNumberOfRuntimeInputs(); + if (runtime_inputs == 2) { + RETURN_IF_ERROR(reader->AddInput(node, 1)); + } else { // runtime_inputs == 1; + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + } reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional const TfLiteDepthwiseConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -1794,6 +1773,9 @@ class SliceOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->ReadValue(0, &input)); RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id)); + const TfLiteTensor* tfl_input = reader->GetInputTensor(0); + const int input_dims = tfl_input->dims->size; + SliceAttributes attr; attr.strides = BHWC(1, 1, 1, 1); Tensor starts, sizes; @@ -1802,37 +1784,65 @@ class SliceOperationParser : public TFLiteOperationParser { if (starts.data.size() != sizes.data.size()) { return absl::InvalidArgumentError("Starts amount != sizes amount."); } - const auto& in_shape = input->tensor.shape; - if (starts.data.size() == 4) { - sizes.data[0] = - sizes.data[0] != -1 ? sizes.data[0] : in_shape.b - starts.data[0]; - sizes.data[1] = - sizes.data[1] != -1 ? sizes.data[1] : in_shape.h - starts.data[1]; - sizes.data[2] = - sizes.data[2] != -1 ? sizes.data[2] : in_shape.w - starts.data[2]; - sizes.data[3] = - sizes.data[3] != -1 ? sizes.data[3] : in_shape.c - starts.data[3]; - attr.starts = - BHWC(starts.data[0], starts.data[1], starts.data[2], starts.data[3]); - attr.ends = - BHWC(starts.data[0] + sizes.data[0], starts.data[1] + sizes.data[1], - starts.data[2] + sizes.data[2], starts.data[3] + sizes.data[3]); - } else if (starts.data.size() == 3) { - sizes.data[0] = - sizes.data[0] != -1 ? sizes.data[0] : in_shape.h - starts.data[0]; - sizes.data[1] = - sizes.data[1] != -1 ? sizes.data[1] : in_shape.w - starts.data[1]; - sizes.data[2] = - sizes.data[2] != -1 ? sizes.data[2] : in_shape.c - starts.data[2]; - attr.starts = BHWC(0, starts.data[0], starts.data[1], starts.data[2]); - attr.ends = - BHWC(in_shape.b, starts.data[0] + sizes.data[0], - starts.data[1] + sizes.data[1], starts.data[2] + sizes.data[2]); + BHWC bhwc_starts(0, 0, 0, 0); + BHWC bhwc_sizes = input->tensor.shape; + if (input_dims == 4) { + // input in BHWC layout + if (starts.data.size() == 4) { + bhwc_starts.b = starts.data[0]; + bhwc_starts.h = starts.data[1]; + bhwc_starts.w = starts.data[2]; + bhwc_starts.c = starts.data[3]; + bhwc_sizes.b = sizes.data[0]; + bhwc_sizes.h = sizes.data[1]; + bhwc_sizes.w = sizes.data[2]; + bhwc_sizes.c = sizes.data[3]; + } else if (starts.data.size() == 3) { + // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout + bhwc_starts.h = starts.data[0]; + bhwc_starts.w = starts.data[1]; + bhwc_starts.c = starts.data[2]; + bhwc_sizes.h = sizes.data[0]; + bhwc_sizes.w = sizes.data[1]; + bhwc_sizes.c = sizes.data[2]; + } else { + return absl::UnimplementedError( + "Slicing is supported for 3 or 4 dimensional tensors only."); + } + } else if (input_dims == 3) { + // input in BWC layout + if (starts.data.size() == 3) { + bhwc_starts.b = starts.data[0]; + bhwc_starts.w = starts.data[1]; + bhwc_starts.c = starts.data[2]; + bhwc_sizes.b = sizes.data[0]; + bhwc_sizes.w = sizes.data[1]; + bhwc_sizes.c = sizes.data[2]; + } else { + return absl::UnimplementedError( + "Slicing is supported for 3 or 4 dimensional tensors only."); + } } else { - // Error: Must be catched in IsSupported() return absl::UnimplementedError( "Slicing is supported for 3 or 4 dimensional tensors only."); } + const auto& in_shape = input->tensor.shape; + if (bhwc_sizes.b == -1) { + bhwc_sizes.b = in_shape.b - bhwc_starts.b; + } + if (bhwc_sizes.h == -1) { + bhwc_sizes.h = in_shape.h - bhwc_starts.h; + } + if (bhwc_sizes.w == -1) { + bhwc_sizes.w = in_shape.w - bhwc_starts.w; + } + if (bhwc_sizes.c == -1) { + bhwc_sizes.c = in_shape.c - bhwc_starts.c; + } + attr.starts = bhwc_starts; + attr.ends = + BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h, + bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c); RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr)); auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape; @@ -2140,7 +2150,7 @@ class StridedSliceOperationParser : public TFLiteOperationParser { }; // Builtin op version of TRANSPOSE_CONV. -class TransposeConvOperationParser : public TFLiteOperationParser { +class TransposeConvBuiltinOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, @@ -2184,6 +2194,45 @@ class TransposeConvOperationParser : public TFLiteOperationParser { } }; +// Custom op version of TRANSPOSE_CONV. +class TransposeConvCustomOperationParser : public TFLiteOperationParser { + public: + absl::Status IsSupported(const TfLiteContext* context, + const TfLiteNode* tflite_node, + const TfLiteRegistration* registration) final { + RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + const TfLiteTransposeConvParams* tf_options; + RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); + RETURN_IF_ERROR( + CheckStrides(tf_options->stride_height, tf_options->stride_width)); + return absl::OkStatus(); + } + + absl::Status Parse(const TfLiteNode* tflite_node, + const TfLiteRegistration* registration, + GraphFloat32* graph, ObjectReader* reader) final { + auto* node = graph->NewNode(); + node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED); + RETURN_IF_ERROR(reader->AddInput(node, 0)); + RETURN_IF_ERROR(reader->AddOutputs(node)); + + const TfLiteTransposeConvParams* tf_options; + auto status = RetrieveCustomInitialData(tflite_node, &tf_options); + + ConvolutionTransposedAttributes attr; + attr.stride = status.ok() + ? HW(tf_options->stride_height, tf_options->stride_width) + : HW(1, 1); + RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights)); + reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional + + UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown, + graph->FindInputs(node->id)[0]->tensor.shape, &attr); + node->operation.attributes = std::move(attr); + return absl::OkStatus(); + } +}; + class TransposeOperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, @@ -2724,12 +2773,12 @@ std::unique_ptr NewOperationParser( case kTfLiteBuiltinTranspose: return std::make_unique(); case kTfLiteBuiltinTransposeConv: - return std::make_unique(); + return std::make_unique(); case kTfLiteBuiltinCustom: const absl::string_view custom_name = registration->custom_name; if (custom_name == "Convolution2DTransposeBias") { - return std::make_unique(); + return std::make_unique(); } if (custom_name == "MaxPoolingWithArgmax2D") { return std::make_unique(PoolingType::MAX); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc index 4c6b08e2fb4..af274d8381e 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc @@ -70,6 +70,12 @@ class AddBias : public NodeTransformation { } if (node->operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { + if (graph->FindInputs(node->id).size() != 1) { + return {TransformStatus::DECLINED, + "This transformation is only applicable to depth wise conv " + "with one " + "runtime input."}; + } auto& attr = absl::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index 2b432bad877..62c3ec39854 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -54,6 +54,10 @@ class MergeConvolutionWithAdd : public SequenceTransformation { TransformResult ApplyToNodesSequence(const std::vector& sequence, GraphFloat32* graph) final { auto& conv_node = *sequence[0]; + if (graph->FindInputs(conv_node.id).size() != 1) { + return {TransformStatus::DECLINED, + "This fusion is only applicable to ops with one runtime input."}; + } auto& add_node = *sequence[1]; if (add_node.operation.type != ToString(OperationType::ADD)) { return {TransformStatus::SKIPPED, ""}; diff --git a/tensorflow/lite/delegates/gpu/delegate.cc b/tensorflow/lite/delegates/gpu/delegate.cc index bfc2b7f08c4..98303b51da8 100644 --- a/tensorflow/lite/delegates/gpu/delegate.cc +++ b/tensorflow/lite/delegates/gpu/delegate.cc @@ -448,7 +448,7 @@ TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default() { .inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION, .inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO, .inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO, - .experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE, + .experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT, .max_delegated_partitions = 1, }; return options; diff --git a/tensorflow/lite/delegates/gpu/delegate.h b/tensorflow/lite/delegates/gpu/delegate.h index 9af586bfd75..40a06bb4384 100644 --- a/tensorflow/lite/delegates/gpu/delegate.h +++ b/tensorflow/lite/delegates/gpu/delegate.h @@ -51,6 +51,7 @@ enum TfLiteGpuInferencePriority { enum TfLiteGpuExperimentalFlags { TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE = 0, // Enables inference on quantized models with the delegate. + // NOTE: This is enabled in TfLiteGpuDelegateOptionsV2Default. TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT = 1 << 0, // Enforces execution with the provided backend. TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY = 1 << 1, @@ -108,6 +109,8 @@ typedef struct { // priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION // priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO // priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO +// experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT +// max_delegated_partitions = 1 TFL_CAPI_EXPORT TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default(); // Creates a new delegate instance that need to be destroyed with diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc index 71217a8e709..ceda5b68ca8 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.cc @@ -38,6 +38,10 @@ class DepthwiseConvolution : public NodeShader { public: absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { + if (ctx.input_shapes.size() != 1) { + return absl::UnimplementedError( + "DepthWise Convolution does not support more than 1 runtime tensor"); + } const auto& attr = absl::any_cast(ctx.op_attr); auto weights = attr.weights.shape; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc index e6cd539500e..efab4dd2274 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/registry.cc @@ -126,17 +126,20 @@ class Registry : public NodeShader { absl::Status GenerateCode(const GenerationContext& ctx, GeneratedCode* generated_code) const final { - std::vector errors; auto it = shaders_.find(ctx.op_type); - if (it != shaders_.end()) { - for (auto& shader : it->second) { - const auto status = shader->GenerateCode(ctx, generated_code); - if (status.ok()) return status; - errors.push_back(std::string(status.message())); - } + if (it == shaders_.end()) { + return absl::NotFoundError( + absl::StrCat("No shader implementation for ", ctx.op_type)); } - return absl::NotFoundError(absl::StrCat( - "Suitable node shader is not found: ", absl::StrJoin(errors, ", "))); + std::vector errors; + for (const auto& shader : it->second) { + const auto status = shader->GenerateCode(ctx, generated_code); + // Return the first suitable shader. + if (status.ok()) return absl::OkStatus(); + errors.push_back(std::string(status.message())); + } + return errors.empty() ? absl::OkStatus() + : absl::UnknownError(absl::StrJoin(errors, ", ")); } private: diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java index 78cab0d2cbf..5eb6881be88 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java @@ -68,7 +68,7 @@ public class GpuDelegate implements Delegate, Closeable { * *

WARNING: This is an experimental API and subject to change. * - * @param quantizedModelsAllowed When {@code true}, the GPU may run quantized models. + * @param quantizedModelsAllowed When {@code true} (default), the GPU may run quantized models. */ public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) { this.quantizedModelsAllowed = quantizedModelsAllowed; @@ -87,7 +87,7 @@ public class GpuDelegate implements Delegate, Closeable { } boolean precisionLossAllowed = true; - boolean quantizedModelsAllowed = false; + boolean quantizedModelsAllowed = true; int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER; } diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD index c4e7ca7c10d..9f694b55cdb 100644 --- a/tensorflow/lite/delegates/gpu/metal/BUILD +++ b/tensorflow/lite/delegates/gpu/metal/BUILD @@ -38,6 +38,15 @@ cc_library( ], ) +cc_library( + name = "arguments", + srcs = ["arguments.cc"], + hdrs = ["arguments.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/common:status", + ], +) + objc_library( name = "buffer_convert", srcs = ["buffer_convert.mm"], @@ -134,6 +143,7 @@ objc_library( deps = [ ":common", ":compute_task_descriptor", + ":metal_arguments", ":runtime_options", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:shape", @@ -149,6 +159,7 @@ objc_library( hdrs = ["compute_task_descriptor.h"], copts = DEFAULT_COPTS, deps = [ + ":arguments", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:types", @@ -213,6 +224,20 @@ ios_unit_test( deps = [":inference_context_test_lib"], ) +objc_library( + name = "metal_arguments", + srcs = ["metal_arguments.mm"], + hdrs = ["metal_arguments.h"], + copts = DEFAULT_COPTS, + sdk_frameworks = ["Metal"], + deps = [ + ":arguments", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "runtime_options", hdrs = ["runtime_options.h"], diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc index 3359174e9f8..1f3476170b0 100644 --- a/tensorflow/lite/delegates/gpu/metal/api.cc +++ b/tensorflow/lite/delegates/gpu/metal/api.cc @@ -267,6 +267,11 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node, device_info, options); break; case OperationType::DEPTHWISE_CONVOLUTION: + if (graph.FindInputs(node->id).size() != 1) { + return absl::UnimplementedError( + "DepthWise Convolution does not support more than 1 runtime " + "tensor"); + } *tasks = SelectDepthWiseConv(node_id, inputs[0], outputs[0], absl::any_cast( diff --git a/tensorflow/lite/delegates/gpu/metal/arguments.cc b/tensorflow/lite/delegates/gpu/metal/arguments.cc new file mode 100644 index 00000000000..f4a308e59be --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/arguments.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/metal/arguments.h" + +#include "absl/strings/ascii.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { +bool IsWordSymbol(char symbol) { + return absl::ascii_isalnum(symbol) || symbol == '_'; +} + +bool HasWord(const std::string& word, const std::string& text) { + size_t pos = text.find(word); + while (pos != std::string::npos) { + char prev = pos == 0 ? '.' : text[pos - 1]; + char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.'; + if (!IsWordSymbol(prev) & !IsWordSymbol(next)) { + return true; + } + pos = text.find(word, pos + 1); + } + return false; +} +} // namespace + +// Static +constexpr char Arguments::kArgsPrefix[]; + +void Arguments::AddFloat(const std::string& name, float value) { + float_values_[name].value = value; +} + +void Arguments::AddInt(const std::string& name, int value) { + int_values_[name].value = value; +} + +void Arguments::GetActiveArguments(const std::string& code) { + for (auto& float_val : float_values_) { + float_val.second.active = HasWord(kArgsPrefix + float_val.first, code); + } + for (auto& int_val : int_values_) { + int_val.second.active = HasWord(kArgsPrefix + int_val.first, code); + } +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal/arguments.h b/tensorflow/lite/delegates/gpu/metal/arguments.h new file mode 100644 index 00000000000..fbdcfef1358 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/arguments.h @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_ + +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace metal { + +class Arguments { + public: + Arguments() = default; + + // Move only + Arguments(Arguments&& args) = default; + Arguments& operator=(Arguments&& args) = default; + Arguments(const Arguments&) = delete; + Arguments& operator=(const Arguments&) = delete; + + void AddFloat(const std::string& name, float value = 0.0f); + void AddInt(const std::string& name, int value = 0); + + private: + friend class MetalArguments; + void GetActiveArguments(const std::string& code); + + static constexpr char kArgsPrefix[] = "args."; + struct IntValue { + int value; + + // many arguments generated automatically and not used + // this flag active if argument was used in kernel_code + // Will be filled after GetActiveArguments call + bool active = false; + }; + std::map int_values_; + + struct FloatValue { + float value; + + // many arguments generated automatically and not used + // this flag active if argument was used in kernel_code + // Will be filled after GetActiveArguments call + bool active = false; + }; + std::map float_values_; +}; + +class ArgumentsSetter { + public: + virtual absl::Status SetInt(const std::string& name, int value) = 0; + virtual absl::Status SetFloat(const std::string& name, float value) = 0; + virtual ~ArgumentsSetter() = default; +}; + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc index 74202edd585..dc55c70f39f 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model.cc +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model.cc @@ -535,6 +535,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { uniform_index++; fused_descriptor->uniform_buffers.push_back({"", buffer.data_function}); } + fused_descriptor->args = std::move(desc->args); if (desc->is_linkable) { call_code += @@ -546,8 +547,8 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) { ComputeTaskDescriptorPtr non_linkable = sequence.front(); fused_descriptor->shader_source = - absl::Substitute(non_linkable->shader_source, function_code, - buffer_declarations, call_code); + absl::Substitute(non_linkable->shader_source, function_code + "$0", + buffer_declarations + "$1", call_code); std::vector alias; alias.reserve(chain.size() - 1); for (int i = 0; i < chain.size() - 1; i++) { diff --git a/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm b/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm index 3a76178d71f..0bafcb2616e 100644 --- a/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm +++ b/tensorflow/lite/delegates/gpu/metal/compiled_model_test.mm @@ -158,6 +158,7 @@ static std::vector Add2Linkable(int id, ValueId input_ ValueId input_id2, ValueId output_id) { std::vector descriptors; descriptors.push_back(ComputeTaskDescriptorPtr(new ComputeTaskDescriptor({ + {}, // args id, true, // linkable true, // associative_op diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.mm b/tensorflow/lite/delegates/gpu/metal/compute_task.mm index 7bfbb55feff..bcff35e29ad 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task.mm +++ b/tensorflow/lite/delegates/gpu/metal/compute_task.mm @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -70,11 +71,15 @@ struct UniformBuffer { uint3 _groupsCount; DispatchParamsFunction _resizeFunction; std::string _description; + tflite::gpu::metal::MetalArguments _metal_args; } - (absl::Status)compileWithDevice:(id)device taskDescriptor:(ComputeTaskDescriptorPtr)desc runtimeOptions:(const RuntimeOptions&)options { + size_t offset = desc->input_buffers.size() + desc->uniform_buffers.size() + + desc->immutable_buffers.size() + 1; + RETURN_IF_ERROR(_metal_args.Init(offset, &desc->args, &desc->shader_source)); NSString* barrier; // simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0 if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) { @@ -251,6 +256,7 @@ struct UniformBuffer { [encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex]; bindIndex++; } + _metal_args.Encode(encoder, bindIndex); MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z); MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z); diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h index 923f4dcc245..7b65f2bdb02 100644 --- a/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h +++ b/tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/metal/arguments.h" #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h" namespace tflite { @@ -79,6 +80,7 @@ struct ComputeTaskDescriptor { UniformsFunction data_function; }; + Arguments args; // Unique ID to match the graph compilation errors. int id; bool is_linkable; diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc index 9872328f6d7..152ec0f542a 100644 --- a/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/metal/kernels/fully_connected.cc @@ -51,13 +51,6 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info, #include using namespace metal; - struct uniforms { - uint src_depth; - uint dst_channels; - uint out_channels; - uint dummy; - }; - $$0 kernel void ComputeFunction( $$1 @@ -71,11 +64,11 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info, float summa = 0.0f; threadgroup FLT4 local_vector[32]; for (int j = 0; j < $0; ++j) { - local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ? + local_vector[tid_index] = j * 32 + tid_index >= args.src_slices ? FLT4(0.0f) : vector[j * 32 + tid_index]; $1(mem_flags::mem_threadgroup); for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) { - summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]); + summa += dot(local_vector[tid.y * 8 + i], matrix[counter * args.dst_channels_alignedx8 + ugid.x]); } $1(mem_flags::mem_none); } @@ -87,10 +80,10 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info, for (uint i = 0; i < $0; ++i, ++counter) { )"; if (src_depth % 4 != 0) { - code << " if (counter >= params.src_depth) continue;" << std::endl; + code << " if (counter >= args.src_slices) continue;" << std::endl; } code << " summa += dot(vector[counter], matrix[counter * " - "params.dst_channels + ugid.x]);" + "args.dst_channels_alignedx8 + ugid.x]);" << std::endl; code << " }" << std::endl; } @@ -106,7 +99,7 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info, temp[tid.x][0] = summa; } $1(mem_flags::mem_threadgroup); - if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) { + if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < args.dst_channels) { const int linear_index = ugid.x / 4; FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) + biases[linear_index]; @@ -132,6 +125,11 @@ std::vector FullyConnected( desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i, attr.weights.shape.o); + desc->args.AddInt("dst_channels", attr.weights.shape.o); + desc->args.AddInt("src_slices", DivideRoundUp(attr.weights.shape.i, 4)); + desc->args.AddInt("dst_channels_alignedx8", + AlignByN(attr.weights.shape.o, 8)); + desc->input_buffers = { {input_id, "device FLT4* const vector"}, }; @@ -174,19 +172,6 @@ std::vector FullyConnected( attr.weights.shape.o)}, }; - desc->uniform_buffers = { - {"constant uniforms& params", - [attr](const std::map& buffers) { - std::vector uniform_params{ - static_cast(DivideRoundUp(attr.weights.shape.i, 4)), - static_cast(AlignByN(attr.weights.shape.o, 8)), - static_cast(attr.weights.shape.o), - static_cast(0), - }; - return GetByteBuffer(uniform_params); - }}, - }; - desc->resize_function = [attr](const std::map& buffers) { const uint3 groups_size{8, 4, 1}; const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8); diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h new file mode 100644 index 00000000000..496287c8ff0 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_ + +#import + +#include +#include +#include + +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/metal/arguments.h" + +namespace tflite { +namespace gpu { +namespace metal { + +class MetalArguments : public ArgumentsSetter { + public: + MetalArguments() = default; + + absl::Status Init(int buffer_offset, Arguments* args, std::string* code); + + // Move only + MetalArguments(MetalArguments&& args) = default; + MetalArguments& operator=(MetalArguments&& args) = default; + MetalArguments(const MetalArguments&) = delete; + MetalArguments& operator=(const MetalArguments&) = delete; + + absl::Status SetInt(const std::string& name, int value) override; + absl::Status SetFloat(const std::string& name, float value) override; + + void Encode(id encoder, int buffer_offset) const; + + private: + static constexpr char kArgsPrefix[] = "args."; + struct IntValue { + int value; + + // many arguments generated automatically and not used + // to reduce amount of data transferred we adding this optimization + bool active = false; + + // offset to shared storage. + uint32_t bytes_offset = -1; + }; + std::map int_values_; + + struct FloatValue { + float value; + + // many arguments generated automatically and not used + // to reduce amount of data transferred we adding this optimization + bool active = false; + + // offset to shared storage. + uint32_t bytes_offset = -1; + }; + std::map float_values_; + std::vector const_data_; +}; + +} // namespace metal +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_ diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm b/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm new file mode 100644 index 00000000000..421e4a98db7 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.mm @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h" + +#include + +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace metal { +namespace { +bool IsWordSymbol(char symbol) { + return absl::ascii_isalnum(symbol) || symbol == '_'; +} + +void ReplaceAllWords(const std::string& old_word, const std::string& new_word, + std::string* str) { + size_t position = str->find(old_word); + while (position != std::string::npos) { + char prev = position == 0 ? '.' : (*str)[position - 1]; + char next = position + old_word.size() < str->size() + ? (*str)[position + old_word.size()] + : '.'; + if (IsWordSymbol(prev) || IsWordSymbol(next)) { + position = str->find(old_word, position + 1); + continue; + } + str->replace(position, old_word.size(), new_word); + position = str->find(old_word, position + new_word.size()); + } +} +} // namespace + +// Static +constexpr char MetalArguments::kArgsPrefix[]; + +absl::Status MetalArguments::Init(int buffer_offset, Arguments* args, std::string* code) { + args->GetActiveArguments(*code); + std::string struct_desc = "struct uniforms_buffer {\n"; + std::string struct_decl; + int pos = 0; + for (auto& fvalue : args->float_values_) { + auto& new_val = float_values_[fvalue.first]; + new_val.value = fvalue.second.value; + new_val.active = fvalue.second.active; + if (fvalue.second.active) { + new_val.bytes_offset = pos * 4; + pos++; + struct_desc += " float " + fvalue.first + ";\n"; + ReplaceAllWords(kArgsPrefix + fvalue.first, "U." + fvalue.first, code); + } + } + for (auto& ivalue : args->int_values_) { + auto& new_val = int_values_[ivalue.first]; + new_val.value = ivalue.second.value; + new_val.active = ivalue.second.active; + if (ivalue.second.active) { + new_val.bytes_offset = pos * 4; + pos++; + struct_desc += " int " + ivalue.first + ";\n"; + ReplaceAllWords(kArgsPrefix + ivalue.first, "U." + ivalue.first, code); + } + } + if (pos != 0) { + struct_decl = "constant uniforms_buffer& U[[buffer(" + std::to_string(buffer_offset) + ")]],\n"; + int aligned_pos = AlignByN(pos, 4); + for (int i = pos; i < aligned_pos; i++) { + struct_desc += " int dummy" + std::to_string(i - pos) + ";\n"; + } + struct_desc += "};"; + const_data_.resize(aligned_pos * 4); + for (auto& it : float_values_) { + float* ptr = reinterpret_cast(&const_data_[it.second.bytes_offset]); + *ptr = it.second.value; + } + for (auto& it : int_values_) { + int32_t* ptr = reinterpret_cast(&const_data_[it.second.bytes_offset]); + *ptr = it.second.value; + } + } else { + struct_desc = ""; + struct_decl = ""; + } + *code = absl::Substitute(*code, struct_desc, struct_decl); + return absl::OkStatus(); +} + +absl::Status MetalArguments::SetInt(const std::string& name, int value) { + auto it = int_values_.find(name); + if (it == int_values_.end()) { + return absl::NotFoundError( + absl::StrCat("No int argument with name - ", name)); + } + it->second.value = value; + if (it->second.active) { + int32_t* ptr = reinterpret_cast(&const_data_[it->second.bytes_offset]); + *ptr = value; + } + return absl::OkStatus(); +} +absl::Status MetalArguments::SetFloat(const std::string& name, float value) { + auto it = float_values_.find(name); + if (it == float_values_.end()) { + return absl::NotFoundError( + absl::StrCat("No float argument with name - ", name)); + } + it->second.value = value; + if (it->second.active) { + float* ptr = reinterpret_cast(&const_data_[it->second.bytes_offset]); + *ptr = value; + } + return absl::OkStatus(); +} + +void MetalArguments::Encode(id encoder, int buffer_offset) const { + if (!const_data_.empty()) { + [encoder setBytes:const_data_.data() length:const_data_.size() atIndex:buffer_offset]; + } +} + +} // namespace metal +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/metal_delegate.h b/tensorflow/lite/delegates/gpu/metal_delegate.h index e4bdba36799..ea9da126954 100644 --- a/tensorflow/lite/delegates/gpu/metal_delegate.h +++ b/tensorflow/lite/delegates/gpu/metal_delegate.h @@ -47,6 +47,7 @@ typedef struct { bool allow_precision_loss; TFLGpuDelegateWaitType wait_type; // Allows execution of integer quantized models + // TODO(b/169350710): Enable by default. bool enable_quantization; } TFLGpuDelegateOptions; diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index f9825f62773..e96dfdd187b 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -77,6 +77,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", @@ -94,6 +95,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", @@ -111,6 +113,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", @@ -128,6 +131,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", @@ -145,6 +149,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -161,6 +166,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -177,6 +183,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -193,6 +200,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@FP16", "@com_google_googletest//:gtest", "@flatbuffers", @@ -210,6 +218,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -226,6 +235,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -242,6 +252,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -258,6 +269,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -274,6 +286,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc index 1ba48c3c0e5..bc18c76f7eb 100644 --- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc index f5a5f809993..b81928717f3 100644 --- a/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/conv_2d_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc index d846dcf9929..ca40e89375a 100644 --- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc index ff3e974a4e4..0ea0580bd79 100644 --- a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/leaky_relu_tester.cc b/tensorflow/lite/delegates/xnnpack/leaky_relu_tester.cc index e830760a2f9..c44143eb18a 100644 --- a/tensorflow/lite/delegates/xnnpack/leaky_relu_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/leaky_relu_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/pad_tester.cc b/tensorflow/lite/delegates/xnnpack/pad_tester.cc index e9688188d9f..0365fd4cbd5 100644 --- a/tensorflow/lite/delegates/xnnpack/pad_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/pad_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/pool_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/pool_2d_tester.cc index 6f7993b0df4..bb6e8be7b7d 100644 --- a/tensorflow/lite/delegates/xnnpack/pool_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/pool_2d_tester.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/prelu_tester.cc b/tensorflow/lite/delegates/xnnpack/prelu_tester.cc index 01361075c1f..cee690e4dbd 100644 --- a/tensorflow/lite/delegates/xnnpack/prelu_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/prelu_tester.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/reduce_tester.cc b/tensorflow/lite/delegates/xnnpack/reduce_tester.cc index f9db35e6e28..9628dbcc1d4 100644 --- a/tensorflow/lite/delegates/xnnpack/reduce_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/reduce_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/reshape_tester.cc b/tensorflow/lite/delegates/xnnpack/reshape_tester.cc index 6e16c9fe1c0..f44e3bd9ea3 100644 --- a/tensorflow/lite/delegates/xnnpack/reshape_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/reshape_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc b/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc index 52f8921391a..df11d740696 100644 --- a/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/softmax_tester.cc b/tensorflow/lite/delegates/xnnpack/softmax_tester.cc index e3636a9e960..4ec916db17f 100644 --- a/tensorflow/lite/delegates/xnnpack/softmax_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/softmax_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.cc index 4b34d80d82b..df22ae2db7a 100644 --- a/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/examples/label_image/BUILD b/tensorflow/lite/examples/label_image/BUILD index 633f767c5e9..e7d2d1a9672 100644 --- a/tensorflow/lite/examples/label_image/BUILD +++ b/tensorflow/lite/examples/label_image/BUILD @@ -29,16 +29,20 @@ cc_binary( }), deps = [ ":bitmap_helpers", - "//tensorflow/lite/c:common", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "//tensorflow/lite:framework", "//tensorflow/lite:string_util", + "//tensorflow/lite/c:common", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/profiling:profiler", + "//tensorflow/lite/tools:command_line_flags", + "//tensorflow/lite/tools:tool_params", + "//tensorflow/lite/tools/delegates:delegate_provider_hdr", + "//tensorflow/lite/tools/delegates:tflite_execution_providers", "//tensorflow/lite/tools/evaluation:utils", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", ] + select({ "//tensorflow:android": [ "//tensorflow/lite/delegates/gpu:delegate", diff --git a/tensorflow/lite/examples/label_image/README.md b/tensorflow/lite/examples/label_image/README.md index 09e21b5d0c6..c13dffe1e3b 100644 --- a/tensorflow/lite/examples/label_image/README.md +++ b/tensorflow/lite/examples/label_image/README.md @@ -222,3 +222,20 @@ label_image ``` See the `label_image.cc` source code for other command line options. + +Note that this binary also supports runtime/delegate arguments introduced by the +[delegate registrar](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/delegates). +If there is any conflict, the arguments mentioned earlier are given precedence. +For example, you can run the binary with additional command line options +such as `--use_nnapi=true --nnapi_accelerator_name=google-edgetpu` to utilize +the EdgeTPU in a 4th-gen Pixel phone. Please be aware that the "=" in the option +should not be omitted. + +``` +adb shell \ + "/data/local/tmp/label_image \ + -m /data/local/tmp/mobilenet_v1_1.0_224_quant.tflite \ + -i /data/local/tmp/grace_hopper.bmp \ + -l /data/local/tmp/labels.txt -j 1 \ + --use_nnapi=true --nnapi_accelerator_name=google-edgetpu" +``` diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index d0761c07089..4f6bcb4573c 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -44,6 +44,8 @@ limitations under the License. #include "tensorflow/lite/optional_debug_tools.h" #include "tensorflow/lite/profiling/profiler.h" #include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/delegates/delegate_provider.h" #include "tensorflow/lite/tools/evaluation/utils.h" #if defined(__ANDROID__) @@ -60,6 +62,56 @@ double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr; using TfLiteDelegatePtrMap = std::map; +class DelegateProviders { + public: + DelegateProviders() + : delegates_list_(tflite::tools::GetRegisteredDelegateProviders()) { + for (const auto& delegate : delegates_list_) { + params_.Merge(delegate->DefaultParams()); + } + } + + // Initialize delegate-related parameters from parsing command line arguments, + // and remove the matching arguments from (*argc, argv). Returns true if all + // recognized arg values are parsed correctly. + bool InitFromCmdlineArgs(int* argc, const char** argv) { + std::vector flags; + for (const auto& delegate : delegates_list_) { + auto delegate_flags = delegate->CreateFlags(¶ms_); + flags.insert(flags.end(), delegate_flags.begin(), delegate_flags.end()); + } + + const bool parse_result = Flags::Parse(argc, argv, flags); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flags); + LOG(ERROR) << usage; + } + return parse_result; + } + + // Create a list of TfLite delegates based on what have been initialized (i.e. + // 'params_'). + TfLiteDelegatePtrMap CreateAllDelegates() const { + TfLiteDelegatePtrMap delegates_map; + for (const auto& delegate : delegates_list_) { + auto ptr = delegate->CreateTfLiteDelegate(params_); + // It's possible that a delegate of certain type won't be created as + // user-specified benchmark params tells not to. + if (ptr == nullptr) continue; + LOG(INFO) << delegate->GetName() << " delegate created.\n"; + delegates_map.emplace(delegate->GetName(), std::move(ptr)); + } + return delegates_map; + } + + private: + // Contain delegate-related parameters that are initialized from command-line + // flags. + tflite::tools::ToolParams params_; + + const tflite::tools::DelegateProviderList& delegates_list_; +}; + TfLiteDelegatePtr CreateGPUDelegate(Settings* s) { #if defined(__ANDROID__) TfLiteGpuDelegateOptionsV2 gpu_opts = TfLiteGpuDelegateOptionsV2Default(); @@ -74,7 +126,10 @@ TfLiteDelegatePtr CreateGPUDelegate(Settings* s) { #endif } -TfLiteDelegatePtrMap GetDelegates(Settings* s) { +TfLiteDelegatePtrMap GetDelegates(Settings* s, + const DelegateProviders& delegate_providers) { + // TODO(b/169681115): deprecate delegate creation path based on "Settings" by + // mapping settings to DelegateProvider's parameters. TfLiteDelegatePtrMap delegates; if (s->gl_backend) { auto delegate = CreateGPUDelegate(s); @@ -117,6 +172,18 @@ TfLiteDelegatePtrMap GetDelegates(Settings* s) { } } + // Independent of above delegate creation options that are specific to this + // binary, we use delegate providers to create TFLite delegates. Delegate + // providers have been used in TFLite benchmark/evaluation tools and testing + // so that we have a single and more comprehensive set of command line + // arguments for delegate creation. + TfLiteDelegatePtrMap delegates_from_providers = + delegate_providers.CreateAllDelegates(); + for (auto& name_and_delegate : delegates_from_providers) { + delegates.emplace("Delegate_Provider_" + name_and_delegate.first, + std::move(name_and_delegate.second)); + } + return delegates; } @@ -162,21 +229,22 @@ void PrintProfilingInfo(const profiling::ProfileEvent* e, << "\n"; } -void RunInference(Settings* s) { - if (!s->model_name.c_str()) { +void RunInference(Settings* settings, + const DelegateProviders& delegate_providers) { + if (!settings->model_name.c_str()) { LOG(ERROR) << "no model file name\n"; exit(-1); } std::unique_ptr model; std::unique_ptr interpreter; - model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str()); + model = tflite::FlatBufferModel::BuildFromFile(settings->model_name.c_str()); if (!model) { - LOG(ERROR) << "\nFailed to mmap model " << s->model_name << "\n"; + LOG(ERROR) << "\nFailed to mmap model " << settings->model_name << "\n"; exit(-1); } - s->model = model.get(); - LOG(INFO) << "Loaded model " << s->model_name << "\n"; + settings->model = model.get(); + LOG(INFO) << "Loaded model " << settings->model_name << "\n"; model->error_reporter(); LOG(INFO) << "resolved reporter\n"; @@ -188,9 +256,9 @@ void RunInference(Settings* s) { exit(-1); } - interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16); + interpreter->SetAllowFp16PrecisionForFp32(settings->allow_fp16); - if (s->verbose) { + if (settings->verbose) { LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n"; LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n"; LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n"; @@ -207,28 +275,28 @@ void RunInference(Settings* s) { } } - if (s->number_of_threads != -1) { - interpreter->SetNumThreads(s->number_of_threads); + if (settings->number_of_threads != -1) { + interpreter->SetNumThreads(settings->number_of_threads); } int image_width = 224; int image_height = 224; int image_channels = 3; - std::vector in = read_bmp(s->input_bmp_name, &image_width, - &image_height, &image_channels, s); + std::vector in = read_bmp(settings->input_bmp_name, &image_width, + &image_height, &image_channels, settings); int input = interpreter->inputs()[0]; - if (s->verbose) LOG(INFO) << "input: " << input << "\n"; + if (settings->verbose) LOG(INFO) << "input: " << input << "\n"; const std::vector inputs = interpreter->inputs(); const std::vector outputs = interpreter->outputs(); - if (s->verbose) { + if (settings->verbose) { LOG(INFO) << "number of inputs: " << inputs.size() << "\n"; LOG(INFO) << "number of outputs: " << outputs.size() << "\n"; } - auto delegates_ = GetDelegates(s); + auto delegates_ = GetDelegates(settings, delegate_providers); for (const auto& delegate : delegates_) { if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) != kTfLiteOk) { @@ -244,7 +312,7 @@ void RunInference(Settings* s) { exit(-1); } - if (s->verbose) PrintInterpreterState(interpreter.get()); + if (settings->verbose) PrintInterpreterState(interpreter.get()); // get input dimension from the input tensor metadata // assuming one input only @@ -253,44 +321,45 @@ void RunInference(Settings* s) { int wanted_width = dims->data[2]; int wanted_channels = dims->data[3]; - s->input_type = interpreter->tensor(input)->type; - switch (s->input_type) { + settings->input_type = interpreter->tensor(input)->type; + switch (settings->input_type) { case kTfLiteFloat32: resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, - wanted_width, wanted_channels, s); + wanted_width, wanted_channels, settings); break; case kTfLiteInt8: resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, - wanted_width, wanted_channels, s); + wanted_width, wanted_channels, settings); break; case kTfLiteUInt8: resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, - wanted_width, wanted_channels, s); + wanted_width, wanted_channels, settings); break; default: LOG(ERROR) << "cannot handle input type " << interpreter->tensor(input)->type << " yet\n"; exit(-1); } - auto profiler = - absl::make_unique(s->max_profiling_buffer_entries); + auto profiler = absl::make_unique( + settings->max_profiling_buffer_entries); interpreter->SetProfiler(profiler.get()); - if (s->profiling) profiler->StartProfiling(); - if (s->loop_count > 1) - for (int i = 0; i < s->number_of_warmup_runs; i++) { + if (settings->profiling) profiler->StartProfiling(); + if (settings->loop_count > 1) { + for (int i = 0; i < settings->number_of_warmup_runs; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(ERROR) << "Failed to invoke tflite!\n"; exit(-1); } } + } struct timeval start_time, stop_time; gettimeofday(&start_time, nullptr); - for (int i = 0; i < s->loop_count; i++) { + for (int i = 0; i < settings->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(ERROR) << "Failed to invoke tflite!\n"; exit(-1); @@ -299,10 +368,11 @@ void RunInference(Settings* s) { gettimeofday(&stop_time, nullptr); LOG(INFO) << "invoked\n"; LOG(INFO) << "average time: " - << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) + << (get_us(stop_time) - get_us(start_time)) / + (settings->loop_count * 1000) << " ms \n"; - if (s->profiling) { + if (settings->profiling) { profiler->StopProfiling(); auto profile_events = profiler->GetProfileEvents(); for (int i = 0; i < profile_events.size(); i++) { @@ -328,18 +398,18 @@ void RunInference(Settings* s) { switch (interpreter->tensor(output)->type) { case kTfLiteFloat32: get_top_n(interpreter->typed_output_tensor(0), output_size, - s->number_of_results, threshold, &top_results, - s->input_type); + settings->number_of_results, threshold, &top_results, + settings->input_type); break; case kTfLiteInt8: get_top_n(interpreter->typed_output_tensor(0), - output_size, s->number_of_results, threshold, - &top_results, s->input_type); + output_size, settings->number_of_results, threshold, + &top_results, settings->input_type); break; case kTfLiteUInt8: get_top_n(interpreter->typed_output_tensor(0), - output_size, s->number_of_results, threshold, - &top_results, s->input_type); + output_size, settings->number_of_results, threshold, + &top_results, settings->input_type); break; default: LOG(ERROR) << "cannot handle output type " @@ -350,7 +420,8 @@ void RunInference(Settings* s) { std::vector labels; size_t label_count; - if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk) + if (ReadLabelsFile(settings->labels_file_name, &labels, &label_count) != + kTfLiteOk) exit(-1); for (const auto& result : top_results) { @@ -383,6 +454,13 @@ void display_usage() { } int Main(int argc, char** argv) { + DelegateProviders delegate_providers; + bool parse_result = delegate_providers.InitFromCmdlineArgs( + &argc, const_cast(argv)); + if (!parse_result) { + return EXIT_FAILURE; + } + Settings s; int c; @@ -476,7 +554,8 @@ int Main(int argc, char** argv) { strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'x': - s.xnnpack_delegate = optarg; + s.xnnpack_delegate = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'h': case '?': @@ -487,7 +566,7 @@ int Main(int argc, char** argv) { exit(-1); } } - RunInference(&s); + RunInference(&s, delegate_providers); return 0; } diff --git a/tensorflow/lite/experimental/microfrontend/lib/fft.cc b/tensorflow/lite/experimental/microfrontend/lib/fft.cc index 72779f5d60b..22c6ae93455 100644 --- a/tensorflow/lite/experimental/microfrontend/lib/fft.cc +++ b/tensorflow/lite/experimental/microfrontend/lib/fft.cc @@ -38,10 +38,9 @@ void FftCompute(struct FftState* state, const int16_t* input, } // Apply the FFT. - kiss_fftr( - reinterpret_cast(state->scratch), - state->input, - reinterpret_cast(state->output)); + kiss_fftr(reinterpret_cast(state->scratch), + state->input, + reinterpret_cast(state->output)); } void FftInit(struct FftState* state) { diff --git a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm index ad8e5bad7c7..27275212d03 100644 --- a/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm +++ b/tensorflow/lite/experimental/objc/sources/TFLInterpreter.mm @@ -201,7 +201,7 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_ return NO; } - std::vector cDimensions(self.inputTensorCount); + std::vector cDimensions(shape.count); for (int dimIndex = 0; dimIndex < shape.count; ++dimIndex) { int dimension = shape[dimIndex].intValue; if (dimension <= 0) { diff --git a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD b/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD index fff0a47c323..bb64be61599 100644 --- a/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD +++ b/tensorflow/lite/experimental/tflite_api_dispatcher/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") + package( default_visibility = ["//tensorflow:internal"], licenses = ["notice"], # Apache 2.0 @@ -6,6 +8,7 @@ package( cc_library( name = "tflite_api_dispatcher", hdrs = ["tflite_api_dispatcher.h"], + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/lite:framework_lib", ], diff --git a/tensorflow/lite/experimental/writer/BUILD b/tensorflow/lite/experimental/writer/BUILD index 2792cc7ae2f..ceb11e204d8 100644 --- a/tensorflow/lite/experimental/writer/BUILD +++ b/tensorflow/lite/experimental/writer/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs_with_reflection", + "//tensorflow/lite/schema:schema_utils", ], ) diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc index 898f4a95ef6..47550be2a21 100644 --- a/tensorflow/lite/experimental/writer/option_writer_generator.cc +++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc @@ -163,8 +163,11 @@ class OpOptionData { op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; op_to_option_["MAXIMUM"] = "MaximumMinimumOptions"; op_to_option_["MINIMUM"] = "MaximumMinimumOptions"; + + // These operators are not real ones. op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. + op_to_option_["PLACEHOLDER_FOR_GREATER_OP_CODES"] = ""; // Manually specified mappings between ops to "none" options -- these are // ops without a corresponding Options message in schema as yet. If these diff --git a/tensorflow/lite/experimental/writer/writer_lib.cc b/tensorflow/lite/experimental/writer/writer_lib.cc index 2f509daa9cb..9f18fff76d5 100644 --- a/tensorflow/lite/experimental/writer/writer_lib.cc +++ b/tensorflow/lite/experimental/writer/writer_lib.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/experimental/writer/enum_mapping.h" #include "tensorflow/lite/schema/reflection/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace tflite { diff --git a/tensorflow/lite/g3doc/convert/index.md b/tensorflow/lite/g3doc/convert/index.md index 87cc92826db..36774205770 100644 --- a/tensorflow/lite/g3doc/convert/index.md +++ b/tensorflow/lite/g3doc/convert/index.md @@ -161,17 +161,26 @@ with open('model.tflite', 'wb') as f: and then [create the TensorFlow Lite operator](../guide/ops_custom.md). If you were unsuccessful at creating the TensorFlow operator or don't wish to create one (**not recommended, proceed with caution**), you can - still convert using the `custom_opdefs` attribute and then directly - [create the TensorFlow Lite operator](../guide/ops_custom.md). The - `custom_opdefs` attribute is a string containing an (or a list of) + still convert using the `register_custom_opdefs` method and then + directly [create the TensorFlow Lite operator](../guide/ops_custom.md). + The `register_custom_opdefs` method takes a list of a string containing + an [OpDef](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_def.proto) - (s) or operator definition proto(s). Below is an example of a - `TFLiteAwesomeCustomOp` with 1 input, 1 output, and 2 attributes: + (s). Below is an example of a `TFLiteAwesomeCustomOp` with 1 input, 1 + output, and 2 attributes: ```python - converter.custom_opdefs="""name: 'TFLiteAwesomeCustomOp' input_arg: + import tensorflow as tf + + custom_opdef = """name: 'TFLiteAwesomeCustomOp' input_arg: { name: 'In' type: DT_FLOAT } output_arg: { name: 'Out' type: DT_FLOAT } attr : { name: 'a1' type: 'float'} attr : { name: 'a2' type: 'list(float)'}""" + + # Register custom opdefs before the invocation of converter API. + tf.lite.python.convert.register_custom_opdefs([custom_opdef]) + + converter = tf.lite.TFLiteConverter.from_saved_model(...) + converter.allow_custom_ops = True ``` ## Command Line Tool diff --git a/tensorflow/lite/g3doc/guide/build_android.md b/tensorflow/lite/g3doc/guide/build_android.md index 792c609bc0e..32fc6f1facc 100644 --- a/tensorflow/lite/g3doc/guide/build_android.md +++ b/tensorflow/lite/g3doc/guide/build_android.md @@ -99,7 +99,7 @@ prompt. Successful configuration should yield entries similar to the following in the `.tf_configure.bazelrc` file in the root folder: ```shell -build --action_env ANDROID_NDK_HOME="/usr/local/android/android-ndk-r17c" +build --action_env ANDROID_NDK_HOME="/usr/local/android/android-ndk-r18b" build --action_env ANDROID_NDK_API_LEVEL="21" build --action_env ANDROID_BUILD_TOOLS_VERSION="28.0.3" build --action_env ANDROID_SDK_API_LEVEL="23" diff --git a/tensorflow/lite/g3doc/guide/get_started.md b/tensorflow/lite/g3doc/guide/get_started.md index df206e73416..2dff4b54c59 100644 --- a/tensorflow/lite/g3doc/guide/get_started.md +++ b/tensorflow/lite/g3doc/guide/get_started.md @@ -274,7 +274,7 @@ import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.optimizations = [tf.lite.Optimize.DEFAULT] -tflite_quant_model = converter.convert() +tflite_quantized_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_quantized_model) ``` diff --git a/tensorflow/lite/g3doc/guide/inference.md b/tensorflow/lite/g3doc/guide/inference.md index e5e65fac069..335720f19c0 100644 --- a/tensorflow/lite/g3doc/guide/inference.md +++ b/tensorflow/lite/g3doc/guide/inference.md @@ -46,7 +46,8 @@ TensorFlow Lite inference typically follows the following steps: ## Supported platforms TensorFlow inference APIs are provided for most common mobile/embedded platforms -such as Android, iOS and Linux, in multiple programming languages. +such as [Android](#android-platform), [iOS](#ios-platform) and +[Linux](#linux-platform), in multiple programming languages. In most cases, the API design reflects a preference for performance over ease of use. TensorFlow Lite is designed for fast inference on small devices, so it @@ -57,20 +58,15 @@ explicit goal and some variance between languages is to be expected. Across all libraries, the TensorFlow Lite API enables you to load models, feed inputs, and retrieve inference outputs. -## Supported operations - -TensorFlow Lite supports a subset of TensorFlow operations with some -limitations. For full list of operations and limitations see -[TF Lite Ops page](https://www.tensorflow.org/mlir/tfl_ops). - -### Android +### Android Platform On Android, TensorFlow Lite inference can be performed using either Java or C++ APIs. The Java APIs provide convenience and can be used directly within your Android Activity classes. The C++ APIs offer more flexibility and speed, but may require writing JNI wrappers to move data between Java and C++ layers. -See below for details about using C++ and Java, or follow the +See below for details about using [C++](#load-and-run-a-model-in-c) and +[Java](#load-and-run-a-model-in-java), or follow the [Android quickstart](android.md) for a tutorial and example code. #### TensorFlow Lite Android wrapper code generator @@ -86,7 +82,7 @@ TensorFlow Lite model with typed objects such as `Bitmap` and `Rect`. For more information, please refer to the [TensorFlow Lite Android wrapper code generator](../inference_with_metadata/codegen.md). -### iOS +### iOS Platform On iOS, TensorFlow Lite is available with native iOS libraries written in [Swift](https://www.tensorflow.org/code/tensorflow/lite/experimental/swift) @@ -96,14 +92,17 @@ You can also use [C API](https://www.tensorflow.org/code/tensorflow/lite/c/c_api.h) directly in Objective-C codes. -See below for details about using Swift, Objective-C and C API, or follow the +See below for details about using [Swift](#load-and-run-a-model-in-swift), +[Objective-C](#load-and-run-a-model-in-objective-c) and the +[C API](#using-c-api-in-objective-c-code), or follow the [iOS quickstart](ios.md) for a tutorial and example code. -### Linux +### Linux Platform On Linux platforms (including [Raspberry Pi](build_rpi.md)), you can run -inferences using TensorFlow Lite APIs available in C++ and Python, as shown in -the following sections. +inferences using TensorFlow Lite APIs available in +[C++](#load-and-run-a-model-in-c) and [Python](#load-and-run-a-model-in-python), +as shown in the following sections. ## Running a model @@ -523,103 +522,8 @@ For more Python sample code, see Tip: Run `help(tf.lite.Interpreter)` in the Python terminal to get detailed documentation about the interpreter. -## Write a custom operator +## Supported operations -All TensorFlow Lite operators (both custom and builtin) are defined using a -simple pure-C interface that consists of four functions: - -```c++ -typedef struct { - void* (*init)(TfLiteContext* context, const char* buffer, size_t length); - void (*free)(TfLiteContext* context, void* buffer); - TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); - TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); -} TfLiteRegistration; -``` - -Refer to `context.h` for details on `TfLiteContext` and `TfLiteNode`. The former -provides error reporting facilities and access to global objects, including all -the tensors. The latter allows implementations to access their inputs and -outputs. - -When the interpreter loads a model, it calls `init()` once for each node in the -graph. A given `init()` will be called more than once if the op is used multiple -times in the graph. For custom ops a configuration buffer will be provided, -containing a flexbuffer that maps parameter names to their values. The buffer is -empty for builtin ops because the interpreter has already parsed the op -parameters. Kernel implementations that require state should initialize it here -and transfer ownership to the caller. For each `init()` call, there will be a -corresponding call to `free()`, allowing implementations to dispose of the -buffer they might have allocated in `init()`. - -Whenever the input tensors are resized, the interpreter will go through the -graph notifying implementations of the change. This gives them the chance to -resize their internal buffer, check validity of input shapes and types, and -recalculate output shapes. This is all done through `prepare()`, and -implementations can access their state using `node->user_data`. - -Finally, each time inference runs, the interpreter traverses the graph calling -`invoke()`, and here too the state is available as `node->user_data`. - -Custom ops can be implemented in exactly the same way as builtin ops, by defined -those four functions and a global registration function that usually looks like -this: - -```c++ -namespace tflite { -namespace ops { -namespace custom { - TfLiteRegistration* Register_MY_CUSTOM_OP() { - static TfLiteRegistration r = {my_custom_op::Init, - my_custom_op::Free, - my_custom_op::Prepare, - my_custom_op::Eval}; - return &r; - } -} // namespace custom -} // namespace ops -} // namespace tflite -``` - -Note that registration is not automatic and an explicit call to -`Register_MY_CUSTOM_OP` should be made somewhere. While the standard -`BuiltinOpResolver` (available from the `:builtin_ops` target) takes care of the -registration of builtins, custom ops will have to be collected in separate -custom libraries. - -### Customize the kernel library - -Behind the scenes the interpreter will load a library of kernels which will be -assigned to execute each of the operators in the model. While the default -library only contains builtin kernels, it is possible to replace it with a -custom library. - -The interpreter uses an `OpResolver` to translate operator codes and names into -actual code: - -```c++ -class OpResolver { - virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; - virtual TfLiteRegistration* FindOp(const char* op) const = 0; - virtual void AddOp(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; - virtual void AddOp(const char* op, TfLiteRegistration* registration) = 0; -}; -``` - -Regular usage requires that you use the `BuiltinOpResolver` and write: - -```c++ -tflite::ops::builtin::BuiltinOpResolver resolver; -``` - -You can optionally register custom ops (before you pass the resolver to the -`InterpreterBuilder`): - -```c++ -resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP()); -``` - -If the set of builtin ops is deemed to be too large, a new `OpResolver` could be -code-generated based on a given subset of ops, possibly only the ones contained -in a given model. This is the equivalent of TensorFlow's selective registration -(and a simple version of it is available in the `tools` directory). +TensorFlow Lite supports a subset of TensorFlow operations with some +limitations. For full list of operations and limitations see +[TF Lite Ops page](https://www.tensorflow.org/mlir/tfl_ops). diff --git a/tensorflow/lite/g3doc/guide/ops_custom.md b/tensorflow/lite/g3doc/guide/ops_custom.md index 5cd43e2b32f..0c229881984 100644 --- a/tensorflow/lite/g3doc/guide/ops_custom.md +++ b/tensorflow/lite/g3doc/guide/ops_custom.md @@ -6,57 +6,179 @@ refer to [operator compatibility](ops_compatibility.md). To allow conversion, users can provide their own custom implementation of an unsupported TensorFlow operator in TensorFlow Lite, known as a custom operator. -*Instead, if you want to combine a series of unsupported (or supported) +*If instead, you wish to combine a series of unsupported (or supported) TensorFlow operators into a single fused optimized custom operator, refer to [operator fusing](https://www.tensorflow.org/lite/convert/operation_fusion).* -Using custom operators consists of three steps. +Using custom operators consists of four steps. -* Make sure the TensorFlow Graph Def or SavedModel refers to the correctly - named TensorFlow Lite operator. +* [Create a TensorFlow Model.](#create-a-tensorflow-model) Make sure the Saved + Model (or Graph Def) refers to the correctly named TensorFlow Lite operator. -* Register a custom kernel with TensorFlow Lite so that the runtime knows how - to map your operator and parameters in your graph to executable C/C++ code. +* [Convert to a TensorFlow Lite Model.](#convert-to-a-tensorflow-lite-model) + Make sure you set the right TensorFlow Lite converter attribute in order to + successfully convert the model. -* Test and profile your operator correctness and performance, respectively. If - you wish to test just your custom operator, it is best to create a model - with just your custom operator and using the - [benchmark model](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/benchmark/benchmark_model_test.cc) - code. +* [Create and register the operator.](#create-and-register-the-operator) This + is so that the TensorFlow Lite runtime knows how to map your operator and + parameters in your graph to executable C/C++ code. -Below we describe a complete example of defining `Sin` and some links to -existing conversion process involving custom operators. +* [Test and profile your operator.](#test-and-profile-your-operator) If you + wish to test just your custom operator, it is best to create a model with + just your custom operator and use the + [benchmark_model](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/benchmark/benchmark_model.cc) + program. -## Making a custom operator for Sin +Let’s walk through an end-to-end example of running a model with a custom +operator `tf.sin` (named as `Sin`, refer to #create-a-tensorflow-model) which is +supported in TensorFlow, but unsupported in TensorFlow Lite. + +Note: In reality, `tf.sin` is **not** a custom operator. It is regular operator +which is supported by both TensorFlow and TensorFlow Lite. But we **assume** +that it is a custom operator in the following example in order to demonstrate a +simple workflow. + +## Example: Custom `Sin` operator Let’s walk through an example of supporting a TensorFlow operator that TensorFlow Lite does not have. Assume we are using the `Sin` operator and that we are building a very simple model for a function `y = sin(x + offset)`, where `offset` is trainable. -### Generating the model from TensorFlow +### Create a TensorFlow Model -The code to train the TensorFlow model will be something like: +The following code snippet trains a simple TensorFlow model. This model just +contains a custom operator named `Sin`, which is a function `y = sin(x + +offset)`, where `offset` is trainable. ```python -offset = tf.get_variable("offset", [1,], tf.float32) -x = tf.placeholder(tf.float32, shape=(None,)) -y = tf.sin(x + offset) -y_ = tf.placeholder(tf.float32, shape=(None,)) -loss = tf.reduce_sum(tf.square(y - y_)) -optimizer = tf.train.GradientDescentOptimizer(0.001) -train = optimizer.minimize(loss) +import tensorflow as tf + +# Define training dataset and variables +x = [-8, 0.5, 2, 2.2, 201] +y = [-0.6569866 , 0.99749499, 0.14112001, -0.05837414, 0.80641841] +offset = tf.Variable(0.0) + +# Define a simple model which just contains a custom operator named `Sin` +@tf.function +def sin(x): + return tf.sin(x + offset, name="Sin") + + # Train model +optimizer = tf.optimizers.Adam(0.01) +def train(x, y): + with tf.GradientTape() as t: + predicted_y = sin(x) + loss = tf.reduce_sum(tf.square(predicted_y - y)) + grads = t.gradient(loss, [offset]) + optimizer.apply_gradients(zip(grads, [offset])) + +for i in range(1000): + train(x, y) + +print("The actual offset is: 1.0") +print("The predicted offset is:", offset.numpy()) ``` -If you convert this model to TensorFlow Lite format using the TensorFlow Lite -Optimizing Converter with `--allow_custom_ops` argument, and run it with the -default interpreter, the interpreter will raise the following error messages: +```python +The actual offset is: 1.0 +The predicted offset is: 1.0000001 +``` + +At this point, if you try to generate a TensorFlow Lite model with the default +converter flags, you will get the following error message: ```none -Didn't find custom op for name 'Sin' +Error: +Some of the operators in the model are not supported by the standard TensorFlow +Lite runtime...... Here is +a list of operators for which you will need custom implementations: Sin. +``` + +### Convert to a TensorFlow Lite Model + +Create a TensorFlow Lite model with custom operators, by setting the converter +attribute `allow_custom_ops` as shown below: + +

+converter = tf.lite.TFLiteConverter.from_concrete_functions([sin.get_concrete_function(x)])
+converter.allow_custom_ops = True
+tflite_model = converter.convert()
+
+ +At this point, if you run it with the default interpreter, you will get the +following error messages: + +```none +Error: +Didn't find custom operator for name 'Sin' Registration failed. ``` +### Create and register the operator. + +All TensorFlow Lite operators (both custom and builtin) are defined using a +simple pure-C interface that consists of four functions: + +```c++ +typedef struct { + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + void (*free)(TfLiteContext* context, void* buffer); + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); +} TfLiteRegistration; +``` + +Refer to +[`common.h`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/common.h) +for details on `TfLiteContext` and `TfLiteNode`. The former provides error +reporting facilities and access to global objects, including all the tensors. +The latter allows implementations to access their inputs and outputs. + +When the interpreter loads a model, it calls `init()` once for each node in the +graph. A given `init()` will be called more than once if the op is used multiple +times in the graph. For custom ops a configuration buffer will be provided, +containing a flexbuffer that maps parameter names to their values. The buffer is +empty for builtin ops because the interpreter has already parsed the op +parameters. Kernel implementations that require state should initialize it here +and transfer ownership to the caller. For each `init()` call, there will be a +corresponding call to `free()`, allowing implementations to dispose of the +buffer they might have allocated in `init()`. + +Whenever the input tensors are resized, the interpreter will go through the +graph notifying implementations of the change. This gives them the chance to +resize their internal buffer, check validity of input shapes and types, and +recalculate output shapes. This is all done through `prepare()`, and +implementations can access their state using `node->user_data`. + +Finally, each time inference runs, the interpreter traverses the graph calling +`invoke()`, and here too the state is available as `node->user_data`. + +Custom ops can be implemented in exactly the same way as builtin ops, by +defining those four functions and a global registration function that usually +looks like this: + +```c++ +namespace tflite { +namespace ops { +namespace custom { + TfLiteRegistration* Register_MY_CUSTOM_OP() { + static TfLiteRegistration r = {my_custom_op::Init, + my_custom_op::Free, + my_custom_op::Prepare, + my_custom_op::Eval}; + return &r; + } +} // namespace custom +} // namespace ops +} // namespace tflite +``` + +Note that registration is not automatic and an explicit call to +`Register_MY_CUSTOM_OP` should be made. While the standard `BuiltinOpResolver` +(available from the `:builtin_ops` target) takes care of the registration of +builtins, custom ops will have to be collected in separate custom libraries. + ### Defining the kernel in the TensorFlow Lite runtime All we need to do to use the op in TensorFlow Lite is define two functions @@ -107,20 +229,53 @@ TfLiteRegistration* Register_SIN() { } ``` -When initializing the `OpResolver`, add the custom op into the resolver. This -will register the operator with Tensorflow Lite so that TensorFlow Lite can use -the new implementation. Note that the last two arguments in `TfLiteRegistration` -correspond to the `SinPrepare` and `SinEval` functions you defined for the -custom op. If you used `SinInit` and `SinFree` functions to initialize variables -used in the op and to free up space, respectively, then they would be added to -the first two arguments of `TfLiteRegistration`; those arguments are set to -`nullptr` in this example. +When initializing the `OpResolver`, add the custom op into the resolver (see +below for an example). This will register the operator with Tensorflow Lite so +that TensorFlow Lite can use the new implementation. Note that the last two +arguments in `TfLiteRegistration` correspond to the `SinPrepare` and `SinEval` +functions you defined for the custom op. If you used `SinInit` and `SinFree` +functions to initialize variables used in the op and to free up space, +respectively, then they would be added to the first two arguments of +`TfLiteRegistration`; those arguments are set to `nullptr` in this example. -```cpp -tflite::ops::builtin::BuiltinOpResolver builtins; -builtins.AddCustom("Sin", Register_SIN()); +### Register the operator with the kernel library + +Now we need to register the operator with the kernel library. This is done with +an `OpResolver`. Behind the scenes, the interpreter will load a library of +kernels which will be assigned to execute each of the operators in the model. +While the default library only contains builtin kernels, it is possible to +replace/augment it with a custom library op operators. + +The `OpResolver` class, which translates operator codes and names into actual +code, is defined like this: + +```c++ +class OpResolver { + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; + virtual void AddCustom(const char* op, TfLiteRegistration* registration) = 0; +}; ``` +Regular usage requires that you use the `BuiltinOpResolver` and write: + +```c++ +tflite::ops::builtin::BuiltinOpResolver resolver; +``` + +To add the custom op created above, you call `AddOp` (before you pass the +resolver to the `InterpreterBuilder`): + +```c++ +resolver.AddCustom("Sin", Register_SIN()); +``` + +If the set of builtin ops is deemed to be too large, a new `OpResolver` could be +code-generated based on a given subset of ops, possibly only the ones contained +in a given model. This is the equivalent of TensorFlow's selective registration +(and a simple version of it is available in the `tools` directory). + If you want to define your custom operators in Java, you would currently need to build your own custom JNI layer and compile your own AAR [in this jni code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc). @@ -133,6 +288,15 @@ operations instead of a single operator. Just add as many `AddCustom` operators as you need. In addition, `BuiltinOpResolver` also allows you to override implementations of builtins by using the `AddBuiltin`. +### Test and profile your operator + +To profile your op with the TensorFlow Lite benchmark tool, you can use the +[benchmark model tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark#tflite-model-benchmark-tool) +for TensorFlow Lite. For testing purposes, you can make your local build of +TensorFlow Lite aware of your custom op by adding the appropriate `AddCustom` +call (as show above) to +[register.cc](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/kernels/register.cc) + ## Best practices 1. Optimize memory allocations and de-allocations cautiously. Allocating memory diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md index 41fc958d53e..ac8eb1975cc 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_nl_classifier.md @@ -67,6 +67,34 @@ See the [source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java) for more details. +## Run inference in Swift + +### Step 1: Import CocoaPods + +Add the TensorFlowLiteTaskText pod in Podfile + +``` +target 'MySwiftAppWithTaskAPI' do + use_frameworks! + pod 'TensorFlowLiteTaskText', '~> 0.0.1-nightly' +end +``` + +### Step 2: Run inference using the API + +```swift +// Initialization +let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier( + modelPath: bertModelPath) + +// Run inference +let categories = bertNLClassifier.classify(text: input) +``` + +See the +[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h) +for more details. + ## Run inference in C++ Note: We are working on improving the usability of the C++ Task Library, such as diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md index 9c41f23aff8..f5d9aff7b6d 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/bert_question_answerer.md @@ -61,13 +61,41 @@ BertQuestionAnswerer answerer = BertQuestionAnswerer.createFromFile(androidConte // Run inference List answers = answerer.answer(contextOfTheQuestion, questionToAsk); -); ``` See the [source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java) for more details. +## Run inference in Swift + +### Step 1: Import CocoaPods + +Add the TensorFlowLiteTaskText pod in Podfile + +``` +target 'MySwiftAppWithTaskAPI' do + use_frameworks! + pod 'TensorFlowLiteTaskText', '~> 0.0.1-nightly' +end +``` + +### Step 2: Run inference using the API + +```swift +// Initialization +let mobileBertAnswerer = TFLBertQuestionAnswerer.questionAnswerer( + modelPath: mobileBertModelPath) + +// Run inference +let answers = mobileBertAnswerer.answer( + context: context, question: question) +``` + +See the +[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h) +for more details. + ## Run inference in C++ Note: we are working on improving the usability of the C++ Task Library, such as diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md index 94d6b513096..391e75a4b7d 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/image_classifier.md @@ -50,6 +50,10 @@ API. ## Run inference in Java +See the +[Image Classification reference app](https://github.com/tensorflow/examples/blob/master/lite/examples/image_classification/android/EXPLORE_THE_CODE.md) +for an example of how to use `ImageClassifier` in an Android app. + ### Step 1: Import Gradle dependency and other settings Copy the `.tflite` model file to the assets directory of the Android module diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md index c2c1b2ac68b..f773b260332 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/nl_classifier.md @@ -29,6 +29,10 @@ API. ## Run inference in Java +See the +[Text Classification reference app](https://github.com/tensorflow/examples/blob/master/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/client/TextClassificationClient.java) +for an example of how to use `NLClassifier` in an Android app. + ### Step 1: Import Gradle dependency and other settings Copy the `.tflite` model file to the assets directory of the Android module @@ -69,6 +73,38 @@ See the [source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java) for more options to configure `NLClassifier`. +## Run inference in Swift + +### Step 1: Import CocoaPods + +Add the TensorFlowLiteTaskText pod in Podfile + +``` +target 'MySwiftAppWithTaskAPI' do + use_frameworks! + pod 'TensorFlowLiteTaskText', '~> 0.0.1-nightly' +end +``` + +### Step 2: Run inference using the API + +```swift +// Initialization +var modelOptions:TFLNLClassifierOptions = TFLNLClassifierOptions() +modelOptions.inputTensorName = inputTensorName +modelOptions.outputScoreTensorName = outputScoreTensorName +let nlClassifier = TFLNLClassifier.nlClassifier( + modelPath: modelPath, + options: modelOptions) + +// Run inference +let categories = nlClassifier.classify(text: input) +``` + +See the +[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h) +for more details. + ## Run inference in C++ Note: We are working on improving the usability of the C++ Task Library, such as diff --git a/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md b/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md index 2a2b124c7f9..d054e4717d7 100644 --- a/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md +++ b/tensorflow/lite/g3doc/inference_with_metadata/task_library/overview.md @@ -6,7 +6,7 @@ It provides optimized out-of-box model interfaces for popular machine learning tasks, such as image classification, question and answer, etc. The model interfaces are specifically designed for each task to achieve the best performance and usability. Task Library works cross-platform and is supported on -Java, C++, and Swift (coming soon). +Java, C++, and Swift. ## What to expect from the Task Library diff --git a/tensorflow/lite/g3doc/models/image_classification/overview.md b/tensorflow/lite/g3doc/models/image_classification/overview.md index cff23134bda..8f401e916b5 100644 --- a/tensorflow/lite/g3doc/models/image_classification/overview.md +++ b/tensorflow/lite/g3doc/models/image_classification/overview.md @@ -34,6 +34,18 @@ For each example, we provide a guide that explains how it works. #### Android +You can leverage the out-of-box API from TensorFlow Lite Task Library to +[integrate image classification models](../../inference_with_metadata/task_library/image_classifier) +in just a few lines of code. You can also +[build your own custom inference pipleline](../../inference_with_metadata/lite_support) +using the TensorFlow Lite Support Library. + +The Android example below demonstrates the implementation for both methods as +[lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api) +and +[lib_support](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support), +respectively. + View Android example diff --git a/tensorflow/lite/g3doc/models/object_detection/overview.md b/tensorflow/lite/g3doc/models/object_detection/overview.md index 5ceb6a16ee1..cbe2fc05922 100644 --- a/tensorflow/lite/g3doc/models/object_detection/overview.md +++ b/tensorflow/lite/g3doc/models/object_detection/overview.md @@ -2,45 +2,55 @@ -Detect multiple objects within an image, with bounding boxes. Recognize 90 -different classes of objects. +Given an image or a video stream, an object detection model can identify which +of a known set of objects might be present and provide information about their +positions within the image. + +For example, this screenshot of the example +application shows how two objects have been recognized and their positions +annotated: + +Screenshot of Android example ## Get started -If you are new to TensorFlow Lite and are working with Android or iOS, we -recommend exploring the following example applications that can help you get -started. +If you are new to TensorFlow Lite and are working with Android or iOS, download +the following example applications to get started. Android example iOS example -If you are using a platform other than Android or iOS, or you are already -familiar with the TensorFlow Lite APIs, you can -download our starter object detection model and the accompanying labels. +If you are using a platform other than Android or iOS, or if you are already +familiar with the +TensorFlow Lite +APIs, you can download the starter object detection model and the +accompanying labels. Download starter model with Metadata -For more information about the starter model, see -Starter model. - For more information about Metadata and associated fields (eg: `labels.txt`) see Read the metadata from models -## What is object detection? +If you want to train a custom detection model for your own task, see +Model customization. -Given an image or a video stream, an object detection model can identify which -of a known set of objects might be present and provide information about their -positions within the image. +For the following use cases, you should use a different type of model: -For example, this screenshot of our example -application shows how two objects have been recognized and their positions -annotated: +
    +
  • Predicting which single label the image most likely represents (see image classification)
  • +
  • Predicting the composition of an image, for example subject versus background (see segmentation)
  • +
-Screenshot of Android example +## Model description + +This section describes the signature for +[Single-Shot Detector](https://arxiv.org/abs/1512.02325) models converted to +TensorFlow Lite from the +[TensorFlow Object Detection API](https://github.com/tensorflow/models/blob/master/research/object_detection/). An object detection model is trained to detect the presence and location of multiple classes of objects. For example, a model might be trained with images @@ -48,15 +58,69 @@ that contain various pieces of fruit, along with a _label_ that specifies the class of fruit they represent (e.g. an apple, a banana, or a strawberry), and data specifying where each object appears in the image. -When we subsequently provide an image to the model, it will output a list of the -objects it detects, the location of a bounding box that contains each object, -and a score that indicates the confidence that detection was correct. +When an image is subsequently provided to the model, it will output a list of +the objects it detects, the location of a bounding box that contains each +object, and a score that indicates the confidence that detection was correct. -### Model output +### Input Signature -Imagine a model has been trained to detect apples, bananas, and strawberries. -When we pass it an image, it will output a set number of detection results - in -this example, 5. +The model takes an image as input. + +Lets assume the expected image is 300x300 pixels, with three channels (red, +blue, and green) per pixel. This should be fed to the model as a flattened +buffer of 270,000 byte values (300x300x3). If the model is +quantized, each +value should be a single byte representing a value between 0 and 255. + +You can take a look at our +[example app code](https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android) +to understand how to do this pre-processing on Android. + +### Output Signature + +The model outputs four arrays, mapped to the indices 0-4. Arrays 0, 1, and 2 +describe `N` detected objects, with one element in each array corresponding to +each object. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
IndexNameDescription
0LocationsMultidimensional array of [N][4] floating point values between 0 and 1, the inner arrays representing bounding boxes in the form [top, left, bottom, right]
1ClassesArray of N integers (output as floating point values) each indicating the index of a class label from the labels file
2ScoresArray of N floating point values between 0 and 1 representing probability that a class was detected
3Number of detectionsInteger value of N
+ +NOTE: The number of results (10 in the above case) is a parameter set while +exporting the detection model to TensorFlow Lite. See +Model customization for more details. + +For example, imagine a model has been trained to detect apples, bananas, and +strawberries. When provided an image, it will output a set number of detection +results - in this example, 5. @@ -95,7 +159,7 @@ this example, 5.
-### Confidence score +#### Confidence score To interpret these results, we can look at the score and the location for each detected object. The score is a number between 0 and 1 that indicates confidence @@ -103,10 +167,10 @@ that the object was genuinely detected. The closer the number is to 1, the more confident the model is. Depending on your application, you can decide a cut-off threshold below which -you will discard detection results. For our example, we might decide a sensible -cut-off is a score of 0.5 (meaning a 50% probability that the detection is -valid). In that case, we would ignore the last two objects in the array, because -those confidence scores are below 0.5: +you will discard detection results. For the current example, a sensible cut-off +is a score of 0.5 (meaning a 50% probability that the detection is valid). In +that case, the last two objects in the array would be ignored because those +confidence scores are below 0.5: @@ -158,11 +222,11 @@ positive. Screenshot of Android example showing a false positive -### Location +#### Location For each detected object, the model will return an array of four numbers representing a bounding rectangle that surrounds its position. For the starter -model we provide, the numbers are ordered as follows: +model provided, the numbers are ordered as follows:
@@ -186,7 +250,9 @@ Note: Object detection models accept input images of a specific size. This is li ## Performance benchmarks -Performance benchmark numbers are generated with the tool +Performance benchmark numbers for our +starter +model are generated with the tool [described here](https://www.tensorflow.org/lite/performance/benchmarks).
@@ -226,79 +292,53 @@ Performance benchmark numbers are generated with the tool \*\* 2 threads used on iPhone for the best performance result. -## Starter model +## Model Customization -We recommend starting with this pre-trained quantized COCO SSD MobileNet v1 -model. +### Pre-trained models -Download -starter model and labels +Mobile-optimized detection models with a variety of latency and precision +characteristics can be found in the +[Detection Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md#mobile-models). +Each one of them follows the input and output signatures described in the +following sections. -### Uses and limitations +Most of the download zips contain a `model.tflite` file. If there isn't one, a +TensorFlow Lite flatbuffer can be generated using +[these instructions](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md). +SSD models from the +[TF2 Object Detection Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md) +can also be converted to TensorFlow Lite using the instructions +[here](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tf2.md). +It is important to note that detection models cannot be converted directly using +the [TensorFlow Lite Converter](https://www.tensorflow.org/lite/convert), since +they require an intermediate step of generating a mobile-friendly source model. +The scripts linked above perform this step. -The object detection model we provide can identify and locate up to 10 objects -in an image. It is trained to recognize 90 classes of objects. For a full list -of classes, see the labels file embedded in the model with -metadata -visualiztion. +Both the +[TF1](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md) +& +[TF2](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md) +exporting scripts have parameters that can enable a larger number of output +objects or slower, more-accurate post processing. Please use `--help` with the +scripts to see an exhaustive list of supported arguments. -If you want to train a model to recognize new classes, see -Customize model. +> Currently, on-device inference is only optimized with SSD models. Better +> support for other architectures like CenterNet and EfficientDet is being +> investigated. -For the following use cases, you should use a different type of model: +### How to choose a model to customize? -
    -
  • Predicting which single label the image most likely represents (see image classification)
  • -
  • Predicting the composition of an image, for example subject versus background (see segmentation)
  • -
+Each model comes with its own precision (quantified by mAP value) and latency +characteristics. You should choose a model that works the best for your use-case +and intended hardware. For example, the +[Edge TPU](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md#pixel4-edge-tpu-models) +models are ideal for inference on Google's Edge TPU on Pixel 4. -### Input +You can use our +[benchmark tool](https://www.tensorflow.org/lite/performance/measurement) to +evaluate models and choose the most efficient option available. -The model takes an image as input. The expected image is 300x300 pixels, with -three channels (red, blue, and green) per pixel. This should be fed to the model -as a flattened buffer of 270,000 byte values (300x300x3). Since the model is -quantized, each -value should be a single byte representing a value between 0 and 255. - -### Output - -The model outputs four arrays, mapped to the indices 0-4. Arrays 0, 1, and 2 -describe 10 detected objects, with one element in each array corresponding to -each object. There will always be 10 objects detected. - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
IndexNameDescription
0LocationsMultidimensional array of [10][4] floating point values between 0 and 1, the inner arrays representing bounding boxes in the form [top, left, bottom, right]
1ClassesArray of 10 integers (output as floating point values) each indicating the index of a class label from the labels file
2ScoresArray of 10 floating point values between 0 and 1 representing probability that a class was detected
3Number and detectionsArray of length 1 containing a floating point value expressing the total number of detection results
- -## Customize model +## Fine-tuning models on custom data The pre-trained models we provide are trained to detect 90 classes of objects. For a full list of classes, see the labels file in the @@ -309,8 +349,15 @@ You can use a technique known as transfer learning to re-train a model to recognize classes not in the original set. For example, you could re-train the model to detect multiple types of vegetable, despite there only being one vegetable in the original training data. To do this, you will need a set of -training images for each of the new labels you wish to train. +training images for each of the new labels you wish to train. Please see our +[Few-shot detection Colab](https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/eager_few_shot_od_training_tflite.ipynb) +as an example of fine-tuning a pre-trained model with few examples. -Learn how to perform transfer learning in -Training -and serving a real-time mobile object detector in 30 minutes. +For fine-tuning with larger datasets, take a look at the these guides for +training your own models with the TensorFlow Object Detection API: +[TF1](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_training_and_evaluation.md), +[TF2](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_training_and_evaluation.md). +Once trained, they can be converted to a TFLite-friendly format with the +instructions here: +[TF1](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md), +[TF2](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md) diff --git a/tensorflow/lite/g3doc/models/text_classification/overview.md b/tensorflow/lite/g3doc/models/text_classification/overview.md index c3463188394..1761974ee3f 100644 --- a/tensorflow/lite/g3doc/models/text_classification/overview.md +++ b/tensorflow/lite/g3doc/models/text_classification/overview.md @@ -7,7 +7,16 @@ Use a pre-trained model to category a paragraph into predefined groups. If you are new to TensorFlow Lite and are working with Android, we recommend -exploring the following example applications that can help you get started. +exploring the guide of TensorFLow Lite Task Library to +[integrate text classification models](../../inference_with_metadata/task_library/nl_classifier). +within just a few lines of code. You can also integrate the model using the +[TensorFlow Lite Interpreter Java API](../../guide/inference#load_and_run_a_model_in_java). + +The Android example below demonstrates the implementation for both methods as +[lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android/lib_task_api) +and +[lib_interpreter](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android/lib_interpreter), +respectively. Android example @@ -60,7 +69,7 @@ Performance benchmark numbers are generated with the tool - Text Classification + Text Classification 0.6 Mb diff --git a/tensorflow/lite/g3doc/performance/delegates.md b/tensorflow/lite/g3doc/performance/delegates.md index 14aeece21fa..6b233075398 100644 --- a/tensorflow/lite/g3doc/performance/delegates.md +++ b/tensorflow/lite/g3doc/performance/delegates.md @@ -21,10 +21,11 @@ TensorFlow Lite provides the following delegates for hardware acceleration: * **GPU delegate for cross platform acceleration** - The GPU delegate can be used on both Android and iOS. It is optimized to run 32-bit and 16-bit float - based models where a GPU is available. For an overview of the GPU delegate, - see [TensorFlow Lite on GPU](gpu_advanced.md). For step-by-step tutorials on - using the GPU delegate with Android and iOS, see - [TensorFlow Lite GPU Delegate Tutorial](gpu.md). + based models where a GPU is available. It also supports 8-bit quantized + models and provides GPU performance on par with their float versions. For + details on the GPU delegate, see [TensorFlow Lite on GPU](gpu_advanced.md). + For step-by-step tutorials on using the GPU delegate with Android and iOS, + see [TensorFlow Lite GPU Delegate Tutorial](gpu.md). * **NNAPI delegate for newer Android devices** - The NNAPI delegate can be used to accelerate models on Android devices with GPU, DSP and / or NPU available. It is available in Android 8.1 (API 27+) or higher. For an diff --git a/tensorflow/lite/g3doc/performance/gpu.md b/tensorflow/lite/g3doc/performance/gpu.md index 3d9989558f5..96e8aa6f9dc 100644 --- a/tensorflow/lite/g3doc/performance/gpu.md +++ b/tensorflow/lite/g3doc/performance/gpu.md @@ -14,6 +14,10 @@ run fast enough for previously not available real-time applications. Unlike CPUs, GPUs compute with 16-bit or 32-bit floating point numbers and do not require quantization for optimal performance. +**NOTE:** The delegate does accept 8-bit quantized models on Android. Support on +iOS is experimental. Refer to the [advanced documentation](gpu_advanced.md) for +details. + Another benefit with GPU inference is its power efficiency. GPUs carry out the computations in a very efficient and optimized manner, so that they consume less power and generate less heat than when the same task is run on CPUs. diff --git a/tensorflow/lite/g3doc/performance/gpu_advanced.md b/tensorflow/lite/g3doc/performance/gpu_advanced.md index 652100ab850..71415693f86 100644 --- a/tensorflow/lite/g3doc/performance/gpu_advanced.md +++ b/tensorflow/lite/g3doc/performance/gpu_advanced.md @@ -285,12 +285,10 @@ While it is convenient to use `nullptr`, we recommend that you explicitly set the options, to avoid any unexpected behavior if default values are changed in the future. -### Running quantized models (Experimental) +### Running quantized models on GPU -The GPU delegate already supports -[float16 quantized](https://www.tensorflow.org/lite/performance/post_training_float16_quant) -models. There is experimental support on Android and iOS to run 8-bit quantized -as well. This includes all flavors of quantization, including: +This section explains how the GPU delegate accelerates 8-bit quantized models. +This includes all flavors of quantization, including: * Models trained with [Quantization-aware training](https://www.tensorflow.org/lite/convert/quantization) @@ -322,12 +320,14 @@ This feature can be enabled using delegate options as follows: #### Android +Android APIs support quantized models by default. To disable, do the following: + **C++ API** ```c++ // NEW: Prepare custom options with feature enabled. TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); -options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT; +options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE; auto* delegate = TfLiteGpuDelegateV2Create(options); if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; @@ -337,13 +337,16 @@ if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; ```java // NEW: Prepare GPU delegate with feature turned on. -GpuDelegate delegate = new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true)); +GpuDelegate delegate = new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(false)); Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate); ``` #### iOS +Support for quantized models on iOS APIs is experimental. To enable, do the +following: + **Swift API** ```swift diff --git a/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md b/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md index e38c2b3e215..a2377412c8d 100644 --- a/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md +++ b/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md @@ -104,8 +104,8 @@ tflite_convert \ --std_dev_values=127.7 ``` -*If you're setting `--inference_type=QUANTIZED_UINT8` then update -`--mean_values=128` and `--std_dev_values=127`* +*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and +`--std_dev_values=127`* #### Convert a model with \"dummy-quantization\" into a quantized TensorFlow Lite model @@ -134,8 +134,8 @@ tflite_convert \ --default_ranges_max=6 ``` -*If you're setting `--inference_type=QUANTIZED_UINT8` then update -`--mean_values=128` and `--std_dev_values=127`* +*If you're setting `--inference_type=UINT8` then update `--mean_values=128` and +`--std_dev_values=127`* #### Convert a model with select TensorFlow operators. diff --git a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md index 826bb7afdbb..386d9063f9f 100644 --- a/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md +++ b/tensorflow/lite/g3doc/r1/convert/cmdline_reference.md @@ -63,8 +63,7 @@ based on index. has a shape of [2, 3] and "bar" has a shape of [4, 5, 6]. * `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats. These specify the (de-)quantization parameters of the input array, when it - is quantized. This is only needed if `inference_input_type` is `INT8` or - `QUANTIZED_UINT8`. + is quantized. Only needed if `inference_input_type` is `INT8` or `UINT8`. * The meaning of `mean_values` and `std_dev_values` is as follows: each quantized value in the quantized input array will be interpreted as a mathematical real number (i.e. as an input activation value) according @@ -75,12 +74,12 @@ based on index. the inference code according to the above formula, before proceeding with float inference. * When performing quantized inference (`inference_type` - is`INT8`or`QUANTIZED_UINT8`), no dequantization is performed by the - inference code. However, the quantization parameters of all arrays, - including those of the input arrays as specified - by`mean_value`and`std_dev_value`, determine the fixed-point multipliers - used in the quantized inference code.`mean_value` must be an integer - when performing quantized inference. + is`INT8`or`UINT8`), no dequantization is performed by the inference + code. However, the quantization parameters of all arrays, including + those of the input arrays as specified by`mean_value`and`std_dev_value`, + determine the fixed-point multipliers used in the quantized inference + code.`mean_value` must be an integer when performing quantized + inference. ## Transformation flags @@ -90,7 +89,7 @@ have. * `--inference_type`. Type: string. Default: `FLOAT`. Data type of all real-number arrays in the output file except for input arrays (defined by - `--inference_input_type`). Must be `{FLOAT, INT8, QUANTIZED_UINT8}`. + `--inference_input_type`). Must be `{FLOAT, INT8, UINT8}`. This flag only impacts real-number arrays including float and quantized arrays. This excludes all other data types including plain integer arrays @@ -102,16 +101,15 @@ have. * If `INT8`, then real-numbers arrays will be quantized as int8 in the output file. If they were float in the input file, then they get quantized. - * If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as - uint8 in the output file. If they were float in the input file, then - they get quantized. + * If `UINT8`, then real-numbers arrays will be quantized as uint8 in the + output file. If they were float in the input file, then they get + quantized. * `--inference_input_type`. Type: string. Data type of a real-number input array in the output file. By default the `--inference_type` is used as type of all of the input arrays. Flag is primarily intended for generating a float-point graph with a quantized input array. A Dequantized operator is - added immediately after the input array. Must be `{FLOAT, INT8, - QUANTIZED_UINT8}`. + added immediately after the input array. Must be `{FLOAT, INT8, UINT8}`. The flag is typically used for vision models taking a bitmap as input but requiring floating-point inference. For such image models, the uint8 input diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb index 7be3820c3ef..ef650c6b05b 100644 --- a/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb +++ b/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb @@ -1,25 +1,9 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "model_maker_image_classification.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "h2q27gKz1H20", - "colab_type": "text" + "id": "h2q27gKz1H20" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." @@ -27,12 +11,12 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "cellView": "form", - "id": "TUfAcER1oUS6", - "colab_type": "code", - "colab": {} + "id": "TUfAcER1oUS6" }, + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -45,15 +29,12 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "Gb7qyhNL1yWt", - "colab_type": "text" + "id": "Gb7qyhNL1yWt" }, "source": [ "# Image classification with TensorFlow Lite Model Maker" @@ -62,31 +43,32 @@ { "cell_type": "markdown", "metadata": { - "id": "nDABAblytltI", - "colab_type": "text" + "id": "nDABAblytltI" }, "source": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " View on TensorFlow.org\n", - " \n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - " \n", - " Download notebook\n", - "
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/tutorials/model_maker_image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/tutorials/model_maker_image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://tfhub.dev/\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" ] }, { "cell_type": "markdown", "metadata": { - "id": "m86-Nh4pMHqY", - "colab_type": "text" + "id": "m86-Nh4pMHqY" }, "source": [ "Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.\n", @@ -97,8 +79,7 @@ { "cell_type": "markdown", "metadata": { - "id": "bcLF2PKkSbV3", - "colab_type": "text" + "id": "bcLF2PKkSbV3" }, "source": [ "## Prerequisites\n", @@ -108,21 +89,18 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "6cv3K3oaksJv", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, + "id": "6cv3K3oaksJv", "outputId": "911fb544-f618-4cf7-e11d-f42c460d8f67" }, - "source": [ - "!pip install tflite-model-maker" - ], - "execution_count": 1, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Collecting tflite-model-maker\n", @@ -131,11 +109,11 @@ "\u001b[?25hCollecting sentencepiece\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n", "\u001b[K |████████████████████████████████| 1.1MB 8.9MB/s \n", - "\u001b[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)\n", + "\u001b[?25hRequirement already satisfied: numpy\u003e=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)\n", "Collecting tflite-support==0.1.0rc3.dev2\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)\n", "\u001b[K |████████████████████████████████| 1.0MB 15.4MB/s \n", - "\u001b[?25hRequirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)\n", + "\u001b[?25hRequirement already satisfied: tensorflow-hub\u003e=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)\n", "Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.10.0)\n", "Collecting flatbuffers==1.12\n", @@ -146,105 +124,105 @@ "\u001b[?25hCollecting fire\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)\n", "\u001b[K |████████████████████████████████| 81kB 9.1MB/s \n", - "\u001b[?25hRequirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)\n", + "\u001b[?25hRequirement already satisfied: tensorflow-datasets\u003e=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)\n", "Collecting tf-nightly\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)\n", "\u001b[K |████████████████████████████████| 390.2MB 43kB/s \n", - "\u001b[?25hCollecting pybind11>=2.4\n", + "\u001b[?25hCollecting pybind11\u003e=2.4\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)\n", "\u001b[K |████████████████████████████████| 296kB 50.8MB/s \n", - "\u001b[?25hRequirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (1.15.0)\n", - "Requirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (3.12.4)\n", - "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)\n", - "Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)\n", - "Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)\n", - "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)\n", - "Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)\n", - "Collecting py-cpuinfo>=3.3.0\n", + "\u001b[?25hRequirement already satisfied: six\u003e=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub\u003e=0.8.0-\u003etflite-model-maker) (1.15.0)\n", + "Requirement already satisfied: protobuf\u003e=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub\u003e=0.8.0-\u003etflite-model-maker) (3.12.4)\n", + "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.3.0)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (3.2.2)\n", + "Requirement already satisfied: kaggle\u003e=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.5.8)\n", + "Requirement already satisfied: pandas\u003e=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.0.5)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.7)\n", + "Requirement already satisfied: google-api-python-client\u003e=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.7.12)\n", + "Collecting py-cpuinfo\u003e=3.3.0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)\n", "\u001b[K |████████████████████████████████| 102kB 13.7MB/s \n", - "\u001b[?25hCollecting tf-slim>=1.1.0\n", + "\u001b[?25hCollecting tf-slim\u003e=1.1.0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)\n", "\u001b[K |████████████████████████████████| 358kB 46.5MB/s \n", "\u001b[?25hCollecting opencv-python-headless\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)\n", "\u001b[K |████████████████████████████████| 36.6MB 92kB/s \n", - "\u001b[?25hRequirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)\n", - "Collecting tensorflow-model-optimization>=0.4.1\n", + "\u001b[?25hRequirement already satisfied: scipy\u003e=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.4.1)\n", + "Collecting tensorflow-model-optimization\u003e=0.4.1\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)\n", "\u001b[K |████████████████████████████████| 174kB 54.9MB/s \n", - "\u001b[?25hRequirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)\n", - "Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)\n", + "\u001b[?25hRequirement already satisfied: google-cloud-bigquery\u003e=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.21.0)\n", + "Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (4.1.3)\n", "Collecting seqeval\n", " Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz\n", - "Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)\n", - "Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)\n", - "Collecting pyyaml>=5.1\n", + "Requirement already satisfied: psutil\u003e=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (5.4.8)\n", + "Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (2.0.2)\n", + "Collecting pyyaml\u003e=5.1\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", "\u001b[K |████████████████████████████████| 276kB 47.1MB/s \n", - "\u001b[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)\n", - "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)\n", - "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->tflite-model-maker) (1.1.0)\n", - "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)\n", - "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)\n", - "Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)\n", - "Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)\n", - "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)\n", - "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)\n", - "Requirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (1.12.1)\n", - "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)\n", - "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)\n", - "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)\n", - "Collecting tb-nightly<3.0.0a0,>=2.4.0a0\n", + "\u001b[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.29.21)\n", + "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.8.3)\n", + "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire-\u003etflite-model-maker) (1.1.0)\n", + "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.24.0)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.16.0)\n", + "Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.3)\n", + "Requirement already satisfied: attrs\u003e=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (20.2.0)\n", + "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.3.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (4.41.1)\n", + "Requirement already satisfied: requests\u003e=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.23.0)\n", + "Requirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (1.12.1)\n", + "Requirement already satisfied: wheel\u003e=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.35.1)\n", + "Requirement already satisfied: opt-einsum\u003e=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (3.3.0)\n", + "Requirement already satisfied: grpcio\u003e=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.32.0)\n", + "Collecting tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)\n", "\u001b[K |████████████████████████████████| 10.1MB 45.9MB/s \n", - "\u001b[?25hRequirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)\n", - "Requirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)\n", + "\u001b[?25hRequirement already satisfied: keras-preprocessing\u003c1.2,\u003e=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.1.2)\n", + "Requirement already satisfied: google-pasta\u003e=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.2.0)\n", "Collecting tf-estimator-nightly\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)\n", "\u001b[K |████████████████████████████████| 460kB 45.6MB/s \n", - "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)\n", - "Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)\n", - "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)\n", - "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow-hub>=0.8.0->tflite-model-maker) (50.3.0)\n", - "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.8.1)\n", - "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)\n", - "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)\n", - "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)\n", - "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)\n", - "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)\n", - "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)\n", - "Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (1.17.2)\n", - "Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)\n", - "Requirement already satisfied: httplib2<1dev,>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.17.4)\n", - "Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)\n", - "Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)\n", - "Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)\n", - "Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)\n", - "Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)\n", - "Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)\n", - "Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)\n", - "Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)\n", - "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)\n", - "Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)\n", - "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)\n", - "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)\n", - "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)\n", - "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)\n", - "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)\n", - "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (4.1.1)\n", - "Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)\n", - "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)\n", + "\u001b[?25hRequirement already satisfied: typing-extensions\u003e=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (3.7.4.3)\n", + "Requirement already satisfied: h5py\u003c2.11.0,\u003e=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (2.10.0)\n", + "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.3.3)\n", + "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.6.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf\u003e=3.8.0-\u003etensorflow-hub\u003e=0.8.0-\u003etflite-model-maker) (50.3.0)\n", + "Requirement already satisfied: python-dateutil\u003e=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (2.8.1)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,\u003e=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (2.4.7)\n", + "Requirement already satisfied: kiwisolver\u003e=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (1.2.0)\n", + "Requirement already satisfied: cycler\u003e=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (0.10.0)\n", + "Requirement already satisfied: urllib3\u003c1.25,\u003e=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (1.24.3)\n", + "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (0.0.1)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (2020.6.20)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (4.0.1)\n", + "Requirement already satisfied: pytz\u003e=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas\u003e=0.22.0-\u003etf-models-nightly-\u003etflite-model-maker) (2018.9)\n", + "Requirement already satisfied: google-auth\u003e=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (1.17.2)\n", + "Requirement already satisfied: uritemplate\u003c4dev,\u003e=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (3.0.1)\n", + "Requirement already satisfied: httplib2\u003c1dev,\u003e=0.17.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (0.17.4)\n", + "Requirement already satisfied: google-auth-httplib2\u003e=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (0.0.4)\n", + "Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization\u003e=0.4.1-\u003etf-models-nightly-\u003etflite-model-maker) (0.1.5)\n", + "Requirement already satisfied: google-cloud-core\u003c2.0dev,\u003e=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (1.0.3)\n", + "Requirement already satisfied: google-resumable-media!=0.4.0,\u003c0.5.0dev,\u003e=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (0.4.1)\n", + "Requirement already satisfied: pyasn1\u003e=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.4.8)\n", + "Requirement already satisfied: rsa\u003e=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (4.6)\n", + "Requirement already satisfied: pyasn1-modules\u003e=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.2.8)\n", + "Requirement already satisfied: Keras\u003e=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval-\u003etf-models-nightly-\u003etflite-model-maker) (2.4.3)\n", + "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons-\u003etf-models-nightly-\u003etflite-model-maker) (2.7.1)\n", + "Requirement already satisfied: googleapis-common-protos\u003c2,\u003e=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (1.52.0)\n", + "Requirement already satisfied: chardet\u003c4,\u003e=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests\u003e=2.19.0-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (3.0.4)\n", + "Requirement already satisfied: idna\u003c3,\u003e=2.5 in /usr/local/lib/python3.6/dist-packages (from requests\u003e=2.19.0-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.10)\n", + "Requirement already satisfied: werkzeug\u003e=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.0.1)\n", + "Requirement already satisfied: markdown\u003e=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.2.2)\n", + "Requirement already satisfied: google-auth-oauthlib\u003c0.5,\u003e=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (0.4.1)\n", + "Requirement already satisfied: tensorboard-plugin-wit\u003e=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.7.0)\n", + "Requirement already satisfied: text-unidecode\u003e=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify-\u003ekaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (1.3)\n", + "Requirement already satisfied: cachetools\u003c5.0,\u003e=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth\u003e=1.4.1-\u003egoogle-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (4.1.1)\n", + "Requirement already satisfied: google-api-core\u003c2.0.0dev,\u003e=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core\u003c2.0dev,\u003e=1.0.3-\u003egoogle-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (1.16.0)\n", + "Requirement already satisfied: importlib-metadata; python_version \u003c \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown\u003e=2.6.8-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.7.0)\n", + "Requirement already satisfied: requests-oauthlib\u003e=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib\u003c0.5,\u003e=0.4.1-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.3.0)\n", + "Requirement already satisfied: zipp\u003e=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version \u003c \"3.8\"-\u003emarkdown\u003e=2.6.8-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.1.0)\n", + "Requirement already satisfied: oauthlib\u003e=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib\u003e=0.7.0-\u003egoogle-auth-oauthlib\u003c0.5,\u003e=0.4.1-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.1.0)\n", "Building wheels for collected packages: fire, py-cpuinfo, seqeval, pyyaml\n", " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=8f09a5a04716eb30229b33f5a9031fa22413bd4f709aac5155f4f26c6b070f47\n", @@ -264,16 +242,17 @@ " Uninstalling PyYAML-3.13:\n", " Successfully uninstalled PyYAML-3.13\n", "Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "!pip install tflite-model-maker" ] }, { "cell_type": "markdown", "metadata": { - "id": "Gx1HGRoFQ54j", - "colab_type": "text" + "id": "Gx1HGRoFQ54j" }, "source": [ "Import the required packages." @@ -281,11 +260,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "XtxiUeZEiXpt", - "colab_type": "code", - "colab": {} + "id": "XtxiUeZEiXpt" }, + "outputs": [], "source": [ "import numpy as np\n", "\n", @@ -293,20 +272,18 @@ "assert tf.__version__.startswith('2')\n", "\n", "from tflite_model_maker import configs\n", + "from tflite_model_maker import ExportFormat\n", "from tflite_model_maker import image_classifier\n", "from tflite_model_maker import ImageClassifierDataLoader\n", "from tflite_model_maker import model_spec\n", "\n", "import matplotlib.pyplot as plt" - ], - "execution_count": 2, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "KKRaYHABpob5", - "colab_type": "text" + "id": "KKRaYHABpob5" }, "source": [ "## Simple End-to-End Example" @@ -315,8 +292,7 @@ { "cell_type": "markdown", "metadata": { - "id": "SiZZ5DHXotaW", - "colab_type": "text" + "id": "SiZZ5DHXotaW" }, "source": [ "### Get the data path\n", @@ -326,51 +302,48 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "cellView": "form", - "id": "3jz5x0JoskPv", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, + "id": "3jz5x0JoskPv", "outputId": "07a05f1d-b20d-4f80-8175-76416747476b" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\n", + "228818944/228813984 [==============================] - 1s 0us/step\n" + ] + } + ], "source": [ "image_path = tf.keras.utils.get_file(\n", " 'flower_photos',\n", " 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n", " untar=True)" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\n", - "228818944/228813984 [==============================] - 1s 0us/step\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "a55MR6i6nuDm", - "colab_type": "text" + "id": "a55MR6i6nuDm" }, "source": [ "You could replace `image_path` with your own image folders. As for uploading data to colab, you could find the upload button in the left sidebar shown in the image below with the red rectangle. Just have a try to upload a zip file and unzip it. The root file path is the current path.\n", "\n", - "\"Upload" + "\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_image_classification.png\" alt=\"Upload File\" width=\"800\" hspace=\"100\"\u003e" ] }, { "cell_type": "markdown", "metadata": { - "id": "NNRNv_mloS89", - "colab_type": "text" + "id": "NNRNv_mloS89" }, "source": [ "If you prefer not to upload your images to the cloud, you could try to run the library locally following the [guide](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker) in github." @@ -379,8 +352,7 @@ { "cell_type": "markdown", "metadata": { - "id": "w-VDriAdsowu", - "colab_type": "text" + "id": "w-VDriAdsowu" }, "source": [ "### Run the example\n", @@ -390,8 +362,7 @@ { "cell_type": "markdown", "metadata": { - "id": "6ahtcO86tZBL", - "colab_type": "text" + "id": "6ahtcO86tZBL" }, "source": [ "Step 1. Load input data specific to an on-device ML app. Split it to training data and testing data." @@ -399,35 +370,33 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "lANoNS_gtdH1", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, + "id": "lANoNS_gtdH1", "outputId": "d2ad1069-373b-47ff-ead1-92982df9f652" }, - "source": [ - "data = ImageClassifierDataLoader.from_folder(image_path)\n", - "train_data, test_data = data.split(0.9)" - ], - "execution_count": 5, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "data = ImageClassifierDataLoader.from_folder(image_path)\n", + "train_data, test_data = data.split(0.9)" ] }, { "cell_type": "markdown", "metadata": { - "id": "Y_9IWyIztuRF", - "colab_type": "text" + "id": "Y_9IWyIztuRF" }, "source": [ "Step 2. Customize the TensorFlow model." @@ -435,22 +404,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "yRXMZbrwtyRD", - "colab_type": "code", - "colab": {} + "id": "yRXMZbrwtyRD" }, + "outputs": [], "source": [ "model = image_classifier.create(train_data)" - ], - "execution_count": 11, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "oxU2fDr-t2Ya", - "colab_type": "text" + "id": "oxU2fDr-t2Ya" }, "source": [ "Step 3. Evaluate the model." @@ -458,58 +424,52 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "wQr02VxJt6Cs", - "colab_type": "code", - "colab": {} + "id": "wQr02VxJt6Cs" }, + "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" - ], - "execution_count": 13, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "eVZw9zU8t84y", - "colab_type": "text" + "id": "eVZw9zU8t84y" }, "source": [ "Step 4. Export to TensorFlow Lite model.\n", "\n", - "Here, we export TensorFlow Lite model with [metadata](https://www.tensorflow.org/lite/convert/metadata) which provides a standard for model descriptions.\n", + "Here, we export TensorFlow Lite model with [metadata](https://www.tensorflow.org/lite/convert/metadata) which provides a standard for model descriptions. The label file is embedded in metadata.\n", + "\n", "You could download it in the left sidebar same as the uploading part for your own use." ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "Zb-eIzfluCoa", - "colab_type": "code", - "colab": {} + "id": "Zb-eIzfluCoa" }, + "outputs": [], "source": [ "model.export(export_dir='.')" - ], - "execution_count": 14, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pyju1qc_v-wy", - "colab_type": "text" - }, - "source": [ - "After this simple 4 steps, we could further use TensorFlow Lite model file and label file in on-device applications like in [image classification](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification) reference app.\n" ] }, { "cell_type": "markdown", "metadata": { - "id": "R1QG32ivs9lF", - "colab_type": "text" + "id": "pyju1qc_v-wy" + }, + "source": [ + "After this simple 4 steps, we could further use TensorFlow Lite model file in on-device applications like in [image classification](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification) reference app." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R1QG32ivs9lF" }, "source": [ "## Detailed Process\n", @@ -523,8 +483,7 @@ { "cell_type": "markdown", "metadata": { - "id": "ygEncJxtl-nQ", - "colab_type": "text" + "id": "ygEncJxtl-nQ" }, "source": [ "### Step 1: Load Input Data Specific to an On-device ML App\n", @@ -533,52 +492,49 @@ "\n", "The dataset has the following directory structure:\n", "\n", - "
\n",
-        "flower_photos\n",
-        "|__ daisy\n",
+        "\u003cpre\u003e\n",
+        "\u003cb\u003eflower_photos\u003c/b\u003e\n",
+        "|__ \u003cb\u003edaisy\u003c/b\u003e\n",
         "    |______ 100080576_f52e8ee070_n.jpg\n",
         "    |______ 14167534527_781ceb1b7a_n.jpg\n",
         "    |______ ...\n",
-        "|__ dandelion\n",
+        "|__ \u003cb\u003edandelion\u003c/b\u003e\n",
         "    |______ 10043234166_e6dd915111_n.jpg\n",
         "    |______ 1426682852_e62169221f_m.jpg\n",
         "    |______ ...\n",
-        "|__ roses\n",
+        "|__ \u003cb\u003eroses\u003c/b\u003e\n",
         "    |______ 102501987_3cdb8e5394_n.jpg\n",
         "    |______ 14982802401_a3dfb22afb.jpg\n",
         "    |______ ...\n",
-        "|__ sunflowers\n",
+        "|__ \u003cb\u003esunflowers\u003c/b\u003e\n",
         "    |______ 12471791574_bb1be83df4.jpg\n",
         "    |______ 15122112402_cafa41934f.jpg\n",
         "    |______ ...\n",
-        "|__ tulips\n",
+        "|__ \u003cb\u003etulips\u003c/b\u003e\n",
         "    |______ 13976522214_ccec508fe7.jpg\n",
         "    |______ 14487943607_651e8062a1_m.jpg\n",
         "    |______ ...\n",
-        "
" + "\u003c/pre\u003e" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "7tOfUr2KlgpU", - "colab_type": "code", - "colab": {} + "id": "7tOfUr2KlgpU" }, + "outputs": [], "source": [ "image_path = tf.keras.utils.get_file(\n", " 'flower_photos',\n", " 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n", " untar=True)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "E051HBUM5owi", - "colab_type": "text" + "id": "E051HBUM5owi" }, "source": [ "Use `ImageClassifierDataLoader` class to load data.\n", @@ -588,22 +544,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "I_fOlZsklmlL", - "colab_type": "code", - "colab": {} + "id": "I_fOlZsklmlL" }, + "outputs": [], "source": [ "data = ImageClassifierDataLoader.from_folder(image_path)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "u501eT4koURB", - "colab_type": "text" + "id": "u501eT4koURB" }, "source": [ "Split it to training data (80%), validation data (10%, optional) and testing data (10%)." @@ -611,23 +564,20 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "cY4UU5SUobtJ", - "colab_type": "code", - "colab": {} + "id": "cY4UU5SUobtJ" }, + "outputs": [], "source": [ "train_data, rest_data = data.split(0.8)\n", "validation_data, test_data = rest_data.split(0.5)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "Z9_MYPie3EMO", - "colab_type": "text" + "id": "Z9_MYPie3EMO" }, "source": [ "Show 25 image examples with labels." @@ -635,11 +585,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "Ih4Wx44I482b", - "colab_type": "code", - "colab": {} + "id": "Ih4Wx44I482b" }, + "outputs": [], "source": [ "plt.figure(figsize=(10,10))\n", "for i, (image, label) in enumerate(data.dataset.take(25)):\n", @@ -650,15 +600,12 @@ " plt.imshow(image.numpy(), cmap=plt.cm.gray)\n", " plt.xlabel(data.index_to_label[label.numpy()])\n", "plt.show()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "AWuoensX4vDA", - "colab_type": "text" + "id": "AWuoensX4vDA" }, "source": [ "### Step 2: Customize the TensorFlow Model\n", @@ -668,22 +615,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "TvYSUuJY3QxR", - "colab_type": "code", - "colab": {} + "id": "TvYSUuJY3QxR" }, + "outputs": [], "source": [ "model = image_classifier.create(train_data, validation_data=validation_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "4JFOKWnH9x8_", - "colab_type": "text" + "id": "4JFOKWnH9x8_" }, "source": [ "Have a look at the detailed model structure." @@ -691,22 +635,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "QNXAfjl192dC", - "colab_type": "code", - "colab": {} + "id": "QNXAfjl192dC" }, + "outputs": [], "source": [ "model.summary()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "LP5FPk_tOxoZ", - "colab_type": "text" + "id": "LP5FPk_tOxoZ" }, "source": [ "### Step 3: Evaluate the Customized Model\n", @@ -716,22 +657,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "A8c2ZQ0J3Riy", - "colab_type": "code", - "colab": {} + "id": "A8c2ZQ0J3Riy" }, + "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "6ZCrYOWoCt05", - "colab_type": "text" + "id": "6ZCrYOWoCt05" }, "source": [ "We could plot the predicted results in 100 test images. Predicted labels with red color are the wrong predicted results while others are correct." @@ -739,11 +677,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "n9O9Kx7nDQWD", - "colab_type": "code", - "colab": {} + "id": "n9O9Kx7nDQWD" }, + "outputs": [], "source": [ "# A helper function that returns 'red'/'black' depending on if its two input\n", "# parameter matches or not.\n", @@ -771,15 +709,12 @@ " ax.xaxis.label.set_color(color)\n", " plt.xlabel('Predicted: %s' % predict_label)\n", "plt.show()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "S3H0rkbLUZAG", - "colab_type": "text" + "id": "S3H0rkbLUZAG" }, "source": [ "If the accuracy doesn't meet the app requirement, one could refer to [Advanced Usage](#scrollTo=zNDBP2qA54aK) to explore alternatives such as changing to a larger model, adjusting re-training parameters etc." @@ -788,45 +723,64 @@ { "cell_type": "markdown", "metadata": { - "id": "aeHoGAceO2xV", - "colab_type": "text" + "id": "aeHoGAceO2xV" }, "source": [ "### Step 4: Export to TensorFlow Lite Model\n", "\n", - "Convert the existing model to TensorFlow Lite model format and save the image labels in label file. The default TFLite filename is `model.tflite`, the default label filename is `label.txt`." + "Convert the existing model to TensorFlow Lite model format with [metadata](https://www.tensorflow.org/lite/convert/metadata). The default TFLite filename is `model.tflite`." ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "Im6wA9lK3TQB", - "colab_type": "code", - "colab": {} + "id": "Im6wA9lK3TQB" }, + "outputs": [], "source": [ "model.export(export_dir='.')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ROS2Ay2jMPCl", - "colab_type": "text" - }, - "source": [ - "The TensorFlow Lite model file and label file could be used in [image classification](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification) reference app.\n", - "\n", - "As for android reference app as an example, we could add `flower_classifier.tflite` and `flower_label.txt` in [assets](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/assets) folder. Meanwhile, change label filename in [code](https://github.com/tensorflow/examples/blob/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java#L65) and TensorFlow Lite file name in [code](https://github.com/tensorflow/examples/blob/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java#L60). Thus, we could run the retrained float TensorFlow Lite model on the android app.\n" ] }, { "cell_type": "markdown", "metadata": { - "id": "-4jQaxyT5_KV", - "colab_type": "text" + "id": "ROS2Ay2jMPCl" + }, + "source": [ + "See [example applications and guides of image classification](https://www.tensorflow.org/lite/models/image_classification/overview#example_applications_and_guides) for more details about how to integrate the TensorFlow Lite model into mobile apps." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "habFnvRxxQ4A" + }, + "source": [ + "The allowed export formats can be one or a list of the following:\n", + "\n", + "* `ExportFormat.TFLITE`\n", + "* `ExportFormat.LABEL`\n", + "* `ExportFormat.SAVED_MODEL`\n", + "\n", + "By default, it just exports TensorFlow Lite model with metadata. You can also selectively export different files. For instance, exporting only the label file as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BvxWsOTmKG4P" + }, + "outputs": [], + "source": [ + "model.export(export_dir='.', export_format=ExportFormat.LABEL)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-4jQaxyT5_KV" }, "source": [ "You can also evalute the tflite model with the `evaluate_tflite` method." @@ -834,22 +788,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "S1YoPX5wOK-u", - "colab_type": "code", - "colab": {} + "id": "S1YoPX5wOK-u" }, + "outputs": [], "source": [ "model.evaluate_tflite('model.tflite', test_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "zNDBP2qA54aK", - "colab_type": "text" + "id": "zNDBP2qA54aK" }, "source": [ "## Advanced Usage\n", @@ -871,8 +822,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Gc4Jk8TvBQfm", - "colab_type": "text" + "id": "Gc4Jk8TvBQfm" }, "source": [ "## Post-training quantization on the TensorFLow Lite model\n" @@ -881,8 +831,7 @@ { "cell_type": "markdown", "metadata": { - "id": "tD8BOYrHBiDt", - "colab_type": "text" + "id": "tD8BOYrHBiDt" }, "source": [ "[Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) is a conversion technique that can reduce model size and inference latency, while also improving CPU and hardware accelerator latency, with little degradation in model accuracy. Thus, it's widely used to optimize the model.\n" @@ -891,8 +840,7 @@ { "cell_type": "markdown", "metadata": { - "id": "iyIo0d5TCzE2", - "colab_type": "text" + "id": "iyIo0d5TCzE2" }, "source": [ "Model Maker supports multiple post-training quantization options. Let's take full integer quantization as an instance. First, define the quantization config to enforce enforce full integer quantization for all ops including the input and output. The input type and output type are `uint8` by default. You may also change them to other types like `int8` by setting `inference_input_type` and `inference_output_type` in config." @@ -900,22 +848,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "k8hL2mstCxQl", - "colab_type": "code", - "colab": {} + "id": "k8hL2mstCxQl" }, + "outputs": [], "source": [ "config = configs.QuantizationConfig.create_full_integer_quantization(representative_data=test_data, is_integer_only=True)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "K1gzx_rmFMOA", - "colab_type": "text" + "id": "K1gzx_rmFMOA" }, "source": [ "Then we export TensorFlow Lite model with such configuration." @@ -923,22 +868,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "WTJzFQnJFMjr", - "colab_type": "code", - "colab": {} + "id": "WTJzFQnJFMjr" }, + "outputs": [], "source": [ "model.export(export_dir='.', tflite_filename='model_quant.tflite', quantization_config=config)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "Safo0e40wKZW", - "colab_type": "text" + "id": "Safo0e40wKZW" }, "source": [ "In Colab, you can download the model named `model_quant.tflite` from the left sidebar, same as the uploading part mentioned above." @@ -947,8 +889,7 @@ { "cell_type": "markdown", "metadata": { - "id": "A4kiTJtZ_sDm", - "colab_type": "text" + "id": "A4kiTJtZ_sDm" }, "source": [ "## Change the model\n" @@ -957,8 +898,7 @@ { "cell_type": "markdown", "metadata": { - "id": "794vgj6ud7Ep", - "colab_type": "text" + "id": "794vgj6ud7Ep" }, "source": [ "### Change to the model that's supported in this library.\n", @@ -970,22 +910,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "7JKsJ6-P6ae1", - "colab_type": "code", - "colab": {} + "id": "7JKsJ6-P6ae1" }, + "outputs": [], "source": [ "model = image_classifier.create(train_data, model_spec=model_spec.mobilenet_v2_spec, validation_data=validation_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "gm_B1Wv08AxR", - "colab_type": "text" + "id": "gm_B1Wv08AxR" }, "source": [ "Evaluate the newly retrained MobileNetV2 model to see the accuracy and loss in testing data." @@ -993,22 +930,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "lB2Go3HW8X7_", - "colab_type": "code", - "colab": {} + "id": "lB2Go3HW8X7_" }, + "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "vAciGzVWtmWp", - "colab_type": "text" + "id": "vAciGzVWtmWp" }, "source": [ "### Change to the model in TensorFlow Hub\n", @@ -1022,24 +956,21 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "xdiMF2WMfAR4", - "colab_type": "code", - "colab": {} + "id": "xdiMF2WMfAR4" }, + "outputs": [], "source": [ "inception_v3_spec = model_spec.ImageModelSpec(\n", " uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')\n", "inception_v3_spec.input_image_shape = [299, 299]" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "T_GGIoXZCs5F", - "colab_type": "text" + "id": "T_GGIoXZCs5F" }, "source": [ "Then, by setting parameter `model_spec` to `inception_v3_spec` in `create` method, we could retrain the Inception V3 model.\n", @@ -1050,8 +981,7 @@ { "cell_type": "markdown", "metadata": { - "id": "UhZ5IRKdeex3", - "colab_type": "text" + "id": "UhZ5IRKdeex3" }, "source": [ "### Change your own custom model" @@ -1060,8 +990,7 @@ { "cell_type": "markdown", "metadata": { - "id": "svTjlZhrCrcV", - "colab_type": "text" + "id": "svTjlZhrCrcV" }, "source": [ "If we'd like to use the custom model that's not in TensorFlow Hub, we should create and export [ModelSpec](https://www.tensorflow.org/hub/api_docs/python/hub/ModuleSpec) in TensorFlow Hub.\n", @@ -1072,41 +1001,47 @@ { "cell_type": "markdown", "metadata": { - "id": "4M9bn703AHt2", - "colab_type": "text" + "id": "4M9bn703AHt2" }, "source": [ "## Change the training hyperparameters\n", - "We could also change the training hyperparameters like `epochs`, `dropout_rate` and `batch_size` that could affect the model accuracy. For instance,\n", + "We could also change the training hyperparameters like `epochs`, `dropout_rate` and `batch_size` that could affect the model accuracy. The model parameters you can adjust are:\n", "\n", "\n", "* `epochs`: more epochs could achieve better accuracy until it converges but training for too many epochs may lead to overfitting.\n", - "* `dropout_rate`: avoid overfitting.\n", - "* `batch_size`: number of samples to use in one training step.\n", - "* `validation_data`: number of samples to use in one training step.\n", + "* `dropout_rate`: The rate for dropout, avoid overfitting. None by default.\n", + "* `batch_size`: number of samples to use in one training step. None by default.\n", + "* `validation_data`: Validation data. If None, skips validation process. None by default.\n", + "* `train_whole_model`: If true, the Hub module is trained together with the classification layer on top. Otherwise, only train the top classification layer. None by default.\n", + "* `learning_rate`: Base learning rate. None by default.\n", + "* `momentum`: a Python float forwarded to the optimizer. Only used when\n", + " `use_hub_library` is True. None by default.\n", + "* `shuffle`: Boolean, whether the data should be shuffled. False by default.\n", + "* `use_augmentation`: Boolean, use data augmentation for preprocessing. False by default.\n", + "* `use_hub_library`: Boolean, use `make_image_classifier_lib` from tensorflow hub to retrain the model. This training pipline could achieve better performance for complicated dataset with many categories. True by default. \n", + "* `warmup_steps`: Number of warmup steps for warmup schedule on learning rate. If None, the default warmup_steps is used which is the total training steps in two epochs. Only used when `use_hub_library` is False. None by default.\n", + "* `model_dir`: Optional, the location of the model checkpoint files. Only used when `use_hub_library` is False. None by default.\n", "\n", + "Parameters which are None by default like `epochs` will get the concrete default parameters in [make_image_classifier_lib](https://github.com/tensorflow/hub/blob/02ab9b7d3455e99e97abecf43c5d598a5528e20c/tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py#L54) from TensorFlow Hub library or [train_image_classifier_lib](https://github.com/tensorflow/examples/blob/f0260433d133fd3cea4a920d1e53ecda07163aee/tensorflow_examples/lite/model_maker/core/task/train_image_classifier_lib.py#L61).\n", "\n", "For example, we could train with more epochs.\n" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "A3k7mhH54QcK", - "colab_type": "code", - "colab": {} + "id": "A3k7mhH54QcK" }, + "outputs": [], "source": [ "model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "VaYBQymQDsXU", - "colab_type": "text" + "id": "VaYBQymQDsXU" }, "source": [ "Evaluate the newly retrained model with 10 training epochs." @@ -1114,16 +1049,29 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "VafIYpKWD4Sw", - "colab_type": "code", - "colab": {} + "id": "VafIYpKWD4Sw" }, + "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" - ], - "execution_count": null, - "outputs": [] + ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "model_maker_image_classification.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb index 4abee967a09..06f534522c7 100644 --- a/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb +++ b/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb @@ -1,25 +1,9 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "model_maker_question_answer.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "h2q27gKz1H20", - "colab_type": "text" + "id": "h2q27gKz1H20" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." @@ -27,12 +11,12 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "cellView": "form", - "id": "TUfAcER1oUS6", - "colab_type": "code", - "colab": {} + "id": "TUfAcER1oUS6" }, + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -45,15 +29,12 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "Gb7qyhNL1yWt", - "colab_type": "text" + "id": "Gb7qyhNL1yWt" }, "source": [ "# Question Answer with TensorFlow Lite Model Maker" @@ -62,31 +43,29 @@ { "cell_type": "markdown", "metadata": { - "id": "Fw5Y7snSuG51", - "colab_type": "text" + "id": "Fw5Y7snSuG51" }, "source": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " View on TensorFlow.org\n", - " \n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - " \n", - " Download notebook\n", - "
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/tutorials/model_maker_question_answer\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" ] }, { "cell_type": "markdown", "metadata": { - "id": "sr3q-gvm3cI8", - "colab_type": "text" + "id": "sr3q-gvm3cI8" }, "source": [ "The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.\n", @@ -97,8 +76,7 @@ { "cell_type": "markdown", "metadata": { - "id": "UxEHFTk755qw", - "colab_type": "text" + "id": "UxEHFTk755qw" }, "source": [ "# Introduction to Question Answer Task" @@ -107,18 +85,17 @@ { "cell_type": "markdown", "metadata": { - "id": "cFbKTCF25-SG", - "colab_type": "text" + "id": "cFbKTCF25-SG" }, "source": [ "The supported task in this library is extractive question answer task, which means given a passage and a question, the answer is the span in the passage. The image below shows an example for question answer.\n", "\n", "\n", - "

\n", + "\u003cp align=\"center\"\u003e\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_squad_showcase.png\" width=\"500\"\u003e\u003c/p\u003e\n", "\n", - "

\n", - " Answers are spans in the passage (image credit: SQuAD blog) \n", - "

\n", + "\u003cp align=\"center\"\u003e\n", + " \u003cem\u003eAnswers are spans in the passage (image credit: \u003ca href=\"https://rajpurkar.github.io/mlx/qa-and-squad/\"\u003eSQuAD blog\u003c/a\u003e) \u003c/em\u003e\n", + "\u003c/p\u003e\n", "\n", "As for the model of question answer task, the inputs should be the passage and question pair that are already preprocessed, the outputs should be the start logits and end logits for each token in the passage.\n", "The size of input could be set and adjusted according to the length of passage and question." @@ -127,8 +104,7 @@ { "cell_type": "markdown", "metadata": { - "id": "gb7P4WQta8Ub", - "colab_type": "text" + "id": "gb7P4WQta8Ub" }, "source": [ "## End-to-End Overview\n" @@ -137,8 +113,7 @@ { "cell_type": "markdown", "metadata": { - "id": "w7cIHjIfbDlG", - "colab_type": "text" + "id": "w7cIHjIfbDlG" }, "source": [ "The following code snippet demonstrates how to get the model within a few lines of code. The overall process includes 5 steps: (1) choose a model, (2) load data, (3) retrain the model, (4) evaluate, and (5) export it to TensorFlow Lite format." @@ -147,8 +122,7 @@ { "cell_type": "markdown", "metadata": { - "id": "xQPdlxZBYuZG", - "colab_type": "text" + "id": "xQPdlxZBYuZG" }, "source": [ "```python\n", @@ -165,7 +139,7 @@ "# Gets the evaluation result.\n", "metric = model.evaluate(validation_data)\n", "\n", - "# Exports the model to the TensorFlow Lite format in the export directory.\n", + "# Exports the model to the TensorFlow Lite format with metadata in the export directory.\n", "model.export(export_dir)\n", "```" ] @@ -173,8 +147,7 @@ { "cell_type": "markdown", "metadata": { - "id": "exScAdvBbNEi", - "colab_type": "text" + "id": "exScAdvBbNEi" }, "source": [ "The following sections explain the code in more detail." @@ -183,8 +156,7 @@ { "cell_type": "markdown", "metadata": { - "id": "bcLF2PKkSbV3", - "colab_type": "text" + "id": "bcLF2PKkSbV3" }, "source": [ "## Prerequisites\n", @@ -194,21 +166,18 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "qhl8lqVamEty", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, + "id": "qhl8lqVamEty", "outputId": "f2ef33ec-ad6b-4b45-9c50-6d65118e80da" }, - "source": [ - "!pip install tflite-model-maker" - ], - "execution_count": 1, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Collecting tflite-model-maker\n", @@ -218,14 +187,14 @@ "Collecting tf-nightly\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)\n", "\u001b[K |████████████████████████████████| 390.2MB 43kB/s \n", - "\u001b[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)\n", + "\u001b[?25hRequirement already satisfied: numpy\u003e=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)\n", "Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)\n", "Collecting tf-models-nightly\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d3/e9/c4e5a451c268a5a75a27949562364f6086f6bb33b226a065a8beceefa9ba/tf_models_nightly-2.3.0.dev20200914-py2.py3-none-any.whl (993kB)\n", "\u001b[K |████████████████████████████████| 1.0MB 57.6MB/s \n", "\u001b[?25hCollecting flatbuffers==1.12\n", " Downloading https://files.pythonhosted.org/packages/eb/26/712e578c5f14e26ae3314c39a1bdc4eb2ec2f4ddc89b708cf8e0a0d20423/flatbuffers-1.12-py2.py3-none-any.whl\n", - "Requirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)\n", + "Requirement already satisfied: tensorflow-hub\u003e=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)\n", "Collecting fire\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)\n", "\u001b[K |████████████████████████████████| 81kB 11.5MB/s \n", @@ -235,102 +204,102 @@ "\u001b[?25hCollecting tflite-support==0.1.0rc3.dev2\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)\n", "\u001b[K |████████████████████████████████| 1.0MB 50.9MB/s \n", - "\u001b[?25hRequirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->tflite-model-maker) (1.15.0)\n", - "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.0)\n", - "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)\n", - "Collecting tb-nightly<3.0.0a0,>=2.4.0a0\n", + "\u001b[?25hRequirement already satisfied: tensorflow-datasets\u003e=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)\n", + "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py-\u003etflite-model-maker) (1.15.0)\n", + "Requirement already satisfied: termcolor\u003e=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.1.0)\n", + "Requirement already satisfied: opt-einsum\u003e=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (3.3.0)\n", + "Collecting tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)\n", "\u001b[K |████████████████████████████████| 10.1MB 51.4MB/s \n", - "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)\n", - "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)\n", + "\u001b[?25hRequirement already satisfied: typing-extensions\u003e=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (3.7.4.3)\n", + "Requirement already satisfied: wheel\u003e=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.35.1)\n", "Collecting tf-estimator-nightly\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)\n", "\u001b[K |████████████████████████████████| 460kB 36.0MB/s \n", - "\u001b[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)\n", - "Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)\n", - "Requirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)\n", - "Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.12.1)\n", - "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)\n", - "Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.12.4)\n", - "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)\n", - "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)\n", - "Requirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)\n", - "Collecting pyyaml>=5.1\n", + "\u001b[?25hRequirement already satisfied: google-pasta\u003e=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.2.0)\n", + "Requirement already satisfied: h5py\u003c2.11.0,\u003e=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (2.10.0)\n", + "Requirement already satisfied: keras-preprocessing\u003c1.2,\u003e=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.1.2)\n", + "Requirement already satisfied: wrapt\u003e=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.12.1)\n", + "Requirement already satisfied: grpcio\u003e=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.32.0)\n", + "Requirement already satisfied: protobuf\u003e=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (3.12.4)\n", + "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.3.3)\n", + "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.6.3)\n", + "Requirement already satisfied: scipy\u003e=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.4.1)\n", + "Collecting pyyaml\u003e=5.1\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", "\u001b[K |████████████████████████████████| 276kB 59.8MB/s \n", - "\u001b[?25hCollecting tensorflow-model-optimization>=0.4.1\n", + "\u001b[?25hCollecting tensorflow-model-optimization\u003e=0.4.1\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)\n", "\u001b[K |████████████████████████████████| 174kB 51.0MB/s \n", - "\u001b[?25hRequirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)\n", - "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)\n", - "Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)\n", + "\u001b[?25hRequirement already satisfied: pandas\u003e=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.0.5)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.7)\n", + "Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.29.21)\n", "Collecting opencv-python-headless\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)\n", "\u001b[K |████████████████████████████████| 36.6MB 83kB/s \n", - "\u001b[?25hRequirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)\n", - "Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)\n", - "Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)\n", - "Collecting tf-slim>=1.1.0\n", + "\u001b[?25hRequirement already satisfied: kaggle\u003e=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.5.8)\n", + "Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (2.0.2)\n", + "Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (4.1.3)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (3.2.2)\n", + "Collecting tf-slim\u003e=1.1.0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)\n", "\u001b[K |████████████████████████████████| 358kB 55.9MB/s \n", "\u001b[?25hCollecting seqeval\n", " Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz\n", - "Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)\n", - "Collecting py-cpuinfo>=3.3.0\n", + "Requirement already satisfied: psutil\u003e=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (5.4.8)\n", + "Collecting py-cpuinfo\u003e=3.3.0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)\n", "\u001b[K |████████████████████████████████| 102kB 13.5MB/s \n", - "\u001b[?25hRequirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)\n", - "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)\n", - "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)\n", - "Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)\n", - "Collecting pybind11>=2.4\n", + "\u001b[?25hRequirement already satisfied: google-api-python-client\u003e=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.7.12)\n", + "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.3.0)\n", + "Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.8.3)\n", + "Requirement already satisfied: google-cloud-bigquery\u003e=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.21.0)\n", + "Collecting pybind11\u003e=2.4\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)\n", "\u001b[K |████████████████████████████████| 296kB 47.9MB/s \n", - "\u001b[?25hRequirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)\n", - "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)\n", - "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)\n", - "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)\n", - "Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)\n", - "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)\n", - "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)\n", - "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (50.3.0)\n", - "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)\n", - "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)\n", - "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.17.2)\n", - "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)\n", - "Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)\n", - "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)\n", - "Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2.8.1)\n", - "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)\n", - "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)\n", - "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)\n", - "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)\n", - "Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)\n", - "Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)\n", - "Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.17.4)\n", - "Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)\n", - "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)\n", - "Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)\n", - "Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)\n", - "Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)\n", - "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)\n", - "Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)\n", - "Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)\n", - "Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)\n", - "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (4.1.1)\n", - "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)\n", - "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)\n", - "Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)\n", + "\u001b[?25hRequirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.3)\n", + "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.24.0)\n", + "Requirement already satisfied: requests\u003e=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.23.0)\n", + "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.3.2)\n", + "Requirement already satisfied: attrs\u003e=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (20.2.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (4.41.1)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.16.0)\n", + "Requirement already satisfied: werkzeug\u003e=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.0.1)\n", + "Requirement already satisfied: setuptools\u003e=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (50.3.0)\n", + "Requirement already satisfied: tensorboard-plugin-wit\u003e=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.7.0)\n", + "Requirement already satisfied: google-auth-oauthlib\u003c0.5,\u003e=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (0.4.1)\n", + "Requirement already satisfied: google-auth\u003c2,\u003e=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.17.2)\n", + "Requirement already satisfied: markdown\u003e=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.2.2)\n", + "Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization\u003e=0.4.1-\u003etf-models-nightly-\u003etflite-model-maker) (0.1.5)\n", + "Requirement already satisfied: pytz\u003e=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas\u003e=0.22.0-\u003etf-models-nightly-\u003etflite-model-maker) (2018.9)\n", + "Requirement already satisfied: python-dateutil\u003e=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas\u003e=0.22.0-\u003etf-models-nightly-\u003etflite-model-maker) (2.8.1)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (2020.6.20)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (4.0.1)\n", + "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (0.0.1)\n", + "Requirement already satisfied: urllib3\u003c1.25,\u003e=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (1.24.3)\n", + "Requirement already satisfied: pyasn1\u003e=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.4.8)\n", + "Requirement already satisfied: rsa\u003e=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (4.6)\n", + "Requirement already satisfied: httplib2\u003e=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.17.4)\n", + "Requirement already satisfied: pyasn1-modules\u003e=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.2.8)\n", + "Requirement already satisfied: kiwisolver\u003e=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (1.2.0)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,\u003e=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (2.4.7)\n", + "Requirement already satisfied: cycler\u003e=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (0.10.0)\n", + "Requirement already satisfied: Keras\u003e=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval-\u003etf-models-nightly-\u003etflite-model-maker) (2.4.3)\n", + "Requirement already satisfied: google-auth-httplib2\u003e=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (0.0.4)\n", + "Requirement already satisfied: uritemplate\u003c4dev,\u003e=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (3.0.1)\n", + "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons-\u003etf-models-nightly-\u003etflite-model-maker) (2.7.1)\n", + "Requirement already satisfied: google-cloud-core\u003c2.0dev,\u003e=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (1.0.3)\n", + "Requirement already satisfied: google-resumable-media!=0.4.0,\u003c0.5.0dev,\u003e=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (0.4.1)\n", + "Requirement already satisfied: googleapis-common-protos\u003c2,\u003e=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (1.52.0)\n", + "Requirement already satisfied: chardet\u003c4,\u003e=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests\u003e=2.19.0-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (3.0.4)\n", + "Requirement already satisfied: idna\u003c3,\u003e=2.5 in /usr/local/lib/python3.6/dist-packages (from requests\u003e=2.19.0-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.10)\n", + "Requirement already satisfied: requests-oauthlib\u003e=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib\u003c0.5,\u003e=0.4.1-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.3.0)\n", + "Requirement already satisfied: cachetools\u003c5.0,\u003e=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth\u003c2,\u003e=1.6.3-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (4.1.1)\n", + "Requirement already satisfied: importlib-metadata; python_version \u003c \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown\u003e=2.6.8-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.7.0)\n", + "Requirement already satisfied: text-unidecode\u003e=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify-\u003ekaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (1.3)\n", + "Requirement already satisfied: google-api-core\u003c2.0.0dev,\u003e=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core\u003c2.0dev,\u003e=1.0.3-\u003egoogle-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (1.16.0)\n", + "Requirement already satisfied: oauthlib\u003e=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib\u003e=0.7.0-\u003egoogle-auth-oauthlib\u003c0.5,\u003e=0.4.1-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.1.0)\n", + "Requirement already satisfied: zipp\u003e=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version \u003c \"3.8\"-\u003emarkdown\u003e=2.6.8-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.1.0)\n", "Building wheels for collected packages: fire, pyyaml, seqeval, py-cpuinfo\n", " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=f0b82e6b31e21d6db3591478a37188c727533acefe415b16b456c85ef9bef47c\n", @@ -350,16 +319,17 @@ " Uninstalling PyYAML-3.13:\n", " Successfully uninstalled PyYAML-3.13\n", "Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "!pip install tflite-model-maker" ] }, { "cell_type": "markdown", "metadata": { - "id": "l6lRhVK9Q_0U", - "colab_type": "text" + "id": "l6lRhVK9Q_0U" }, "source": [ "Import the required packages." @@ -367,11 +337,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "XtxiUeZEiXpt", - "colab_type": "code", - "colab": {} + "id": "XtxiUeZEiXpt" }, + "outputs": [], "source": [ "import numpy as np\n", "import os\n", @@ -380,18 +350,16 @@ "assert tf.__version__.startswith('2')\n", "\n", "from tflite_model_maker import configs\n", + "from tflite_model_maker import ExportFormat\n", "from tflite_model_maker import model_spec\n", "from tflite_model_maker import question_answer\n", "from tflite_model_maker import QuestionAnswerDataLoader" - ], - "execution_count": 3, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "l65ctmtW7_FF", - "colab_type": "text" + "id": "l65ctmtW7_FF" }, "source": [ "The \"End-to-End Overview\" demonstrates a simple end-to-end example. The following sections walk through the example step by step to show more detail." @@ -400,8 +368,7 @@ { "cell_type": "markdown", "metadata": { - "id": "kJ_B8fMDOhMR", - "colab_type": "text" + "id": "kJ_B8fMDOhMR" }, "source": [ "## Choose a model_spec that represents a model for question answer\n", @@ -419,22 +386,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "vEAWuZQ1PFiX", - "colab_type": "code", - "colab": {} + "id": "vEAWuZQ1PFiX" }, + "outputs": [], "source": [ "spec = model_spec.get('mobilebert_qa_squad')" - ], - "execution_count": 4, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "ygEncJxtl-nQ", - "colab_type": "text" + "id": "ygEncJxtl-nQ" }, "source": [ "## Load Input Data Specific to an On-device ML App and Preprocess the Data\n", @@ -450,15 +414,27 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "7tOfUr2KlgpU", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, + "id": "7tOfUr2KlgpU", "outputId": "a5d42181-82ea-47a2-f364-825701e4a1f8" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json\n", + "32571392/32570663 [==============================] - 1s 0us/step\n", + "Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json\n", + "1171456/1167744 [==============================] - 0s 0us/step\n" + ] + } + ], "source": [ "train_data_path = tf.keras.utils.get_file(\n", " fname='triviaqa-web-train-8000.json',\n", @@ -466,31 +442,17 @@ "validation_data_path = tf.keras.utils.get_file(\n", " fname='triviaqa-verified-web-dev.json',\n", " origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json')" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json\n", - "32571392/32570663 [==============================] - 1s 0us/step\n", - "Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json\n", - "1171456/1167744 [==============================] - 0s 0us/step\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "UfZk8GNr_1nc", - "colab_type": "text" + "id": "UfZk8GNr_1nc" }, "source": [ "You can also train the MobileBERT model with your own dataset. If you are running this notebook on Colab, upload your data by using the left sidebar.\n", "\n", - "\"Upload\n", + "\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_question_answer.png\" alt=\"Upload File\" width=\"800\" hspace=\"100\"\u003e\n", "\n", "If you prefer not to upload your data to the cloud, you can also run the library offline by following the [guide](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)." ] @@ -498,8 +460,7 @@ { "cell_type": "markdown", "metadata": { - "id": "E051HBUM5owi", - "colab_type": "text" + "id": "E051HBUM5owi" }, "source": [ "Use the `QuestionAnswerDataLoader.from_squad` method to load and preprocess the [SQuAD format](https://rajpurkar.github.io/SQuAD-explorer/) data according to a specific `model_spec`. You can use either SQuAD2.0 or SQuAD1.1 formats. Setting parameter `version_2_with_negative` as `True` means the formats is SQuAD2.0. Otherwise, the format is SQuAD1.1. By default, `version_2_with_negative` is `False`." @@ -507,23 +468,20 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "I_fOlZsklmlL", - "colab_type": "code", - "colab": {} + "id": "I_fOlZsklmlL" }, + "outputs": [], "source": [ "train_data = QuestionAnswerDataLoader.from_squad(train_data_path, spec, is_training=True)\n", "validation_data = QuestionAnswerDataLoader.from_squad(validation_data_path, spec, is_training=False)" - ], - "execution_count": 6, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "AWuoensX4vDA", - "colab_type": "text" + "id": "AWuoensX4vDA" }, "source": [ "## Customize the TensorFlow Model\n", @@ -536,48 +494,46 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "TvYSUuJY3QxR", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, + "id": "TvYSUuJY3QxR", "outputId": "67c38b57-3596-467b-c71b-b1a3980e68c7" }, - "source": [ - "model = question_answer.create(train_data, model_spec=spec)" - ], - "execution_count": null, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Retraining the models...\n" - ], - "name": "stdout" + ] }, { + "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Retraining the models...\n" - ], - "name": "stderr" + ] }, { + "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/2\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "model = question_answer.create(train_data, model_spec=spec)" ] }, { "cell_type": "markdown", "metadata": { - "id": "0JKI-pNc8idH", - "colab_type": "text" + "id": "0JKI-pNc8idH" }, "source": [ "Have a look at the detailed model structure." @@ -585,22 +541,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "gd7Hs8TF8n3H", - "colab_type": "code", - "colab": {} + "id": "gd7Hs8TF8n3H" }, + "outputs": [], "source": [ "model.summary()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "LP5FPk_tOxoZ", - "colab_type": "text" + "id": "LP5FPk_tOxoZ" }, "source": [ "## Evaluate the Customized Model\n", @@ -610,22 +563,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "A8c2ZQ0J3Riy", - "colab_type": "code", - "colab": {} + "id": "A8c2ZQ0J3Riy" }, + "outputs": [], "source": [ "model.evaluate(validation_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "aeHoGAceO2xV", - "colab_type": "text" + "id": "aeHoGAceO2xV" }, "source": [ "## Export to TensorFlow Lite Model\n", @@ -636,8 +586,7 @@ { "cell_type": "markdown", "metadata": { - "id": "TwA2Z2pokQJc", - "colab_type": "text" + "id": "TwA2Z2pokQJc" }, "source": [ "Since MobileBERT is too big for on-device applications, use dynamic range quantization on the model to compress MobileBERT by 4x with the minimal loss of performance. First, define the quantization configuration:" @@ -645,56 +594,75 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "1wBVTO8qkmum", - "colab_type": "code", - "colab": {} + "id": "1wBVTO8qkmum" }, + "outputs": [], "source": [ "config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])\n", "config._experimental_new_quantizer = True" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "qea2YkEGkOTH", - "colab_type": "text" + "id": "qea2YkEGkOTH" }, "source": [ - "Export the quantized TFLite model according to the quantization config and save the vocabulary to a vocab file. The default TFLite model filename is `model.tflite`, and the default vocab filename is `vocab`." + "Export the quantized TFLite model according to the quantization config with [metadata](https://www.tensorflow.org/lite/convert/metadata). The default TFLite model filename is `model.tflite`." ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "Im6wA9lK3TQB", - "colab_type": "code", - "colab": {} + "id": "Im6wA9lK3TQB" }, + "outputs": [], "source": [ "model.export(export_dir='.', quantization_config=config)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w12kvDdHJIGH", - "colab_type": "text" - }, - "source": [ - "You can use the TensorFlow Lite model file and vocab file in the [bert_qa](https://github.com/tensorflow/examples/tree/master/lite/examples/bert_qa/android) reference app by downloading it from the left sidebar on Colab." ] }, { "cell_type": "markdown", "metadata": { - "id": "HZKYthlVrTos", - "colab_type": "text" + "id": "w12kvDdHJIGH" + }, + "source": [ + "You can use the TensorFlow Lite model file in the [bert_qa](https://github.com/tensorflow/examples/tree/master/lite/examples/bert_qa/android) reference app using [BertQuestionAnswerer API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer) in [TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview) by downloading it from the left sidebar on Colab." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VFnJPvq3VGh3" + }, + "source": [ + "The allowed export formats can be one or a list of the following:\n", + "\n", + "* `ExportFormat.TFLITE`\n", + "* `ExportFormat.VOCAB`\n", + "* `ExportFormat.SAVED_MODEL`\n", + "\n", + "By default, it just exports TensorFlow Lite model with metadata. You can also selectively export different files. For instance, exporting only the vocab file as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ro2hz4kXVImY" + }, + "outputs": [], + "source": [ + "model.export(export_dir='.', export_format=ExportFormat.VOCAB)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HZKYthlVrTos" }, "source": [ "You can also evalute the tflite model with the `evaluate_tflite` method. This step is expected to take a long time." @@ -702,22 +670,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "ochbq95ZrVFX", - "colab_type": "code", - "colab": {} + "id": "ochbq95ZrVFX" }, + "outputs": [], "source": [ "model.evaluate_tflite('model.tflite', validation_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "EoWiA_zX8rxE", - "colab_type": "text" + "id": "EoWiA_zX8rxE" }, "source": [ "## Advanced Usage\n", @@ -733,8 +698,7 @@ { "cell_type": "markdown", "metadata": { - "id": "mwtiksguDfhl", - "colab_type": "text" + "id": "mwtiksguDfhl" }, "source": [ "### Adjust the model\n", @@ -761,8 +725,7 @@ { "cell_type": "markdown", "metadata": { - "id": "cAOd5_bzH9AQ", - "colab_type": "text" + "id": "cAOd5_bzH9AQ" }, "source": [ "For example, you can train the model with a longer sequence length. If you change the model, you must first construct a new `model_spec`." @@ -770,23 +733,20 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "e9WBN0UTQoMN", - "colab_type": "code", - "colab": {} + "id": "e9WBN0UTQoMN" }, + "outputs": [], "source": [ "new_spec = model_spec.get('mobilebert_qa')\n", "new_spec.seq_len = 512" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "6LSTdghTP0Cv", - "colab_type": "text" + "id": "6LSTdghTP0Cv" }, "source": [ "The remaining steps are the same. Note that you must rerun both the `dataloader` and `create` parts as different model specs may have different preprocessing steps.\n" @@ -795,8 +755,7 @@ { "cell_type": "markdown", "metadata": { - "id": "LvQuy7RSDir3", - "colab_type": "text" + "id": "LvQuy7RSDir3" }, "source": [ "### Tune training hyperparameters\n", @@ -815,8 +774,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Eq6B9lKMfhS6", - "colab_type": "text" + "id": "Eq6B9lKMfhS6" }, "source": [ "### Change the Model Architecture\n", @@ -831,12 +789,26 @@ { "cell_type": "markdown", "metadata": { - "id": "L2d7yycrgu6L", - "colab_type": "text" + "id": "L2d7yycrgu6L" }, "source": [ "The remaining steps are the same." ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "model_maker_question_answer.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb index 36ac1e0e4c2..43168e394f5 100644 --- a/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb +++ b/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb @@ -1,25 +1,9 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "model_maker_text_classification.ipynb", - "provenance": [], - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, "cells": [ { "cell_type": "markdown", "metadata": { - "id": "h2q27gKz1H20", - "colab_type": "text" + "id": "h2q27gKz1H20" }, "source": [ "##### Copyright 2019 The TensorFlow Authors." @@ -27,12 +11,12 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "cellView": "form", - "id": "TUfAcER1oUS6", - "colab_type": "code", - "colab": {} + "id": "TUfAcER1oUS6" }, + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -45,15 +29,12 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "Gb7qyhNL1yWt", - "colab_type": "text" + "id": "Gb7qyhNL1yWt" }, "source": [ "# Text classification with TensorFlow Lite Model Maker" @@ -62,31 +43,29 @@ { "cell_type": "markdown", "metadata": { - "id": "Fw5Y7snSuG51", - "colab_type": "text" + "id": "Fw5Y7snSuG51" }, "source": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " View on TensorFlow.org\n", - " \n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - " \n", - " Download notebook\n", - "
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/tutorials/model_maker_text_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" ] }, { "cell_type": "markdown", "metadata": { - "id": "sr3q-gvm3cI8", - "colab_type": "text" + "id": "sr3q-gvm3cI8" }, "source": [ "The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.\n", @@ -97,8 +76,7 @@ { "cell_type": "markdown", "metadata": { - "id": "bcLF2PKkSbV3", - "colab_type": "text" + "id": "bcLF2PKkSbV3" }, "source": [ "## Prerequisites\n" @@ -107,8 +85,7 @@ { "cell_type": "markdown", "metadata": { - "id": "2vvAObmTqglq", - "colab_type": "text" + "id": "2vvAObmTqglq" }, "source": [ "### Install the required packages\n", @@ -117,27 +94,24 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "qhl8lqVamEty", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, + "id": "qhl8lqVamEty", "outputId": "9757d141-eb26-46e2-fc48-376e0e244142" }, - "source": [ - "!pip install tflite-model-maker" - ], - "execution_count": null, "outputs": [ { + "name": "stdout", "output_type": "stream", "text": [ "Collecting tflite-model-maker\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/13/bc/4c23b9cb9ef612a1f48bac5543bd531665de5eab8f8231111aac067f8c30/tflite_model_maker-0.1.2-py3-none-any.whl (104kB)\n", "\r\u001b[K |███▏ | 10kB 20.4MB/s eta 0:00:01\r\u001b[K |██████▎ | 20kB 5.8MB/s eta 0:00:01\r\u001b[K |█████████▍ | 30kB 7.1MB/s eta 0:00:01\r\u001b[K |████████████▋ | 40kB 7.7MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 51kB 6.9MB/s eta 0:00:01\r\u001b[K |██████████████████▉ | 61kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 71kB 8.0MB/s eta 0:00:01\r\u001b[K |█████████████████████████▏ | 81kB 8.4MB/s eta 0:00:01\r\u001b[K |████████████████████████████▎ | 92kB 7.9MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▌| 102kB 8.2MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 112kB 8.2MB/s \n", - "\u001b[?25hRequirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)\n", + "\u001b[?25hRequirement already satisfied: tensorflow-hub\u003e=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)\n", "Collecting fire\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)\n", "\u001b[K |████████████████████████████████| 81kB 7.7MB/s \n", @@ -149,111 +123,111 @@ "\u001b[?25hCollecting sentencepiece\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n", "\u001b[K |████████████████████████████████| 1.1MB 31.6MB/s \n", - "\u001b[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)\n", + "\u001b[?25hRequirement already satisfied: numpy\u003e=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)\n", "Collecting tf-nightly\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)\n", "\u001b[K |████████████████████████████████| 390.2MB 46kB/s \n", "\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.10.0)\n", - "Requirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)\n", + "Requirement already satisfied: tensorflow-datasets\u003e=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)\n", "Collecting tflite-support==0.1.0rc3.dev2\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)\n", "\u001b[K |████████████████████████████████| 1.0MB 50.0MB/s \n", - "\u001b[?25hRequirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (3.12.4)\n", - "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (1.15.0)\n", - "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->tflite-model-maker) (1.1.0)\n", - "Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)\n", - "Collecting tensorflow-model-optimization>=0.4.1\n", + "\u001b[?25hRequirement already satisfied: protobuf\u003e=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub\u003e=0.8.0-\u003etflite-model-maker) (3.12.4)\n", + "Requirement already satisfied: six\u003e=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub\u003e=0.8.0-\u003etflite-model-maker) (1.15.0)\n", + "Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire-\u003etflite-model-maker) (1.1.0)\n", + "Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (2.0.2)\n", + "Collecting tensorflow-model-optimization\u003e=0.4.1\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)\n", "\u001b[K |████████████████████████████████| 174kB 57.7MB/s \n", - "\u001b[?25hCollecting tf-slim>=1.1.0\n", + "\u001b[?25hCollecting tf-slim\u003e=1.1.0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)\n", "\u001b[K |████████████████████████████████| 358kB 54.9MB/s \n", - "\u001b[?25hRequirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)\n", - "Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)\n", - "Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)\n", + "\u001b[?25hRequirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.8.3)\n", + "Requirement already satisfied: kaggle\u003e=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.5.8)\n", + "Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (4.1.3)\n", "Collecting seqeval\n", " Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz\n", - "Requirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)\n", - "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)\n", - "Collecting pyyaml>=5.1\n", + "Requirement already satisfied: scipy\u003e=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.4.1)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.7)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (3.2.2)\n", + "Collecting pyyaml\u003e=5.1\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)\n", "\u001b[K |████████████████████████████████| 276kB 55.1MB/s \n", - "\u001b[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)\n", - "Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)\n", + "\u001b[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.29.21)\n", + "Requirement already satisfied: google-cloud-bigquery\u003e=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.21.0)\n", "Collecting opencv-python-headless\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)\n", "\u001b[K |████████████████████████████████| 36.6MB 88kB/s \n", - "\u001b[?25hRequirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)\n", - "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)\n", - "Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)\n", - "Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)\n", - "Collecting py-cpuinfo>=3.3.0\n", + "\u001b[?25hRequirement already satisfied: psutil\u003e=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (5.4.8)\n", + "Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (0.3.0)\n", + "Requirement already satisfied: google-api-python-client\u003e=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.7.12)\n", + "Requirement already satisfied: pandas\u003e=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly-\u003etflite-model-maker) (1.0.5)\n", + "Collecting py-cpuinfo\u003e=3.3.0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)\n", "\u001b[K |████████████████████████████████| 102kB 11.1MB/s \n", - "\u001b[?25hRequirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)\n", - "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)\n", - "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)\n", - "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)\n", - "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)\n", - "Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)\n", - "Collecting tb-nightly<3.0.0a0,>=2.4.0a0\n", + "\u001b[?25hRequirement already satisfied: grpcio\u003e=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.32.0)\n", + "Requirement already satisfied: opt-einsum\u003e=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (3.3.0)\n", + "Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.3.3)\n", + "Requirement already satisfied: wheel\u003e=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.35.1)\n", + "Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.6.3)\n", + "Requirement already satisfied: h5py\u003c2.11.0,\u003e=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (2.10.0)\n", + "Requirement already satisfied: typing-extensions\u003e=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (3.7.4.3)\n", + "Collecting tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)\n", "\u001b[K |████████████████████████████████| 10.1MB 46.1MB/s \n", - "\u001b[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)\n", + "\u001b[?25hRequirement already satisfied: google-pasta\u003e=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (0.2.0)\n", "Collecting tf-estimator-nightly\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)\n", "\u001b[K |████████████████████████████████| 460kB 51.7MB/s \n", - "\u001b[?25hRequirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)\n", - "Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.12.1)\n", - "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)\n", - "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)\n", - "Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)\n", - "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)\n", - "Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)\n", - "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)\n", - "Collecting pybind11>=2.4\n", + "\u001b[?25hRequirement already satisfied: keras-preprocessing\u003c1.2,\u003e=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.1.2)\n", + "Requirement already satisfied: wrapt\u003e=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly-\u003etflite-model-maker) (1.12.1)\n", + "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.24.0)\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.16.0)\n", + "Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.3)\n", + "Requirement already satisfied: requests\u003e=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.23.0)\n", + "Requirement already satisfied: attrs\u003e=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (20.2.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (4.41.1)\n", + "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (0.3.2)\n", + "Collecting pybind11\u003e=2.4\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)\n", "\u001b[K |████████████████████████████████| 296kB 55.2MB/s \n", - "\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow-hub>=0.8.0->tflite-model-maker) (50.3.0)\n", - "Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)\n", - "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)\n", - "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)\n", - "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)\n", - "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)\n", - "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)\n", - "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2.8.1)\n", - "Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)\n", - "Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)\n", - "Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)\n", - "Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.17.4)\n", - "Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)\n", - "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)\n", - "Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)\n", - "Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)\n", - "Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)\n", - "Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)\n", - "Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (1.17.2)\n", - "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)\n", - "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)\n", - "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)\n", - "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)\n", - "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)\n", - "Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)\n", - "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)\n", - "Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)\n", - "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (4.1.1)\n", - "Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)\n", + "\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf\u003e=3.8.0-\u003etensorflow-hub\u003e=0.8.0-\u003etflite-model-maker) (50.3.0)\n", + "Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization\u003e=0.4.1-\u003etf-models-nightly-\u003etflite-model-maker) (0.1.5)\n", + "Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons-\u003etf-models-nightly-\u003etflite-model-maker) (2.7.1)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (4.0.1)\n", + "Requirement already satisfied: urllib3\u003c1.25,\u003e=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (1.24.3)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (2020.6.20)\n", + "Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (0.0.1)\n", + "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from kaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (2.8.1)\n", + "Requirement already satisfied: rsa\u003e=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (4.6)\n", + "Requirement already satisfied: pyasn1\u003e=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.4.8)\n", + "Requirement already satisfied: pyasn1-modules\u003e=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.2.8)\n", + "Requirement already satisfied: httplib2\u003e=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client-\u003etf-models-nightly-\u003etflite-model-maker) (0.17.4)\n", + "Requirement already satisfied: Keras\u003e=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval-\u003etf-models-nightly-\u003etflite-model-maker) (2.4.3)\n", + "Requirement already satisfied: cycler\u003e=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (0.10.0)\n", + "Requirement already satisfied: kiwisolver\u003e=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (1.2.0)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,\u003e=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib-\u003etf-models-nightly-\u003etflite-model-maker) (2.4.7)\n", + "Requirement already satisfied: google-resumable-media!=0.4.0,\u003c0.5.0dev,\u003e=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (0.4.1)\n", + "Requirement already satisfied: google-cloud-core\u003c2.0dev,\u003e=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (1.0.3)\n", + "Requirement already satisfied: google-auth-httplib2\u003e=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (0.0.4)\n", + "Requirement already satisfied: uritemplate\u003c4dev,\u003e=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (3.0.1)\n", + "Requirement already satisfied: google-auth\u003e=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (1.17.2)\n", + "Requirement already satisfied: pytz\u003e=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas\u003e=0.22.0-\u003etf-models-nightly-\u003etflite-model-maker) (2018.9)\n", + "Requirement already satisfied: markdown\u003e=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.2.2)\n", + "Requirement already satisfied: google-auth-oauthlib\u003c0.5,\u003e=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (0.4.1)\n", + "Requirement already satisfied: tensorboard-plugin-wit\u003e=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.7.0)\n", + "Requirement already satisfied: werkzeug\u003e=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.0.1)\n", + "Requirement already satisfied: googleapis-common-protos\u003c2,\u003e=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (1.52.0)\n", + "Requirement already satisfied: idna\u003c3,\u003e=2.5 in /usr/local/lib/python3.6/dist-packages (from requests\u003e=2.19.0-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (2.10)\n", + "Requirement already satisfied: chardet\u003c4,\u003e=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests\u003e=2.19.0-\u003etensorflow-datasets\u003e=2.1.0-\u003etflite-model-maker) (3.0.4)\n", + "Requirement already satisfied: text-unidecode\u003e=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify-\u003ekaggle\u003e=1.3.9-\u003etf-models-nightly-\u003etflite-model-maker) (1.3)\n", + "Requirement already satisfied: google-api-core\u003c2.0.0dev,\u003e=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core\u003c2.0dev,\u003e=1.0.3-\u003egoogle-cloud-bigquery\u003e=0.31.0-\u003etf-models-nightly-\u003etflite-model-maker) (1.16.0)\n", + "Requirement already satisfied: cachetools\u003c5.0,\u003e=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth\u003e=1.4.1-\u003egoogle-api-python-client\u003e=1.6.7-\u003etf-models-nightly-\u003etflite-model-maker) (4.1.1)\n", + "Requirement already satisfied: importlib-metadata; python_version \u003c \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown\u003e=2.6.8-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.7.0)\n", + "Requirement already satisfied: requests-oauthlib\u003e=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib\u003c0.5,\u003e=0.4.1-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (1.3.0)\n", + "Requirement already satisfied: zipp\u003e=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version \u003c \"3.8\"-\u003emarkdown\u003e=2.6.8-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.1.0)\n", + "Requirement already satisfied: oauthlib\u003e=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib\u003e=0.7.0-\u003egoogle-auth-oauthlib\u003c0.5,\u003e=0.4.1-\u003etb-nightly\u003c3.0.0a0,\u003e=2.4.0a0-\u003etf-nightly-\u003etflite-model-maker) (3.1.0)\n", "Building wheels for collected packages: fire, seqeval, pyyaml, py-cpuinfo\n", " Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=9eaa2d36e17621d136f8ab1707a5a4e8994c53d5076a9edde21aab7696ba3e09\n", @@ -273,16 +247,17 @@ " Uninstalling PyYAML-3.13:\n", " Successfully uninstalled PyYAML-3.13\n", "Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2\n" - ], - "name": "stdout" + ] } + ], + "source": [ + "!pip install tflite-model-maker" ] }, { "cell_type": "markdown", "metadata": { - "id": "l6lRhVK9Q_0U", - "colab_type": "text" + "id": "l6lRhVK9Q_0U" }, "source": [ "Import the required packages." @@ -290,11 +265,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "XtxiUeZEiXpt", - "colab_type": "code", - "colab": {} + "id": "XtxiUeZEiXpt" }, + "outputs": [], "source": [ "import numpy as np\n", "import os\n", @@ -303,18 +278,16 @@ "assert tf.__version__.startswith('2')\n", "\n", "from tflite_model_maker import configs\n", + "from tflite_model_maker import ExportFormat\n", "from tflite_model_maker import model_spec\n", "from tflite_model_maker import text_classifier\n", "from tflite_model_maker import TextClassifierDataLoader" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "BRd13bfetO7B", - "colab_type": "text" + "id": "BRd13bfetO7B" }, "source": [ "### Get the data path\n", @@ -323,51 +296,48 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "R2BSkxWg6Rhx", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, + "id": "R2BSkxWg6Rhx", "outputId": "972b735a-d5f6-4152-9c0f-12b7b97b8d86" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8\n", + "7446528/7439277 [==============================] - 0s 0us/step\n" + ] + } + ], "source": [ "data_dir = tf.keras.utils.get_file(\n", " fname='SST-2.zip',\n", - " origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n", + " origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n", " extract=True)\n", "data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8\n", - "7446528/7439277 [==============================] - 0s 0us/step\n" - ], - "name": "stdout" - } ] }, { "cell_type": "markdown", "metadata": { - "id": "6MSCjPAvs2EQ", - "colab_type": "text" + "id": "6MSCjPAvs2EQ" }, "source": [ "You can also upload your own dataset to work through this tutorial. Upload your dataset by using the left sidebar in Colab.\n", "\n", - "\"Upload\n" + "\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_text_classification.png\" alt=\"Upload File\" width=\"800\" hspace=\"100\"\u003e\n" ] }, { "cell_type": "markdown", "metadata": { - "id": "uO5egTlrtWxm", - "colab_type": "text" + "id": "uO5egTlrtWxm" }, "source": [ "If you prefer not to upload your dataset to the cloud, you can also locally run the library by following the [guide](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)." @@ -376,8 +346,7 @@ { "cell_type": "markdown", "metadata": { - "id": "xushUyZXqP59", - "colab_type": "text" + "id": "xushUyZXqP59" }, "source": [ "## End-to-End Workflow" @@ -386,8 +355,7 @@ { "cell_type": "markdown", "metadata": { - "id": "WlKU3SMX6TnB", - "colab_type": "text" + "id": "WlKU3SMX6TnB" }, "source": [ "This workflow consists of five steps as outlined below:" @@ -396,8 +364,7 @@ { "cell_type": "markdown", "metadata": { - "id": "PBPUIhEjMjTR", - "colab_type": "text" + "id": "PBPUIhEjMjTR" }, "source": [ "Step 1. Choose a model specification that represents a text classification model.\n", @@ -407,22 +374,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "CtdZ-JDwMimd", - "colab_type": "code", - "colab": {} + "id": "CtdZ-JDwMimd" }, + "outputs": [], "source": [ "spec = model_spec.get('mobilebert_classifier')" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "s5U-A3tw6Y27", - "colab_type": "text" + "id": "s5U-A3tw6Y27" }, "source": [ "Step 2. Load train and test data specific to an on-device ML app and preprocess the data according to a specific `model_spec`." @@ -430,11 +394,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "HD5BvzWe6YKa", - "colab_type": "code", - "colab": {} + "id": "HD5BvzWe6YKa" }, + "outputs": [], "source": [ "train_data = TextClassifierDataLoader.from_csv(\n", " filename=os.path.join(os.path.join(data_dir, 'train.tsv')),\n", @@ -450,15 +414,12 @@ " model_spec=spec,\n", " delimiter='\\t',\n", " is_training=False)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "2uZkLR6N6gDR", - "colab_type": "text" + "id": "2uZkLR6N6gDR" }, "source": [ "Step 3. Customize the TensorFlow model." @@ -466,22 +427,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "kwlYdTcg63xy", - "colab_type": "code", - "colab": {} + "id": "kwlYdTcg63xy" }, + "outputs": [], "source": [ "model = text_classifier.create(train_data, model_spec=spec)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "-BzCHLWJ6h7q", - "colab_type": "text" + "id": "-BzCHLWJ6h7q" }, "source": [ "Step 4. Evaluate the model." @@ -489,73 +447,64 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "8xmnl6Yy7ARn", - "colab_type": "code", - "colab": {} + "id": "8xmnl6Yy7ARn" }, + "outputs": [], "source": [ "loss, acc = model.evaluate(test_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "CgCDMe0e6jlT", - "colab_type": "text" + "id": "CgCDMe0e6jlT" }, "source": [ - "Step 5. Export as a TensorFlow Lite model.\n", + "Step 5. Export as a TensorFlow Lite model with [metadata](https://www.tensorflow.org/lite/convert/metadata).\n", "\n", "Since MobileBERT is too big for on-device applications, use [dynamic range quantization](https://www.tensorflow.org/lite/performance/post_training_quantization#dynamic_range_quantization) on the model to compress it by almost 4x with minimal performance degradation." ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "ZQRLmkGumr9Y", - "colab_type": "code", - "colab": {} + "id": "ZQRLmkGumr9Y" }, + "outputs": [], "source": [ "config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])\n", "config._experimental_new_quantizer = True" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "Hm_UULdW7A9T", - "colab_type": "code", - "colab": {} + "id": "Hm_UULdW7A9T" }, + "outputs": [], "source": [ "model.export(export_dir='mobilebert/', quantization_config=config)" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rVxaf3x_7OfB", - "colab_type": "text" - }, - "source": [ - "You can also download the model using the left sidebar in Colab.\n", - "\n", - "After executing the 5 steps above, you can further use the TensorFlow Lite model file and label file in on-device applications like in a [text classification](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification) reference app." ] }, { "cell_type": "markdown", "metadata": { - "id": "l65ctmtW7_FF", - "colab_type": "text" + "id": "rVxaf3x_7OfB" + }, + "source": [ + "You can also download the model using the left sidebar in Colab.\n", + "\n", + "After executing the 5 steps above, you can further use the TensorFlow Lite model file in on-device applications using [BertNLClassifier API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_nl_classifier) in [TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l65ctmtW7_FF" }, "source": [ "The following sections walk through the example step by step to show more detail." @@ -564,8 +513,7 @@ { "cell_type": "markdown", "metadata": { - "id": "kJ_B8fMDOhMR", - "colab_type": "text" + "id": "kJ_B8fMDOhMR" }, "source": [ "## Choose a `model_spec` that Represents a Model for Text Classifier\n", @@ -583,22 +531,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "vEAWuZQ1PFiX", - "colab_type": "code", - "colab": {} + "id": "vEAWuZQ1PFiX" }, + "outputs": [], "source": [ "spec = model_spec.get('average_word_vec')" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "ygEncJxtl-nQ", - "colab_type": "text" + "id": "ygEncJxtl-nQ" }, "source": [ "## Load Input Data Specific to an On-device ML App\n", @@ -610,26 +555,23 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "7tOfUr2KlgpU", - "colab_type": "code", - "colab": {} + "id": "7tOfUr2KlgpU" }, + "outputs": [], "source": [ "data_dir = tf.keras.utils.get_file(\n", " fname='SST-2.zip',\n", - " origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n", + " origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media\u0026token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',\n", " extract=True)\n", "data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "E051HBUM5owi", - "colab_type": "text" + "id": "E051HBUM5owi" }, "source": [ "The SST-2 dataset has `train.tsv` for training and `dev.tsv` for validation. The files have the following format:\n", @@ -646,11 +588,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "I_fOlZsklmlL", - "colab_type": "code", - "colab": {} + "id": "I_fOlZsklmlL" }, + "outputs": [], "source": [ "train_data = TextClassifierDataLoader.from_csv(\n", " filename=os.path.join(os.path.join(data_dir, 'train.tsv')),\n", @@ -666,15 +608,12 @@ " model_spec=spec,\n", " delimiter='\\t',\n", " is_training=False)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "MlHvVvv2hw4H", - "colab_type": "text" + "id": "MlHvVvv2hw4H" }, "source": [ "The Model Maker library also supports the `from_folder()` method to load data. It assumes that the text data of the same class are in the same subdirectory and that the subfolder name is the class name. Each text file contains one movie review sample. The `class_labels` parameter is used to specify which the subfolders." @@ -683,8 +622,7 @@ { "cell_type": "markdown", "metadata": { - "id": "AWuoensX4vDA", - "colab_type": "text" + "id": "AWuoensX4vDA" }, "source": [ "## Customize the TensorFlow Model\n", @@ -694,22 +632,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "TvYSUuJY3QxR", - "colab_type": "code", - "colab": {} + "id": "TvYSUuJY3QxR" }, + "outputs": [], "source": [ "model = text_classifier.create(train_data, model_spec=spec, epochs=10)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "0JKI-pNc8idH", - "colab_type": "text" + "id": "0JKI-pNc8idH" }, "source": [ "Examine the detailed model structure." @@ -717,22 +652,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "gd7Hs8TF8n3H", - "colab_type": "code", - "colab": {} + "id": "gd7Hs8TF8n3H" }, + "outputs": [], "source": [ "model.summary()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "LP5FPk_tOxoZ", - "colab_type": "text" + "id": "LP5FPk_tOxoZ" }, "source": [ "## Evaluate the Customized Model\n", @@ -744,57 +676,77 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "A8c2ZQ0J3Riy", - "colab_type": "code", - "colab": {} + "id": "A8c2ZQ0J3Riy" }, + "outputs": [], "source": [ "loss, acc = model.evaluate(test_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "aeHoGAceO2xV", - "colab_type": "text" + "id": "aeHoGAceO2xV" }, "source": [ "## Export as a TensorFlow Lite Model\n", "\n", - "Convert the existing model to TensorFlow Lite model format that you can later use in an on-device ML application. Save the text labels in a label file and vocabulary in a vocab file. The default TFLite filename is `model.tflite`, the default label filename is `label.txt` and the default vocab filename is `vocab`." + "Convert the existing model to TensorFlow Lite model format with [metadata](https://www.tensorflow.org/lite/convert/metadata) that you can later use in an on-device ML application. The label file and the vocab file are embedded in metadata. The default TFLite filename is `model.tflite`." ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "Im6wA9lK3TQB", - "colab_type": "code", - "colab": {} + "id": "Im6wA9lK3TQB" }, + "outputs": [], "source": [ "model.export(export_dir='average_word_vec/')" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w12kvDdHJIGH", - "colab_type": "text" - }, - "source": [ - "The TensorFlow Lite model file can be used in the [text classification](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification) reference app by adding `model.tflite` to the [assets directory](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification/android/lib_task_api/src/main/assets). Do not forget to also change the filenames in the [code](https://github.com/tensorflow/examples/blob/master/lite/examples/text_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/textclassification/TextClassificationClient.java#L31)." ] }, { "cell_type": "markdown", "metadata": { - "id": "HZKYthlVrTos", - "colab_type": "text" + "id": "w12kvDdHJIGH" + }, + "source": [ + "The TensorFlow Lite model file can be used in the [text classification](https://github.com/tensorflow/examples/tree/master/lite/examples/text_classification) reference app using [NLClassifier API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/nl_classifier) in [TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AVy0ormoMZwL" + }, + "source": [ + "The allowed export formats can be one or a list of the following:\n", + "\n", + "* `ExportFormat.TFLITE`\n", + "* `ExportFormat.LABEL`\n", + "* `ExportFormat.VOCAB`\n", + "* `ExportFormat.SAVED_MODEL`\n", + "\n", + "By default, it just exports TensorFlow Lite model with metadata. You can also selectively export different files. For instance, exporting only the label file and vocab file as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nbK7nzK_Mfx4" + }, + "outputs": [], + "source": [ + "model.export(export_dir='average_word_vec/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HZKYthlVrTos" }, "source": [ "You can evalute the tflite model with `evaluate_tflite` method." @@ -802,22 +754,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "ochbq95ZrVFX", - "colab_type": "code", - "colab": {} + "id": "ochbq95ZrVFX" }, + "outputs": [], "source": [ "model.evaluate_tflite('average_word_vec/model.tflite', test_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "EoWiA_zX8rxE", - "colab_type": "text" + "id": "EoWiA_zX8rxE" }, "source": [ "## Advanced Usage\n", @@ -833,8 +782,7 @@ { "cell_type": "markdown", "metadata": { - "id": "mwtiksguDfhl", - "colab_type": "text" + "id": "mwtiksguDfhl" }, "source": [ "### Adjust the model\n", @@ -845,8 +793,7 @@ { "cell_type": "markdown", "metadata": { - "id": "cAOd5_bzH9AQ", - "colab_type": "text" + "id": "cAOd5_bzH9AQ" }, "source": [ "For example, you can train the model with a larger value of `wordvec_dim`. Note that you must construct a new `model_spec` if you modify the model." @@ -854,22 +801,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "e9WBN0UTQoMN", - "colab_type": "code", - "colab": {} + "id": "e9WBN0UTQoMN" }, + "outputs": [], "source": [ "new_model_spec = model_spec.AverageWordVecModelSpec(wordvec_dim=32)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "6LSTdghTP0Cv", - "colab_type": "text" + "id": "6LSTdghTP0Cv" }, "source": [ "Get the preprocessed data." @@ -877,11 +821,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "DVZurFBORG3J", - "colab_type": "code", - "colab": {} + "id": "DVZurFBORG3J" }, + "outputs": [], "source": [ "new_train_data = TextClassifierDataLoader.from_csv(\n", " filename=os.path.join(os.path.join(data_dir, 'train.tsv')),\n", @@ -890,15 +834,12 @@ " model_spec=new_model_spec,\n", " delimiter='\\t',\n", " is_training=True)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "tD7QVVHeRZoM", - "colab_type": "text" + "id": "tD7QVVHeRZoM" }, "source": [ "Train the new model." @@ -906,22 +847,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "PzpV246_JGEu", - "colab_type": "code", - "colab": {} + "id": "PzpV246_JGEu" }, + "outputs": [], "source": [ "model = text_classifier.create(new_train_data, model_spec=new_model_spec)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "E8VxPiOLy4Gv", - "colab_type": "text" + "id": "E8VxPiOLy4Gv" }, "source": [ "You can also adjust the MobileBERT model.\n", @@ -944,23 +882,20 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "4tr9BLcjy4Sh", - "colab_type": "code", - "colab": {} + "id": "4tr9BLcjy4Sh" }, + "outputs": [], "source": [ "new_model_spec = model_spec.get('mobilebert_classifier')\n", "new_model_spec.seq_len = 256" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "LvQuy7RSDir3", - "colab_type": "text" + "id": "LvQuy7RSDir3" }, "source": [ "### Tune the training hyperparameters\n", @@ -974,22 +909,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "rnWFaYZBG6NW", - "colab_type": "code", - "colab": {} + "id": "rnWFaYZBG6NW" }, + "outputs": [], "source": [ "model = text_classifier.create(train_data, model_spec=spec, epochs=20)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "nUaKQZBQHBQR", - "colab_type": "text" + "id": "nUaKQZBQHBQR" }, "source": [ "Evaluate the newly retrained model with 20 training epochs." @@ -997,22 +929,19 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "BMPi1xflHDSY", - "colab_type": "code", - "colab": {} + "id": "BMPi1xflHDSY" }, + "outputs": [], "source": [ "loss, accuracy = model.evaluate(test_data)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "Eq6B9lKMfhS6", - "colab_type": "text" + "id": "Eq6B9lKMfhS6" }, "source": [ "### Change the Model Architecture\n", @@ -1024,26 +953,38 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "QfFCWrwyggrT", - "colab_type": "code", - "colab": {} + "id": "QfFCWrwyggrT" }, + "outputs": [], "source": [ "spec = model_spec.get('bert_classifier')" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { - "id": "L2d7yycrgu6L", - "colab_type": "text" + "id": "L2d7yycrgu6L" }, "source": [ "The remaining steps are the same." ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "model_maker_text_classification.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index 77253ca836d..6a77c5a5f11 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -368,7 +368,7 @@ class Interpreter { /// WARNING: NNAPI cannot be disabled after the graph has been prepared /// (via `AllocateTensors`) with NNAPI enabled. /// - /// NOTE: This API is deprecated, prefer using the NNAPI delegate directly. + /// WARNING: This API is deprecated, prefer using the NNAPI delegate directly. /// This method will be removed in a future release. void UseNNAPI(bool enable); @@ -380,8 +380,12 @@ class Interpreter { TfLiteStatus SetNumThreads(int num_threads); /// Allow float16 precision for FP32 calculation when possible. - /// default: not allow. - /// WARNING: This is an experimental API and subject to change. + /// Default: not allow. + /// + /// WARNING: This API is deprecated: prefer controlling this via delegate + /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or + /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. + /// This method will be removed in a future release. void SetAllowFp16PrecisionForFp32(bool allow); /// Get the half precision flag. @@ -576,6 +580,11 @@ class Interpreter { const Subgraph& primary_subgraph() const { return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry. } + + /// WARNING: Experimental interface, subject to change + // Get the error reporter associated with this interpreter. + ErrorReporter* error_reporter() const { return error_reporter_; } + #endif // DOXYGEN_SKIP private: @@ -602,9 +611,6 @@ class Interpreter { // Returns true if cancellation function returns true. bool IsCancelled(); - // Get the error reporter associated with this interpreter. - ErrorReporter* error_reporter() { return error_reporter_; } - // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index 6c2b2e24558..f5c8d97b962 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/profiling/platform_profiler.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/shared_library.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" @@ -165,7 +166,7 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { } int num_custom_ops = 0; for (const OperatorCode* opcode : *opcodes) { - if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { + if (GetBuiltinCode(opcode) == BuiltinOperator_CUSTOM) { num_custom_ops++; } } @@ -175,7 +176,7 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, ®istration); if (status != kTfLiteOk) { - if (opcode->builtin_code() != BuiltinOperator_CUSTOM) { + if (GetBuiltinCode(opcode) != BuiltinOperator_CUSTOM) { return status; } // If it's an unresolved custom op, allow it for now. It might be resolved diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index e1d69c6dc3a..e2044212340 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -485,18 +485,6 @@ public final class Interpreter implements AutoCloseable { return wrapper.getLastNativeInferenceDurationNanoseconds(); } - /** - * Turns on/off Android NNAPI for hardware acceleration when it is available. - * - * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API. - * This method will be removed in a future release. - */ - @Deprecated - public void setUseNNAPI(boolean useNNAPI) { - checkNotClosed(); - wrapper.setUseNNAPI(useNNAPI); - } - /** * Sets the number of threads to be used for ops that support multi-threading. * diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 6b1686f139f..3d439e4470c 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -218,10 +218,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native long allocateTensors(long interpreterHandle, long errorHandle); - void setUseNNAPI(boolean useNNAPI) { - useNNAPI(interpreterHandle, useNNAPI); - } - void setNumThreads(int numThreads) { numThreads(interpreterHandle, numThreads); } @@ -456,8 +452,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native String[] getOutputNames(long interpreterHandle); - private static native void useNNAPI(long interpreterHandle, boolean state); - private static native void numThreads(long interpreterHandle, int numThreads); private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow); diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 006c08f89c1..959acfb205e 100644 --- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -299,17 +299,6 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, return names; } -JNIEXPORT void JNICALL -Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, - jclass clazz, - jlong handle, - jboolean state) { - tflite_api_dispatcher::Interpreter* interpreter = - convertLongToInterpreter(env, handle); - if (interpreter == nullptr) return; - interpreter->UseNNAPI(static_cast(state)); -} - JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32( JNIEnv* env, jclass clazz, jlong handle, jboolean allow) { diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 4f205f1cdaf..f7e10ae796c 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -394,6 +394,7 @@ public final class InterpreterTest { } @Test + // setAllowFp16PrecisionForFp32 is deprecated, suppress the warning to allow testing. @SuppressWarnings("deprecation") public void testTurnOnNNAPI() throws Exception { Interpreter interpreter = diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java index d92a7119aab..de320fd68d6 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java @@ -76,8 +76,8 @@ public final class GpuDelegateTest { "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg"); Interpreter.Options options = new Interpreter.Options(); - try (GpuDelegate delegate = - new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true)); + // Default behavior allows quantized models. + try (GpuDelegate delegate = new GpuDelegate(); Interpreter interpreter = new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) { byte[][] output = new byte[1][1001]; @@ -98,12 +98,13 @@ public final class GpuDelegateTest { "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg"); Interpreter.Options options = new Interpreter.Options(); - try (GpuDelegate delegate = new GpuDelegate(); + try (GpuDelegate delegate = + new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(false)); Interpreter interpreter = new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) { byte[][] output = new byte[1][1001]; interpreter.run(img, output); - // Original execution plan remains since default behavior doesn't allow quantized models. + // Original execution plan remains since we disabled quantized models. assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(31); assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3}); assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001}); diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 040815b0a82..f7aa91dc24d 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -98,6 +98,7 @@ config_setting( ) ###### End of config_setting's to match aarch64 ###### + # Suppress warnings that are introduced by Eigen Tensor. EXTRA_EIGEN_COPTS = select({ "//tensorflow:ios": [ @@ -190,6 +191,7 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/nnapi:nnapi_implementation", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/testing:util", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", @@ -261,6 +263,7 @@ cc_library( hdrs = [ "eigen_support.h", ], + compatible_with = get_compatible_with_portable(), copts = tflite_copts() + EXTRA_EIGEN_COPTS, visibility = ["//visibility:private"], deps = [ @@ -298,6 +301,7 @@ cc_library( cc_library( name = "tflite_with_ruy_default", + compatible_with = get_compatible_with_portable(), visibility = ["//visibility:private"], deps = select({ ":chromiumos_arm64": [":tflite_with_ruy_enabled"], @@ -331,7 +335,13 @@ cc_library( "cpu_backend_context.h", ], compatible_with = get_compatible_with_portable(), - copts = tflite_copts(), + # TF Lite builds in other build systems should "opt in" to cpufinfo. + copts = tflite_copts() + select({ + "//tensorflow:linux_ppc64le": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:fuchsia": [], + "//conditions:default": ["-DTFLITE_HAVE_CPUINFO"], + }), deps = [ ":tflite_with_ruy", ":op_macros", @@ -342,7 +352,14 @@ cc_library( "@gemmlowp", "//tensorflow/lite/c:common", "//tensorflow/lite:external_cpu_backend_context", - ], + "//tensorflow/lite/kernels/internal:compatibility", + ] + select({ + # This select must match the similar select in `copts` + "//tensorflow:linux_ppc64le": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:fuchsia": [], + "//conditions:default": ["@cpuinfo//:cpuinfo_with_unstripped_include_path"], + }), ) cc_library( @@ -384,6 +401,7 @@ cc_library( "cpu_backend_gemm_eigen.cc", "cpu_backend_gemm_eigen.h", "cpu_backend_gemm_gemmlowp.h", + "cpu_backend_gemm_x86.h", ], hdrs = [ "cpu_backend_gemm.h", @@ -701,6 +719,7 @@ cc_library( "complex_support.cc", "cumsum.cc", "multinomial.cc", + "random_standard_normal.cc", "rfft2d.cc", ], hdrs = ["custom_ops_register.h"], @@ -1370,6 +1389,20 @@ cc_test( ], ) +cc_test( + name = "random_standard_normal_test", + size = "small", + srcs = ["random_standard_normal_test.cc"], + deps = [ + ":custom_ops", + ":test_main", + ":test_util", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "reshape_test_common", testonly = 1, diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc index ee9fc2c3418..0efac36be74 100644 --- a/tensorflow/lite/kernels/activations.cc +++ b/tensorflow/lite/kernels/activations.cc @@ -639,24 +639,15 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { static const double kBeta = 1.0; if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); - data->params.table = data->f_table; - optimized_ops::PopulateSoftmaxLookupTable(&data->params, - input->params.scale, kBeta); - data->params.zero_point = output->params.zero_point; - data->params.scale = output->params.scale; } if (input->type == kTfLiteInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127); - static const int kScaledDiffIntegerBits = 5; - tflite::PreprocessLogSoftmaxScalingExp( - kBeta, input->params.scale, kScaledDiffIntegerBits, - &data->input_multiplier, &data->input_left_shift, - &data->reverse_scaling_divisor, &data->reverse_scaling_right_shift); - data->reverse_scaling_right_shift *= -1; - data->diff_min = - -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, - data->input_left_shift); } + data->params.table = data->f_table; + optimized_ops::PopulateSoftmaxLookupTable(&data->params, + input->params.scale, kBeta); + data->params.zero_point = output->params.zero_point; + data->params.scale = output->params.scale; } return context->ResizeTensor(context, output, @@ -1188,18 +1179,26 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } case kTfLiteInt8: { - const auto input_shape = GetTensorShape(input); - const auto output_shape = GetTensorShape(output); - const int trailing_dim = input_shape.DimensionsCount() - 1; - const int outer_size = - MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); - const int depth = - MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - reference_integer_ops::LogSoftmax( - data->input_multiplier, data->input_left_shift, - data->reverse_scaling_divisor, data->reverse_scaling_right_shift, - data->diff_min, outer_size, depth, GetTensorData(input), - GetTensorData(output)); + if (kernel_type == kGenericOptimized) { + SoftmaxParams op_params = data->params; + optimized_ops::LogSoftmax( + op_params, input->params.scale, GetTensorShape(input), + GetTensorData(input), GetTensorShape(output), + GetTensorData(output)); + } else { + const auto input_shape = GetTensorShape(input); + const auto output_shape = GetTensorShape(output); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + reference_integer_ops::LogSoftmax( + data->input_multiplier, data->input_left_shift, + data->reverse_scaling_divisor, data->reverse_scaling_right_shift, + data->diff_min, outer_size, depth, GetTensorData(input), + GetTensorData(output)); + } return kTfLiteOk; } default: diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc index 35cf57128e7..5f6afa3d14f 100644 --- a/tensorflow/lite/kernels/batch_matmul.cc +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -314,7 +314,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. - if (lhs_data->type == kTfLiteInt8) { + if (lhs_data->type == kTfLiteInt8 || lhs_data->type == kTfLiteInt16) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, lhs_data, rhs_data, output, &real_multiplier)); @@ -322,16 +322,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent); op_data->output_shift = exponent; // BatchMatMul has no fused activation functions. Therefore, set - // output activation min and max to min and max of int8_t type, - // respecitvely. - op_data->output_activation_min = std::numeric_limits::min(); - op_data->output_activation_max = std::numeric_limits::max(); + // output activation min and max to min and max of int8_t or int16_t + // type. + if (lhs_data->type == kTfLiteInt8) { + op_data->output_activation_min = std::numeric_limits::min(); + op_data->output_activation_max = std::numeric_limits::max(); + } else { + op_data->output_activation_min = std::numeric_limits::min(); + op_data->output_activation_max = std::numeric_limits::max(); + } + } + + if (lhs_data->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, lhs_data->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, rhs_data->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); } TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 || - lhs_data->type == kTfLiteInt8); + lhs_data->type == kTfLiteInt8 || + lhs_data->type == kTfLiteInt16); TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 || - rhs_data->type == kTfLiteInt8); + rhs_data->type == kTfLiteInt8 || + rhs_data->type == kTfLiteInt16); + // Either we have a hybrid quantization with a float32 and an int8 input, + // otherwise both inputs should be of the same type. + TF_LITE_ENSURE(context, (lhs_data->type == kTfLiteFloat32 && + rhs_data->type == kTfLiteInt8) || + lhs_data->type == rhs_data->type); // Support dimensions between 2 and 4, inclusive. TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2); TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 4); @@ -402,9 +420,14 @@ TfLiteStatus TransposeRowsColumns(TfLiteContext* context, tensor_in, GetTensorData(tensor_in), tensor_out, GetTensorData(tensor_out)); return kTfLiteOk; + } else if (tensor_in->type == kTfLiteInt16) { + TransposeRowsColumnsImpl( + tensor_in, GetTensorData(tensor_in), tensor_out, + GetTensorData(tensor_out)); + return kTfLiteOk; } else { - TF_LITE_KERNEL_LOG(context, - "Can only transpose tensors with float and int8 type."); + TF_LITE_KERNEL_LOG( + context, "Can only transpose tensors with float, int8 or int16 type."); return kTfLiteError; } } @@ -501,10 +524,10 @@ TfLiteStatus EvalInt8(TfLiteContext* context, const OpData* data, op_params.rhs_cacheable = IsConstantTensor(rhs); if (kernel_type == kReference) { - reference_ops::BatchMatMul(op_params, rhs_shape, GetTensorData(rhs), - lhs_shape, GetTensorData(lhs), - GetTensorShape(output), - GetTensorData(output)); + reference_ops::BatchMatMul( + op_params, rhs_shape, GetTensorData(rhs), lhs_shape, + GetTensorData(lhs), GetTensorShape(output), + GetTensorData(output)); } else { optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData(rhs), lhs_shape, GetTensorData(lhs), @@ -515,13 +538,40 @@ TfLiteStatus EvalInt8(TfLiteContext* context, const OpData* data, return kTfLiteOk; } +template +TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data, + const RuntimeShape& lhs_shape, const TfLiteTensor* lhs, + const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, + const RuntimeShape& output_shape, TfLiteTensor* output) { + // Reuse params struct from FullyConnected Op. + FullyConnectedParams op_params; + int32_t input_offset = -lhs->params.zero_point; + int32_t filter_offset = -rhs->params.zero_point; + int32_t output_offset = output->params.zero_point; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + // optimized_ops not yet implemnted for int16_t, use reference_ops in all + // cases. + reference_ops::BatchMatMul( + op_params, rhs_shape, GetTensorData(rhs), lhs_shape, + GetTensorData(lhs), GetTensorShape(output), + GetTensorData(output)); + return kTfLiteOk; +} + template TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, OpData* data, const RuntimeShape& lhs_shape, const TfLiteTensor* lhs, const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, TfLiteTensor* output) { - if (lhs->type == kTfLiteFloat32) { + if (lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) { TfLiteTensor* input_quantized; TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, &input_quantized)); @@ -540,12 +590,16 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, return EvalHybrid( context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, scaling_factors, accum_scratch, row_sums, input_offsets, output); - } else if (lhs->type == kTfLiteInt8) { + } else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) { return EvalInt8(context, data, lhs_shape, lhs, rhs_shape, rhs, GetTensorShape(output), output); + } else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt16) { + return EvalInt16(context, data, lhs_shape, lhs, rhs_shape, rhs, + GetTensorShape(output), output); } else { TF_LITE_KERNEL_LOG( - context, "Currently only hybrid and int8 quantization is supported.\n"); + context, + "Currently only hybrid, int8 and int16 quantization are supported.\n"); return kTfLiteError; } return kTfLiteOk; @@ -558,7 +612,7 @@ TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node, return nullptr; } - if (rhs->type == kTfLiteInt8) { + if (rhs->type == kTfLiteInt8 || rhs->type == kTfLiteInt16) { // Get the quantization params from the RHS tensor. transposed_rhs->params.scale = rhs->params.scale; transposed_rhs->params.zero_point = rhs->params.zero_point; @@ -573,7 +627,7 @@ TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node, return nullptr; } - if (lhs->type == kTfLiteInt8) { + if (lhs->type == kTfLiteInt8 || lhs->type == kTfLiteInt16) { // Get the quantization params from the LHS tensor. transposed_lhs->params.scale = lhs->params.scale; transposed_lhs->params.zero_point = lhs->params.zero_point; @@ -646,6 +700,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } break; case kTfLiteInt8: + case kTfLiteInt16: EvalQuantized(context, node, op_data, lhs_shape, lhs_tensor, rhs_shape, rhs_tensor, output); break; diff --git a/tensorflow/lite/kernels/batch_matmul_test.cc b/tensorflow/lite/kernels/batch_matmul_test.cc index 98df8ebe3db..7abef73d5a2 100644 --- a/tensorflow/lite/kernels/batch_matmul_test.cc +++ b/tensorflow/lite/kernels/batch_matmul_test.cc @@ -483,7 +483,12 @@ class QuantizedBatchMatMulOpModel : public SingleOpModel { input_size_ = total_input_size / batches_; lhs_id_ = AddInput(lhs); - rhs_id_ = AddInput({lhs.type, {input_size_, units_}, lhs.min, lhs.max}); + rhs_id_ = AddInput({lhs.type, + {input_size_, units_}, + 0, + 0, + GetScale(lhs_id_), + GetZeroPoint(lhs_id_)}); output_id_ = AddOutput(output); @@ -553,6 +558,35 @@ TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8) { EXPECT_THAT(m.GetOutput(), ElementsAre(22, 22, 22, 56, 56, 56)); } +TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt16) { + const float inputs_scale = 10.0 / std::numeric_limits::max(); + const float output_scale = 1.0; + const int32_t zero_point = 0; + + QuantizedBatchMatMulOpModel m( + /*units=*/3, /*batches*/ 2, + /*lhs=*/ + {TensorType_INT16, {2, 10}, 0, 0, inputs_scale, zero_point}, + /*output=*/ + {TensorType_INT16, {}, 0, 0, output_scale, zero_point}); + + m.SetWeights({ + 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, + }); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57}))); + EXPECT_THAT(m.GetOutput(), ElementsAre(23, 23, 23, 57, 57, 57)); +} + INSTANTIATE_TEST_SUITE_P( QuantizedBatchMatMulOpTest, QuantizedBatchMatMulOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); diff --git a/tensorflow/lite/kernels/cpu_backend_context.cc b/tensorflow/lite/kernels/cpu_backend_context.cc index a99d08769ea..6eacb3d2216 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.cc +++ b/tensorflow/lite/kernels/cpu_backend_context.cc @@ -17,10 +17,15 @@ limitations under the License. #include +#ifdef TFLITE_HAVE_CPUINFO +#include "include/cpuinfo.h" +#endif + #include "public/gemmlowp.h" #include "ruy/context.h" // from @ruy #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/external_cpu_backend_context.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/op_macros.h" namespace { @@ -30,6 +35,62 @@ const int kDefaultNumThreadpoolThreads = 1; namespace tflite { +#ifdef TFLITE_HAVE_CPUINFO +CpuBackendContext::CpuInfo::~CpuInfo() { + if (init_status_ == InitStatus::kInitialized) { + cpuinfo_deinitialize(); + } +} + +bool CpuBackendContext::CpuInfo::EnsureInitialized() { + if (init_status_ == InitStatus::kNotYetAttempted) { + init_status_ = Initialize(); + } + return init_status_ == InitStatus::kInitialized; +} + +CpuBackendContext::CpuInfo::InitStatus +CpuBackendContext::CpuInfo::Initialize() { + TFLITE_DCHECK_EQ(init_status_, InitStatus::kNotYetAttempted); + if (!cpuinfo_initialize()) { + return InitStatus::kFailed; + } + return InitStatus::kInitialized; +} + +bool CpuBackendContext::CpuInfo::Avx2Fma() { + return EnsureInitialized() && cpuinfo_has_x86_avx2() && + cpuinfo_has_x86_fma3(); +} + +bool CpuBackendContext::CpuInfo::Avx() { + return EnsureInitialized() && cpuinfo_has_x86_avx(); +} + +bool CpuBackendContext::CpuInfo::Avx512() { + return EnsureInitialized() && cpuinfo_has_x86_avx512f() && + cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512cd() && + cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512vl(); +} +#else + +CpuBackendContext::CpuInfo::~CpuInfo() {} + +bool CpuBackendContext::CpuInfo::EnsureInitialized() { + if (init_status_ == InitStatus::kNotYetAttempted) { + init_status_ = InitStatus::kInitialized; + } + TFLITE_DCHECK_EQ(init_status_, InitStatus::kInitialized); + return true; +} + +bool CpuBackendContext::CpuInfo::Avx2Fma() { return false; } + +bool CpuBackendContext::CpuInfo::Avx() { return false; } + +bool CpuBackendContext::CpuInfo::Avx512() { return false; } +#endif // TFLITE_HAVE_CPUINFO + CpuBackendContext* CpuBackendContext::GetFromContext(TfLiteContext* context) { auto* external_context = static_cast( context->GetExternalContext(context, kTfLiteCpuBackendContext)); @@ -79,4 +140,8 @@ void CpuBackendContext::SetMaxNumThreads(int max_num_threads) { void CpuBackendContext::SetUseCaching(bool flag) { use_caching_ = flag; } +bool CpuBackendContext::HasAvxOrAbove() { + return cpuinfo_.Avx() || cpuinfo_.Avx2Fma() || cpuinfo_.Avx512(); +} + } // namespace tflite diff --git a/tensorflow/lite/kernels/cpu_backend_context.h b/tensorflow/lite/kernels/cpu_backend_context.h index 124b9b849a2..e0207176eb4 100644 --- a/tensorflow/lite/kernels/cpu_backend_context.h +++ b/tensorflow/lite/kernels/cpu_backend_context.h @@ -50,7 +50,35 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { void ClearCaches() override { ruy_context_->ClearPrepackedCache(); } + bool HasAvxOrAbove(); + private: + // Copy the wrapper class for cpuinfo from Ruy. + class CpuInfo final { + public: + CpuInfo() {} + ~CpuInfo(); + + // X86 features + bool Avx(); + bool Avx2Fma(); + bool Avx512(); + + private: + enum class InitStatus { + kNotYetAttempted, + kInitialized, + kFailed, + }; + + InitStatus init_status_ = InitStatus::kNotYetAttempted; + + bool EnsureInitialized(); + InitStatus Initialize(); + CpuInfo(const CpuInfo&) = delete; + CpuInfo& operator=(const CpuInfo&) = delete; + }; + // To enable a smooth transition from the current direct usage // of the underlying gemmlowp context to going through abstractions // (see :cpu_backend_gemm), for now a CpuBackendContext always @@ -59,6 +87,7 @@ class CpuBackendContext final : public TfLiteInternalBackendContext { // elide what can be elided based on TFLITE_WITH_RUY. const std::unique_ptr ruy_context_; const std::unique_ptr gemmlowp_context_; + CpuInfo cpuinfo_; // The maximum of threads used for parallelizing TfLite ops. However, // cpu_backend_threadpool::Execute creates as many threads as it's diff --git a/tensorflow/lite/kernels/cpu_backend_gemm.h b/tensorflow/lite/kernels/cpu_backend_gemm.h index 14ff571e7da..6950e182dfa 100644 --- a/tensorflow/lite/kernels/cpu_backend_gemm.h +++ b/tensorflow/lite/kernels/cpu_backend_gemm.h @@ -27,22 +27,54 @@ limitations under the License. #ifndef TFLITE_WITH_RUY #include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" #include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_x86.h" #endif namespace tflite { namespace cpu_backend_gemm { +// The main entry point for CpuBackendGemm::Gemm. +// +// If TFLITE_WITH_RUY is set, CpuBackendGemm::Gemm will always go to Ruy aka +// GemmImplUsingRuy. Other cases are as follows: +// +// |Quantized (uint8)|Quantized (int8)| Float | +// TFLITE_WITH_RUY | Ruy | Ruy | Ruy | +// !TFLITE_WITH_RUY | gemmlowp | Ruy/gemmlowp* | eigen | +// * - Ruy if NEON is not available. + +// On x86 platforms: +// (default) | gemmlowp | Ruy | eigen | +// TFLITE_X86_RUY_\ | Ruy | Ruy | Ruy | +// ENABLED && (AVX +// or above available) + + +#if (defined(__i386) || defined(_M_IX86) || defined(__x86_64__) || \ + defined(_M_X64)) +#define TFLITE_X86_PLATFORM +#endif + +// TODO(b/168923364) Set TFLITE_X86_RUY_ENABLED default 'on' when ready. +#if defined(TFLITE_X86_PLATFORM) && defined(TFLITE_X86_RUY_ENABLED) +/* GEMM dispatch implementation for x86. + */ +template +struct GemmImpl : detail::GemmImplX86 {}; +#else /* Generic implementation using ruy. * Non-ruy implementation will be partial specializations of this template. */ - template struct GemmImpl : detail::GemmImplUsingRuy {}; +#endif -#ifndef TFLITE_WITH_RUY +#if !defined(TFLITE_WITH_RUY) && !defined(TFLITE_X86_RUY_ENABLED) /* Specializations using gemmlowp */ @@ -56,7 +88,7 @@ struct GemmImpl struct GemmImpl @@ -82,7 +114,7 @@ template <> struct GemmImpl : detail::GemmImplUsingEigen {}; -#endif // not TFLITE_WITH_RUY +#endif // not TFLITE_WITH_RUY && not TFLITE_X86_RUY_ENABLED /* Public entry point */ diff --git a/tensorflow/lite/kernels/cpu_backend_gemm_x86.h b/tensorflow/lite/kernels/cpu_backend_gemm_x86.h new file mode 100644 index 00000000000..20af9536d47 --- /dev/null +++ b/tensorflow/lite/kernels/cpu_backend_gemm_x86.h @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_ +#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_ + +// If TFLITE_WITH_RUY is set, Ruy is the only GEMM option. In this header +// we select either Ruy or an alternative based on the SIMD extentions +// available on the given x86 platform. +#ifndef TFLITE_WITH_RUY + +#include "tensorflow/lite/kernels/cpu_backend_context.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_eigen.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_gemmlowp.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" +#include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" + +namespace tflite { +namespace cpu_backend_gemm { +namespace detail { + +template +struct GemmImplX86 { + static void Run( + const MatrixParams& lhs_params, const LhsScalar* lhs_data, + const MatrixParams& rhs_params, const RhsScalar* rhs_data, + const MatrixParams& dst_params, DstScalar* dst_data, + const GemmParams& params, + CpuBackendContext* context) { + // Run-time dispatch to Ruy for platforms with AVX or above. + if (context->HasAvxOrAbove()) { + detail::GemmImplUsingRuy::Run(lhs_params, lhs_data, + rhs_params, rhs_data, + dst_params, dst_data, + params, context); + } else { + // Dispatch to gemmlowp for SSE. + detail::GemmImplUsingGemmlowp< + LhsScalar, RhsScalar, AccumScalar, DstScalar, + quantization_flavor>::Run(lhs_params, lhs_data, rhs_params, rhs_data, + dst_params, dst_data, params, context); + } + } +}; + +// For float, again prefer Ruy in all cases, but defer to eigen if no flavor of +// AVX is present. +template <> +struct GemmImplX86 { + static void Run(const MatrixParams& lhs_params, const float* lhs_data, + const MatrixParams& rhs_params, const float* rhs_data, + const MatrixParams& dst_params, float* dst_data, + const GemmParams& params, + CpuBackendContext* context) { + // Run-time dispatch to Ruy for platforms with AVX or above. + if (context->HasAvxOrAbove()) { + detail::GemmImplUsingRuy< + float, float, float, float, + QuantizationFlavor::kFloatingPoint>::Run(lhs_params, lhs_data, + rhs_params, rhs_data, + dst_params, dst_data, params, + context); + } else { + // Dispatch to gemmlowp for SSE. + GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, + dst_params, dst_data, params, context); + } + } +}; + +// gemmlowp requires NEON for certain quantization cases. See note in +// cpu_backend_gemm.h +#if !defined(GEMMLOWP_NEON) +template +struct GemmImplX86 + : detail::GemmImplUsingRuy {}; + +template +struct GemmImplX86 + : detail::GemmImplUsingRuy {}; + +template +struct GemmImplX86 + : detail::GemmImplUsingRuy {}; +#endif // not GEMMLOWP_NEON +} // namespace detail +} // namespace cpu_backend_gemm +} // namespace tflite + +#endif // not TFLITE_WITH_RUY + +#endif // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_X86_H_ diff --git a/tensorflow/lite/kernels/custom_ops_register.h b/tensorflow/lite/kernels/custom_ops_register.h index 94d4fec5347..8aadd379a43 100644 --- a/tensorflow/lite/kernels/custom_ops_register.h +++ b/tensorflow/lite/kernels/custom_ops_register.h @@ -28,6 +28,7 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT(); TfLiteRegistration* Register_HASHTABLE_SIZE(); TfLiteRegistration* Register_IMAG(); TfLiteRegistration* Register_MULTINOMIAL(); +TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL(); TfLiteRegistration* Register_REAL(); TfLiteRegistration* Register_RFFT2D(); diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 4d4ea84685e..94135c6adbe 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -347,6 +347,7 @@ cc_library( "optimized/eigen_tensor_reduced_instantiations_oss.h", "optimized/multithreaded_conv.h", ], + compatible_with = get_compatible_with_portable(), copts = tflite_copts(), deps = [ ":common", diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index f4cd76386eb..6cabea11ac4 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -4307,30 +4307,36 @@ inline void LogSoftmax(const SoftmaxParams& params, // TODO(tflite): notes for optimization: // 1) See if e^ is also bottleneck in the reference fully-integer // version and apply lookup there and compare. +template inline void LogSoftmax(const SoftmaxParams& params, float input_scale, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { - ruy::profiler::ScopeLabel label("LogSoftmax/Uint8"); + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { + ruy::profiler::ScopeLabel label("LogSoftmax"); const int trailing_dim = input_shape.DimensionsCount() - 1; const int excluding_last_dim = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); const int last_dim = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); - const int32_t clamp_max = std::numeric_limits::max(); - const int32_t clamp_min = std::numeric_limits::min(); + const int32_t clamp_max = std::numeric_limits::max(); + const int32_t clamp_min = std::numeric_limits::min(); + + int32_t zero_point_offset = 0; + if (std::is_same::value) { + zero_point_offset = 128; + } for (int i = 0; i < excluding_last_dim; ++i) { - uint8_t max_val = std::numeric_limits::min(); + T max_val = std::numeric_limits::min(); // Find max quantized value. for (int j = 0; j < last_dim; ++j) { max_val = std::max(max_val, input_data[j]); } float sum_exp = 0.0f; - const int32_t max_uint8 = std::numeric_limits::max(); + const int32_t max_q8 = std::numeric_limits::max(); // Offset into table to compute exp(scale*(x - xmax)) instead of // exp(scale*(x)) to prevent overflow. - const float* table_offset = ¶ms.table[max_uint8 - max_val]; + const float* table_offset = ¶ms.table[max_q8 - max_val]; // Calculate sum(exp(scale*(x - x_max))). for (int j = 0; j < last_dim; ++j) { sum_exp += table_offset[input_data[j]]; @@ -4340,7 +4346,8 @@ inline void LogSoftmax(const SoftmaxParams& params, float input_scale, // params.scale is the output scale. const float scale = input_scale / params.scale; const float precomputed = - (input_scale * max_val + log_sum_exp) / params.scale; + (input_scale * (max_val + zero_point_offset) + log_sum_exp) / + params.scale; for (int j = 0; j < last_dim; ++j) { // Equivalent to (input_scale * (input_data[j] - max_val) - log_sum_exp) / // output_scale. @@ -4350,7 +4357,7 @@ inline void LogSoftmax(const SoftmaxParams& params, float input_scale, // Use std::rint over std::round (which is used in // FakeQuant) since it's multiple times faster on tested arm32. const int32_t prob_quantized = std::rint(log_prob) + params.zero_point; - output_data[j] = static_cast( + output_data[j] = static_cast( std::max(std::min(clamp_max, prob_quantized), clamp_min)); } input_data += last_dim; diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h index 24c3ffe3d7e..f06199c7700 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -217,10 +217,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, } } +template inline void BatchMatMul(const FullyConnectedParams& params, - const RuntimeShape& lhs_shape, const int8_t* lhs_data, - const RuntimeShape& rhs_shape, const int8_t* rhs_data, - const RuntimeShape& output_shape, int8_t* output_data) { + const RuntimeShape& lhs_shape, const T* lhs_data, + const RuntimeShape& rhs_shape, const T* rhs_data, + const RuntimeShape& output_shape, T* output_data) { const RuntimeShape extended_lhs_shape = RuntimeShape::ExtendedShape(5, lhs_shape); const RuntimeShape extended_rhs_shape = @@ -276,33 +277,33 @@ inline void BatchMatMul(const FullyConnectedParams& params, TFLITE_DCHECK_LE(output_activation_min, output_activation_max); for (int b0 = 0; b0 < batch_dim0; ++b0) { - const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); - const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); + const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); + const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); for (int b1 = 0; b1 < batch_dim1; ++b1) { - const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; - const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; + const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; + const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; for (int b2 = 0; b2 < batch_dim2; ++b2) { - const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; - const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; - int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + - b1 * batch_dim2 + b2) * - lhs_rows * rhs_cols; + const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; + const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; + T* out_ptr = output_data + + ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) * + lhs_rows * rhs_cols; for (int j = 0; j < rhs_cols; ++j) { for (int i = 0; i < lhs_rows; ++i) { - int32_t total = 0; + AccumT total = 0; for (int k = 0; k < accum_depth; ++k) { - int32_t lhs_val = lhs_ptr2[accum_depth * i + k]; - int32_t rhs_val = rhs_ptr2[accum_depth * j + k]; + AccumT lhs_val = lhs_ptr2[accum_depth * i + k]; + AccumT rhs_val = rhs_ptr2[accum_depth * j + k]; total += (lhs_val + filter_offset) * (rhs_val + input_offset); } - total = MultiplyByQuantizedMultiplier(total, output_multiplier, - output_shift); - total += output_offset; - total = std::max(total, output_activation_min); - total = std::min(total, output_activation_max); + int32_t total_scaled = MultiplyByQuantizedMultiplier( + total, output_multiplier, output_shift); + total_scaled += output_offset; + total_scaled = std::max(total_scaled, output_activation_min); + total_scaled = std::min(total_scaled, output_activation_max); const int idx = lhs_rows * j + i; - out_ptr[idx] = static_cast(total); + out_ptr[idx] = static_cast(total_scaled); } } } diff --git a/tensorflow/lite/kernels/internal/reference/svdf.h b/tensorflow/lite/kernels/internal/reference/svdf.h index bb986e4de0a..c61abf3adb5 100644 --- a/tensorflow/lite/kernels/internal/reference/svdf.h +++ b/tensorflow/lite/kernels/internal/reference/svdf.h @@ -36,7 +36,7 @@ namespace reference_ops { static inline void ApplyTimeWeightsBiasAndActivation( int batch_size, int memory_size, int num_filters, int num_units, int rank, - const float* const __restrict__ weights_time_ptr, + const float* const __restrict__ weights_time_data, const float* const __restrict__ bias_ptr, TfLiteFusedActivation activation, float* const __restrict__ state_ptr, float* const __restrict__ scratch_ptr, float* const __restrict__ output_ptr) { @@ -45,7 +45,7 @@ static inline void ApplyTimeWeightsBiasAndActivation( float* state_ptr_batch = state_ptr + b * memory_size * num_filters; float* scratch_ptr_batch = scratch_ptr + b * num_filters; tensor_utils::BatchVectorBatchVectorDotProduct( - weights_time_ptr, state_ptr_batch, memory_size, num_filters, + weights_time_data, state_ptr_batch, memory_size, num_filters, scratch_ptr_batch); } @@ -74,44 +74,41 @@ static inline void ApplyTimeWeightsBiasAndActivation( } inline void EvalIntegerSVDF( - TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input_tensor, - const TfLiteTensor* weights_feature_tensor, - const TfLiteTensor* weights_time_tensor, const TfLiteTensor* bias_tensor, - const TfLiteSVDFParams* params, TfLiteTensor* state_tensor, - TfLiteTensor* output_tensor, TfLiteTensor* scratch_tensor, - TfLiteTensor* output_temp_tensor, int32_t scale_1_a, int scale_1_b, - int32_t scale_2_a, int scale_2_b, int32_t input_zp, int32_t output_zp) { + const TfLiteSVDFParams* params, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& weights_feature_shape, + const int8_t* weights_feature_data, const RuntimeShape& weights_time_shape, + const int16_t* weights_time_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, int16_t* state_data, + const RuntimeShape& output_shape, int8_t* output_data, + int32_t* scratch_data, int32_t* output_temp_data, int32_t scale_1_a, + int scale_1_b, int32_t scale_2_a, int scale_2_b, int32_t input_zp, + int32_t output_zp) { const int n_rank = params->rank; - const int n_batch = input_tensor->dims->data[0]; - const int n_input = input_tensor->dims->data[1]; - const int n_filter = weights_feature_tensor->dims->data[0]; + const int n_batch = input_shape.Dims(0); + const int n_input = input_shape.Dims(1); + const int n_filter = weights_feature_shape.Dims(0); const int n_unit = n_filter / n_rank; - const int n_memory = weights_time_tensor->dims->data[1]; - - int16_t* const state_ptr = GetTensorData(state_tensor); + const int n_memory = weights_time_shape.Dims(1); // Left shift the activation_state. // std::copy is fine for overlapping ranges if the output is outside of the // input range. (This is not true for copy_n.) - std::copy(state_ptr + 1, state_ptr + n_batch * n_memory * n_filter, - state_ptr); + std::copy(state_data + 1, state_data + n_batch * n_memory * n_filter, + state_data); // Feature matmul. // Note: no need to clear the latest activation, matmul is not accumulative. { - const int8_t* input = GetTensorData(input_tensor); - const int8_t* weight_feature = - GetTensorData(weights_feature_tensor); const int32_t output_max = std::numeric_limits::max(); const int32_t output_min = std::numeric_limits::min(); - int16_t* result_in_batch = state_ptr + (n_memory - 1); + int16_t* result_in_batch = state_data + (n_memory - 1); for (int b = 0; b < n_batch; b++) { - const int8_t* matrix_ptr = weight_feature; + const int8_t* matrix_data = weights_feature_data; for (int r = 0; r < n_filter; r++) { int32_t dot_prod = 0; - const int8_t* vector_in_batch = input + b * n_input; + const int8_t* vector_in_batch = input_data + b * n_input; for (int c = 0; c < n_input; c++) { - dot_prod += *matrix_ptr++ * (*vector_in_batch++ - input_zp); + dot_prod += *matrix_data++ * (*vector_in_batch++ - input_zp); } dot_prod = MultiplyByQuantizedMultiplier(dot_prod, scale_1_a, scale_1_b); @@ -131,171 +128,134 @@ inline void EvalIntegerSVDF( // Time. { for (int b = 0; b < n_batch; ++b) { - const int16_t* state_ptr_batch = state_ptr + b * n_memory * n_filter; - int32_t* scratch_ptr_batch = - GetTensorData(scratch_tensor) + b * n_filter; + const int16_t* state_data_batch = state_data + b * n_memory * n_filter; + int32_t* scratch_data_batch = scratch_data + b * n_filter; tensor_utils::BatchVectorBatchVectorDotProduct( - GetTensorData(weights_time_tensor), state_ptr_batch, - n_memory, n_filter, scratch_ptr_batch); + weights_time_data, state_data_batch, n_memory, n_filter, + scratch_data_batch); } } // Reduce, add bias, rescale, activation. { - int32_t* output_temp = GetTensorData(output_temp_tensor); // Add bias. - if (bias_tensor) { - tensor_utils::VectorBatchVectorAssign(GetTensorData(bias_tensor), - n_unit, n_batch, output_temp); + if (bias_data) { + tensor_utils::VectorBatchVectorAssign(bias_data, n_unit, n_batch, + output_temp_data); } else { - std::fill_n(output_temp, n_batch * n_unit, 0); + std::fill_n(output_temp_data, n_batch * n_unit, 0); } // Reduce. for (int b = 0; b < n_batch; ++b) { - int32_t* output_temp_ptr = output_temp + b * n_unit; - int32_t* scratch_ptr_batch = - GetTensorData(scratch_tensor) + b * n_filter; - tensor_utils::ReductionSumVector(scratch_ptr_batch, output_temp_ptr, + int32_t* output_temp_ptr = output_temp_data + b * n_unit; + int32_t* scratch_data_batch = scratch_data + b * n_filter; + tensor_utils::ReductionSumVector(scratch_data_batch, output_temp_ptr, n_unit, n_rank); } // Rescale. const int32_t output_max = std::numeric_limits::max(); const int32_t output_min = std::numeric_limits::min(); for (int i = 0; i < n_batch * n_unit; ++i) { - int32_t x1 = output_temp[i]; + int32_t x1 = output_temp_data[i]; int32_t x2 = MultiplyByQuantizedMultiplier(x1, scale_2_a, scale_2_b); int32_t x3 = x2 + output_zp; int32_t x4 = std::min(std::max(output_min, x3), output_max); - GetTensorData(output_tensor)[i] = static_cast(x4); + output_data[i] = static_cast(x4); } } } -inline void EvalFloatSVDF(TfLiteContext* context, TfLiteNode* node, - const TfLiteTensor* input, - const TfLiteTensor* weights_feature, - const TfLiteTensor* weights_time, - const TfLiteTensor* bias, - const TfLiteSVDFParams* params, TfLiteTensor* scratch, - TfLiteTensor* state, TfLiteTensor* output) { +inline void EvalFloatSVDF( + const TfLiteSVDFParams* params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& weights_feature_shape, + const float* weights_feature_data, const RuntimeShape& weights_time_shape, + const float* weights_time_data, const RuntimeShape& bias_shape, + const float* bias_data, float* scratch_data, float* state_data, + const RuntimeShape& output_shape, float* output_data) { const int rank = params->rank; - const int batch_size = input->dims->data[0]; - const int input_size = input->dims->data[1]; - const int num_filters = weights_feature->dims->data[0]; + const int batch_size = input_shape.Dims(0); + const int input_size = input_shape.Dims(1); + const int num_filters = weights_feature_shape.Dims(0); const int num_units = num_filters / rank; - const int memory_size = weights_time->dims->data[1]; - - // Raw pointers to tensor data. - const float* input_ptr = GetTensorData(input); - const float* weights_feature_ptr = GetTensorData(weights_feature); - const float* weights_time_ptr = GetTensorData(weights_time); - const float* bias_ptr = GetTensorData(bias); - - float* state_ptr = GetTensorData(state); - float* scratch_ptr = GetTensorData(scratch); - - float* output_ptr = GetTensorData(output); + const int memory_size = weights_time_shape.Dims(1); // Left shift the activation_state. // std::copy is fine for overlapping ranges if the output is outside of the // input range. (This is not true for copy_n.) - std::copy(state_ptr + 1, state_ptr + batch_size * memory_size * num_filters, - state_ptr); + std::copy(state_data + 1, state_data + batch_size * memory_size * num_filters, + state_data); // Clear scratch (the matmul is accumulative). - std::fill_n(scratch_ptr, batch_size * num_filters, 0.0f); + std::fill_n(scratch_data, batch_size * num_filters, 0.0f); // Compute conv1d(inputs, weights_feature). tensor_utils::MatrixBatchVectorMultiplyAccumulate( - weights_feature_ptr, num_filters, input_size, input_ptr, batch_size, - scratch_ptr); + weights_feature_data, num_filters, input_size, input_data, batch_size, + scratch_data); // Copy the latest activation from scratch into activation_state: // The last, i.e. (memory_size-1)th entry for each batch, and filter. for (int i = 0; i < batch_size * num_filters; ++i) { - state_ptr[i * memory_size + memory_size - 1] = scratch_ptr[i]; + state_data[i * memory_size + memory_size - 1] = scratch_data[i]; } ApplyTimeWeightsBiasAndActivation( - batch_size, memory_size, num_filters, num_units, rank, weights_time_ptr, - bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr); + batch_size, memory_size, num_filters, num_units, rank, weights_time_data, + bias_data, params->activation, state_data, scratch_data, output_data); } inline void EvalHybridSVDF( - TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input, - const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time, - const TfLiteTensor* bias, const TfLiteSVDFParams* params, - TfLiteTensor* scratch, TfLiteTensor* scaling_factors, - TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output, - TfLiteTensor* zero_points, TfLiteTensor* row_sums, bool* compute_row_sums) { + const TfLiteSVDFParams* params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& weights_feature_shape, + const int8_t* weights_feature_data, const float weights_feature_scale, + const RuntimeShape& weights_time_shape, const float* weights_time_data, + const RuntimeShape& bias_shape, const float* bias_data, float* scratch, + float* scaling_factors, int8_t* quantized_input, float* state, + const RuntimeShape& output_shape, float* output_data, int32_t* zero_points, + int32_t* row_sums, bool* compute_row_sums) { const int rank = params->rank; - const int batch_size = input->dims->data[0]; - const int input_size = input->dims->data[1]; - const int num_filters = weights_feature->dims->data[0]; + const int batch_size = input_shape.Dims(0); + const int input_size = input_shape.Dims(1); + const int num_filters = weights_feature_shape.Dims(0); const int num_units = num_filters / rank; - const int memory_size = weights_time->dims->data[1]; - - // Raw pointers to tensor data. - const float* input_ptr = GetTensorData(input); - const int8_t* weights_feature_ptr = GetTensorData(weights_feature); - const float* weights_time_ptr = GetTensorData(weights_time); - const float* bias_ptr = GetTensorData(bias); - - int8_t* quantized_input_ptr = GetTensorData(input_quantized); - float* scaling_factors_ptr = GetTensorData(scaling_factors); - float* state_ptr = GetTensorData(state); - float* scratch_ptr = GetTensorData(scratch); - - float* output_ptr = GetTensorData(output); - - int32_t* zero_points_ptr = nullptr; - int32_t* row_sums_ptr = nullptr; - if (params->asymmetric_quantize_inputs && row_sums != nullptr) { - zero_points_ptr = GetTensorData(zero_points); - row_sums_ptr = GetTensorData(row_sums); - } - - // Initialize the weights scale. - const float weights_feature_scale = weights_feature->params.scale; + const int memory_size = weights_time_shape.Dims(1); // Left shift the activation_state. // std::copy is fine for overlapping ranges if the output is outside of the // input range. (This is not true for copy_n.) - std::copy(state_ptr + 1, state_ptr + batch_size * memory_size * num_filters, - state_ptr); + std::copy(state + 1, state + batch_size * memory_size * num_filters, state); // Clear scratch (the matmul is accumulative). - std::fill_n(scratch_ptr, batch_size * num_filters, 0.0f); + std::fill_n(scratch, batch_size * num_filters, 0.0f); - if (!tensor_utils::IsZeroVector(input_ptr, batch_size * input_size)) { + if (!tensor_utils::IsZeroVector(input_data, batch_size * input_size)) { // Quantize input from float to int8_t. - tensor_utils::BatchQuantizeFloats(input_ptr, batch_size, input_size, - quantized_input_ptr, scaling_factors_ptr, - zero_points_ptr, - params->asymmetric_quantize_inputs); + tensor_utils::BatchQuantizeFloats( + input_data, batch_size, input_size, quantized_input, scaling_factors, + zero_points, params->asymmetric_quantize_inputs); for (int b = 0; b < batch_size; ++b) { - scaling_factors_ptr[b] *= weights_feature_scale; + scaling_factors[b] *= weights_feature_scale; } // Compute conv1d(inputs, weights_feature). tensor_utils::MatrixBatchVectorMultiplyAccumulate( - weights_feature_ptr, num_filters, input_size, quantized_input_ptr, - scaling_factors_ptr, batch_size, scratch_ptr, - /*per_channel_scale=*/nullptr, zero_points_ptr, - reinterpret_cast(scratch_ptr), row_sums_ptr, compute_row_sums, + weights_feature_data, num_filters, input_size, quantized_input, + scaling_factors, batch_size, scratch, + /*per_channel_scale=*/nullptr, zero_points, + reinterpret_cast(scratch), row_sums, compute_row_sums, /*context=*/nullptr); } // Copy the latest activation from scratch into activation_state: // The last, i.e. (memory_size-1)th entry for each batch, and filter. for (int i = 0; i < batch_size * num_filters; ++i) { - state_ptr[i * memory_size + memory_size - 1] = scratch_ptr[i]; + state[i * memory_size + memory_size - 1] = scratch[i]; } // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying // time weights so that the inner loop multiplies eight elements at a time. ApplyTimeWeightsBiasAndActivation( - batch_size, memory_size, num_filters, num_units, rank, weights_time_ptr, - bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr); + batch_size, memory_size, num_filters, num_units, rank, weights_time_data, + bias_data, params->activation, state, scratch, output_data); } } // namespace reference_ops diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc index ccdc8193f09..9e0084e813d 100644 --- a/tensorflow/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -1458,9 +1458,13 @@ class LSTMIntegerOpModel : public SingleOpModel { BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, CreateLSTMOptions(builder_, ActivationFunctionType_TANH).Union()); - BuildInterpreter({}); // Input sizes are already set + BuildInterpreter(/*input_shapes=*/{}, /*num_threads=*/-1, + /*allow_fp32_relax_to_fp16=*/false, + /*apply_delegate=*/true, /*allocate_and_delegate=*/false); } + void PerformAllocateAndDelegate() { AllocateAndDelegate(true); } + void SetInputToInputWeights(const std::vector& f) { QuantizeAndPopulate(input_to_input_weights_, f); } @@ -1692,6 +1696,8 @@ TEST(IntegerLstmOpTest, NoCifg_NoPeephole_Projection_LayerNorm) { /*use_layer_norm=*/true, /*use_8x8_8_implementation=*/false, ranges, intermediates); + // Do allocate. + lstm.PerformAllocateAndDelegate(); // Set weights. lstm.SetInputToInputWeights(input_to_input_weights); @@ -1859,6 +1865,9 @@ TEST(IntegerLstmOpTest, NoCifg_Peephole_Projection_LayerNorm) { /*use_8x8_8_implementation=*/false, ranges, intermediates); + // Do allocate. + lstm.PerformAllocateAndDelegate(); + // Set weights. lstm.SetInputToInputWeights(input_to_input_weights); lstm.SetInputToCellWeights(input_to_cell_weights); @@ -2026,6 +2035,9 @@ TEST(IntegerLstmOpTest, Cifg_NoPeephole_Projection_LayerNorm_8x8_8) { /*use_8x8_8_implementation=*/true, ranges, intermediates); + // Do allocate. + lstm.PerformAllocateAndDelegate(); + // Set weights. // lstm.SetInputToInputWeights(input_to_input_weights); lstm.SetInputToCellWeights(input_to_cell_weights); diff --git a/tensorflow/lite/kernels/multinomial_test.cc b/tensorflow/lite/kernels/multinomial_test.cc index f1e3d7a039e..ee3d5826fe8 100644 --- a/tensorflow/lite/kernels/multinomial_test.cc +++ b/tensorflow/lite/kernels/multinomial_test.cc @@ -100,24 +100,6 @@ using TestTypes = TYPED_TEST_SUITE(MultinomialTest, TestTypes); TYPED_TEST(MultinomialTest, TestMultiBatch) { - std::default_random_engine rng; - std::uniform_real_distribution dist(0, 10); - std::vector results; - for (int i = 0; i < 1000; ++i) { - results.push_back(dist(rng)); - } - double sum = 0; - for (auto f : results) { - sum += f; - } - double avg = sum / results.size(); - double min = *std::min_element(results.begin(), results.end()); - double max = *std::max_element(results.begin(), results.end()); - EXPECT_GT(avg, 4.8); - EXPECT_LT(avg, 5.2); - EXPECT_LT(min, 0.5); - EXPECT_GT(max, 9.5); - const int kNumSamples = 1000; using Float = typename TestFixture::FloatType; using Int = typename TestFixture::IntegralType; diff --git a/tensorflow/lite/kernels/random_standard_normal.cc b/tensorflow/lite/kernels/random_standard_normal.cc new file mode 100644 index 00000000000..9b0b7b0b5d8 --- /dev/null +++ b/tensorflow/lite/kernels/random_standard_normal.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace random_standard_normal { + +struct OpData { + std::default_random_engine rng; +}; + +// Draws a sample from standard normal distribution. +template +TfLiteStatus RandomStandardNormalSample(std::default_random_engine& rng, + Float* output, size_t output_size) { + std::normal_distribution dist; + for (Float* it = output; it != output + output_size; ++it) { + *it = dist(rng); + } + return kTfLiteOk; +} + +TfLiteStatus RandomStandardNormalSample(TfLiteContext* context, + std::default_random_engine& rng, + TfLiteTensor* output, + size_t output_size) { + switch (output->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE_OK(context, + RandomStandardNormalSample( + rng, GetTensorData(output), output_size)); + break; + case kTfLiteFloat64: + TF_LITE_ENSURE_OK(context, + RandomStandardNormalSample( + rng, GetTensorData(output), output_size)); + break; + default: + TF_LITE_KERNEL_LOG( + context, "Unsupported output datatype for RandomStandardNormal: %s", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return new OpData(); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO(b/169611265): Handle optional seed input. + TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1); + + // Input is a shape tensor. + const TfLiteTensor* input = tflite::GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 1); + // TODO(b/169611265): Support dynamic output tensors. + TF_LITE_ENSURE(context, IsConstantTensor(input)); + + // TODO(b/169611265): Handle other input data types. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32); + + int output_dims = tflite::SizeOfDimension(input, 0); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_dims); + for (int i = 0; i < output_dims; i++) { + output_shape->data[i] = input->data.i32[i]; + } + + TfLiteTensor* output = tflite::GetOutput(context, node, 0); + // ResizeTensor takes ownership of output_shape. + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(b/169611265): Handle optional seed input. + OpData* params = reinterpret_cast(node->user_data); + TF_LITE_ENSURE(context, params != nullptr); + + TfLiteTensor* output = tflite::GetOutput(context, node, 0); + size_t output_size = tflite::NumElements(output); + + TF_LITE_ENSURE_OK(context, RandomStandardNormalSample(context, params->rng, + output, output_size)); + + return kTfLiteOk; +} + +} // namespace random_standard_normal + +TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL() { + static TfLiteRegistration r = { + random_standard_normal::Init, random_standard_normal::Free, + random_standard_normal::Prepare, random_standard_normal::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/random_standard_normal_test.cc b/tensorflow/lite/kernels/random_standard_normal_test.cc new file mode 100644 index 00000000000..88e71f27669 --- /dev/null +++ b/tensorflow/lite/kernels/random_standard_normal_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include +#include +#include "tensorflow/lite/kernels/custom_ops_register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/testing/util.h" + +namespace tflite { +namespace { + +template +tflite::TensorType GetTTEnum(); + +template <> +tflite::TensorType GetTTEnum() { + return tflite::TensorType_FLOAT32; +} + +template <> +tflite::TensorType GetTTEnum() { + return tflite::TensorType_FLOAT64; +} + +class RandomStandardNormalOpModel : public tflite::SingleOpModel { + public: + RandomStandardNormalOpModel(const std::initializer_list& input, + tflite::TensorData output) { + input_ = AddConstInput(tflite::TensorType_INT32, input, + {static_cast(input.size())}); + output_ = AddOutput(output); + SetCustomOp("RandomStandardNormal", {}, + ops::custom::Register_RANDOM_STANDARD_NORMAL); + BuildInterpreter({GetShape(input_)}); + } + + int input_; + int output_; + + int input() { return input_; } + int output() { return output_; } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } +}; + +} // namespace +} // namespace tflite + +template +class RandomStandardNormalTest : public ::testing::Test { + public: + using Float = FloatType; +}; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(RandomStandardNormalTest, TestTypes); + +TYPED_TEST(RandomStandardNormalTest, TestOutput) { + using Float = typename TestFixture::Float; + tflite::RandomStandardNormalOpModel m({1000, 50, 5}, + {tflite::GetTTEnum(), {}}); + m.Invoke(); + auto output = m.GetOutput(); + EXPECT_EQ(output.size(), 1000 * 50 * 5); + + double sum = 0; + for (auto r : output) { + sum += r; + } + double avg = sum / output.size(); + ASSERT_LT(std::abs(avg), 0.05); // Average should approximately 0. + + double sum_squared = 0; + for (auto r : output) { + sum_squared += std::pow(r - avg, 2); + } + double var = sum_squared / output.size(); + EXPECT_LT(std::abs(1 - var), 0.05); // Variance should be approximately 1. +} diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 6e15f553fdb..8d8095bdc82 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -293,7 +293,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM()); AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/register.h b/tensorflow/lite/kernels/register.h index 1a6095c7140..f40798ffd0a 100644 --- a/tensorflow/lite/kernels/register.h +++ b/tensorflow/lite/kernels/register.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_REGISTER_H_ #define TENSORFLOW_LITE_KERNELS_REGISTER_H_ -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/model.h" // Legacy. #include "tensorflow/lite/mutable_op_resolver.h" namespace tflite { diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc index c8fb46adb96..b9a5b13b477 100644 --- a/tensorflow/lite/kernels/register_ref.cc +++ b/tensorflow/lite/kernels/register_ref.cc @@ -447,7 +447,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() { AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY()); AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL_REF(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY_REF()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that diff --git a/tensorflow/lite/kernels/svdf.cc b/tensorflow/lite/kernels/svdf.cc index f3a17e1d3b5..8f5c9a86bff 100644 --- a/tensorflow/lite/kernels/svdf.cc +++ b/tensorflow/lite/kernels/svdf.cc @@ -304,9 +304,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (weights_feature->type) { case kTfLiteFloat32: { - reference_ops::EvalFloatSVDF(context, node, input, weights_feature, - weights_time, bias, params, scratch, state, - output); + reference_ops::EvalFloatSVDF( + params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(weights_feature), + GetTensorData(weights_feature), GetTensorShape(weights_time), + GetTensorData(weights_time), GetTensorShape(bias), + GetTensorData(bias), GetTensorData(scratch), + GetTensorData(state), GetTensorShape(output), + GetTensorData(output)); return kTfLiteOk; break; } @@ -346,10 +351,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { op_data->float_weights_time_initialized = true; } + int32_t* zero_points_ptr = nullptr; + int32_t* row_sums_ptr = nullptr; + if (params->asymmetric_quantize_inputs && row_sums != nullptr) { + zero_points_ptr = GetTensorData(zero_points); + row_sums_ptr = GetTensorData(row_sums); + } + reference_ops::EvalHybridSVDF( - context, node, input, weights_feature, float_weights_time, bias, - params, scratch, scaling_factors, input_quantized, state, output, - zero_points, row_sums, &op_data->compute_row_sums); + params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(weights_feature), + GetTensorData(weights_feature), + weights_feature->params.scale, GetTensorShape(float_weights_time), + GetTensorData(float_weights_time), GetTensorShape(bias), + GetTensorData(bias), GetTensorData(scratch), + GetTensorData(scaling_factors), + GetTensorData(input_quantized), GetTensorData(state), + GetTensorShape(output), GetTensorData(output), + zero_points_ptr, row_sums_ptr, &op_data->compute_row_sums); return kTfLiteOk; } else { auto* input_params = reinterpret_cast( @@ -363,9 +382,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Currently supports only ReLU. // TODO(jianlijianli): support other activations. TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu); + reference_ops::EvalIntegerSVDF( - context, node, input, weights_feature, weights_time, bias, params, - state, output, scratch, output_temp, op_data->effective_scale_1_a, + params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(weights_feature), + GetTensorData(weights_feature), + GetTensorShape(weights_time), GetTensorData(weights_time), + GetTensorShape(bias), GetTensorData(bias), + GetTensorData(state), GetTensorShape(output), + GetTensorData(output), GetTensorData(scratch), + GetTensorData(output_temp), op_data->effective_scale_1_a, op_data->effective_scale_1_b, op_data->effective_scale_2_a, op_data->effective_scale_2_b, input_params->zero_point->data[0], output_params->zero_point->data[0]); diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index d8e1e84f938..ad513e9f918 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/lite/model.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/string_type.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/tools/logging.h" diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 65690891ab3..bd8f39bb925 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -87,6 +87,7 @@ cc_library( "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@flatbuffers//:runtime_cc", ], ) @@ -108,6 +109,7 @@ cc_library( "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/micro/kernels:fully_connected", "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/schema:schema_fbs", ], @@ -327,6 +329,7 @@ tflite_micro_cc_test( ], deps = [ ":micro_framework", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) diff --git a/tensorflow/lite/micro/CONTRIBUTING.md b/tensorflow/lite/micro/CONTRIBUTING.md index 063e1df2161..4227eb42858 100644 --- a/tensorflow/lite/micro/CONTRIBUTING.md +++ b/tensorflow/lite/micro/CONTRIBUTING.md @@ -5,7 +5,6 @@ https://github.com/ekalinin/github-markdown-toc#auto-insert-and-update-toc -* [Resources](#resources) * [Contributing Guidelines](#contributing-guidelines) * [General Pull Request Guidelines](#general-pull-request-guidelines) * [Guidelines for Specific Contribution Categories](#guidelines-for-specific-contribution-categories) @@ -20,25 +19,10 @@ https://github.com/ekalinin/github-markdown-toc#auto-insert-and-update-toc * [Reviewer notes](#reviewer-notes) * [Python notes](#python-notes) - + -# Resources - -A -[TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md) -should be the primary method of getting in touch with the TensorFlow Lite Micro -(TFLM) team. - -The following resources may also be useful: - -1. SIG Micro [email group](https://groups.google.com/a/tensorflow.org/g/micro) - and - [monthly meetings](http://doc/1YHq9rmhrOUdcZnrEnVCWvd87s2wQbq4z17HbeRl-DBc). - -1. SIG Micro [gitter chat room](https://gitter.im/tensorflow/sig-micro). - # Contributing Guidelines We look forward to your contributions to the TensorFlow Lite Micro codebase and diff --git a/tensorflow/lite/micro/MEMORY.md b/tensorflow/lite/micro/MEMORY.md deleted file mode 100644 index e9e73ef9ce9..00000000000 --- a/tensorflow/lite/micro/MEMORY.md +++ /dev/null @@ -1,112 +0,0 @@ -# Memory Management in TensorFlow Lite Micro - -This document outlines how memory is managed internally by TensorFlow Lite Micro (TFLM) today. It outlines the "online" allocation strategy used by the default TFLM APIs for loading a model into a shared tensor arena. - -## Tensor Arena - -The main "working" space for TFLM allocations is inside a single `char` or `int8_t` buffer. This buffer can be managed by passing it directly into a `tflite::MicroInterpreter` constructor or through a `tflite::MicroAllocator` instance that can be passed into a `tflite::MicroInterpreter` constructor. Internally, the `tflite::MicroAllocator` classifies allocations into 3 different sections: - -* **Head** - non-persistent allocations. -* **Temporary** - short term "scoped" allocations. -* **Tail** - persistent allocations. - -The illustration below represents typical allocations in TFLM: -``` --------------------------------------------------------------------------------- -| | | | -| HEAD |<-- TEMPORARY -->| TAIL | -| | | | --------------------------------------------------------------------------------- -* Lowest Address Highest Address * -``` - -### Head Section - -This non-persistent section typically holds shared Tensor buffers. This section does not allocate small iterative chunks, it can only be set by a specific length for the entire section. - -This allocation length of this section is managed by the `tflite::GreedyMemoryPlanner`. That memory planner looks at the entire graph of a model and tries to reuse as many buffers as possible to create the smallest length for the head. The Tensor buffers for this section can be accessed via a `TfLiteEvalTensor` or `TfLiteTensor` instance on the `tflite::MicroInterpreter`. - -### Temporary Section - -This section is used to allocate "scoped" or short-term, non-guaranteed buffers. Allocations from this section start from the current end address of the head section and grow towards the tail section. An allocation chain can be reset (and must be reset before adjusting the head) and moves the current allocation start address back to the end of the head section. - - TFLM currently uses these allocations for a scope allocation of large C structs or scratch memory that is expected to be valid for at least the lifetime of a method call. This section. - -### Tail Section - -This section holds all persistent allocations used by TFLM. This section contains many random sized allocations and grows towards the end of the head section. Allocations in this section come from a variety of areas inside of TFLM. TFLM provides a [recording API](#Recording-Memory-APIs) to assist with auditing the contents of this section. - -## Recording Memory APIs - -TFLM provides simple APIs for auditing memory usage in the shared tensor arena. These APIs are opt-in and require some additional memory overhead and a working debug logging implementation [(reference implementation)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/debug_log.cc). - -A typical bare-bones TFLM interpreter setup looks as such: - -```c++ -// Buffer for the tensor arena: -size_t tensor_arena_size = 2048; -uint8_t tensor_arena[tensor_arena_size]; - -// Interpreter using the shared tensor arena above: -tflite::MicroInterpreter interpreter( - tflite::GetModel(my_model_data), ops_resolver, - tensor_arena, tensor_arena_size, error_reporter); - -// Invoke one time which will allocate internals: -if (interpreter.Invoke() != kTfLiteOk) { - TF_LITE_REPORT_ERROR(error_reporter, "Exception during invoke()!"); -} -``` - -Recording API can simply be used by including the `RecordingMicroInterpreter` class (`recording_micro_interpreter.h`) and replace `tflite::MicroInterpreter` with `tflite::RecordingMicroInterpreter`. The same call to `invoke()` is performed, but another call is made to `PrintAllocations()` which will output detailed allocation logging: - -```c++ -// Add an include to the recording API: -#include "recording_micro_interpreter.h" - -// Simply change the class name from 'MicroInterpreter' to 'RecordingMicroInterpreter': -tflite::RecoridngMicroInterpreter interpreter( - tflite::GetModel(my_model_data), ops_resolver, - tensor_arena, tensor_arena_size, error_reporter); - -// Invoke one time which will allocate internals: -if (interpreter.Invoke() != kTfLiteOk) { - TF_LITE_REPORT_ERROR(error_reporter, "Exception during invoke()!"); -} - -// Print out detailed allocation information: -interpreter.PrintAllocations(); -``` - -The output of this call will look something similar to this (output from the [memory_arena_threshold_test](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/memory_arena_threshold_test.cc#L205)): -```sh -[RecordingMicroAllocator] Arena allocation total 9568 bytes -[RecordingMicroAllocator] Arena allocation head 7744 bytes -[RecordingMicroAllocator] Arena allocation tail 1824 bytes -[RecordingMicroAllocator] 'TfLiteEvalTensor data' used 360 bytes with alignment overhead (requested 360 bytes for 15 allocations) -[RecordingMicroAllocator] 'Persistent TfLiteTensor data' used 0 bytes with alignment overhead (requested 0 bytes for 0 tensors) -[RecordingMicroAllocator] 'Persistent TfLiteTensor quantization data' used 0 bytes with alignment overhead (requested 0 bytes for 0 allocations) -[RecordingMicroAllocator] 'TfLiteTensor variable buffer data' used 0 bytes with alignment overhead (requested 0 bytes for 0 allocations) -[RecordingMicroAllocator] 'NodeAndRegistration struct' used 392 bytes with alignment overhead (requested 392 bytes for 7 NodeAndRegistration structs) -[RecordingMicroAllocator] 'Operator runtime data' used 136 bytes with alignment overhead (requested 136 bytes for 5 OpData structs) -``` - -### Allocation Section Details - -More information about each recorded allocation section: - -* 'TfLiteEvalTensor data' - * C struct that holds the data type, dimension, and a pointer to the buffer representing the Tensor. -* 'Persistent TfLiteTensor data' - * C struct that holds more information than a `TfLiteEvalTensor` struct in the graph. - * Allocations in this bucket will only show up when accessing tensors from the accessors on `tflite::MicroInterpreter`. -* 'Persistent TfLiteTensor quantization data' - * Length of persistent quantization data assigned to persistent `TfLiteTensor` structs. - * Allocations in this bucket will only show up when accessing tensors from the accessors on `tflite::MicroInterpreter`. -* 'TfLiteTensor variable buffer data' - * Length of buffer data from a variable tensor (retains data throughout calls to `invoke()`). -* 'NodeAndRegistration struct' - * C struct that holds a `TfLiteRegistration` and `TfLiteNode` struct instance. - * Each operator in a model will contain one `NodeAndRegistration` struct. -* 'Operator runtime data' - * Persistent allocations of data cached by TFLM kernels (e.g. quantization params, multipliers, etc). diff --git a/tensorflow/lite/micro/README.md b/tensorflow/lite/micro/README.md index c12c83b98b1..c9dcc43c91a 100644 --- a/tensorflow/lite/micro/README.md +++ b/tensorflow/lite/micro/README.md @@ -1,291 +1,50 @@ + + + + +* [TensorFlow Lite for Microcontrollers](#tensorflow-lite-for-microcontrollers) +* [Getting Help and Involved](#getting-help-and-involved) +* [Additional Documentation for the TFLM Internals](#additional-documentation-for-the-tflm-internals) + + + + + # TensorFlow Lite for Microcontrollers -TensorFlow Lite for Microcontrollers is a port of TensorFlow Lite designed to run -machine learning models on microcontrollers and other devices with only kilobytes -of memory. +TensorFlow Lite for Microcontrollers is a port of TensorFlow Lite designed to +run machine learning models on microcontrollers and other devices with only +kilobytes of memory. To learn how to use the framework, visit the developer documentation at [tensorflow.org/lite/microcontrollers](https://www.tensorflow.org/lite/microcontrollers). -## Porting to a new platform +# Getting Help and Involved -The remainder of this document provides guidance on porting TensorFlow Lite for -Microcontrollers to new platforms. You should read the -[developer documentation](https://www.tensorflow.org/lite/microcontrollers) -first. +A +[TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md) +should be the primary method of getting in touch with the TensorFlow Lite Micro +(TFLM) team. -### Requirements +The following resources may also be useful: -Since the core neural network operations are pure arithmetic, and don't require -any I/O or other system-specific functionality, the code doesn't have to have -many dependencies. We've tried to enforce this, so that it's as easy as possible -to get TensorFlow Lite Micro running even on 'bare metal' systems without an OS. -Here are the core requirements that a platform needs to run the framework: +1. SIG Micro [email group](https://groups.google.com/a/tensorflow.org/g/micro) + and + [monthly meetings](http://doc/1YHq9rmhrOUdcZnrEnVCWvd87s2wQbq4z17HbeRl-DBc). -- C/C++ compiler capable of C++11 compatibility. This is probably the most - restrictive of the requirements, since C++11 is not as widely adopted in the - embedded world as it is elsewhere. We made the decision to require it since - one of the main goals of TFL Micro is to share as much code as possible with - the wider TensorFlow codebase, and since that relies on C++11 features, we - need compatibility to achieve it. We only use a small, sane, subset of C++ - though, so don't worry about having to deal with template metaprogramming or - similar challenges! +1. SIG Micro [gitter chat room](https://gitter.im/tensorflow/sig-micro). -- Debug logging. The core network operations don't need any I/O functions, but - to be able to run tests and tell if they've worked as expected, the - framework needs some way to write out a string to some kind of debug - console. This will vary from system to system, for example on Linux it could - just be `fprintf(stderr, debug_string)` whereas an embedded device might - write the string out to a specified UART. As long as there's some mechanism - for outputting debug strings, you should be able to use TFL Micro on that - platform. +If you are interested in contributing code to TensorFlow Lite for +Microcontrollers then please read our [contributions guide](CONTRIBUTING.md). -- Math library. The C standard `libm.a` library is needed to handle some of - the mathematical operations used to calculate neural network results. +# Additional Documentation -- Global variable initialization. We do use a pattern of relying on global - variables being set before `main()` is run in some places, so you'll need to - make sure your compiler toolchain supports this. +For developers that are interested in more details of the internals of the +project, we have additional documentation in the [docs](docs/) folder. -And that's it! You may be wondering about some other common requirements that -are needed by a lot of non-embedded software, so here's a brief list of things -that aren't necessary to get started with TFL Micro on a new platform: - -- Operating system. Since the only platform-specific function we need is - `DebugLog()`, there's no requirement for any kind of Posix or similar - functionality around files, processes, or threads. - -- C or C++ standard libraries. The framework tries to avoid relying on any - standard library functions that require linker-time support. This includes - things like string functions, but still allows us to use headers like - `stdtypes.h` which typically just define constants and typedefs. - Unfortunately this distinction isn't officially defined by any standard, so - it's possible that different toolchains may decide to require linked code - even for the subset we use, but in practice we've found it's usually a - pretty obvious decision and stable over platforms and toolchains. - -- Dynamic memory allocation. All the TFL Micro code avoids dynamic memory - allocation, instead relying on local variables on the stack in most cases, - or global variables for a few situations. These are all fixed-size, which - can mean some compile-time configuration to ensure there's enough space for - particular networks, but does avoid any need for a heap and the - implementation of `malloc\new` on a platform. - -- Floating point. Eight-bit integer arithmetic is enough for inference on many - networks, so if a model sticks to these kind of quantized operations, no - floating point instructions should be required or executed by the framework. - -### Getting started - -We recommend that you start trying to compile and run one of the simplest tests -in the framework as your first step. The full TensorFlow codebase can seem -overwhelming to work with at first, so instead you can begin with a collection -of self-contained project folders that only include the source files needed for -a particular test or executable. You can find a set of pre-generated projects -[here](https://drive.google.com/open?id=1cawEQAkqquK_SO4crReDYqf_v7yAwOY8). - -As mentioned above, the one function you will need to implement for a completely -new platform is debug logging. If your device is just a variation on an existing -platform you may be able to reuse code that's already been written. To -understand what's available, begin with the default reference implementation at -[tensorflow/lite/micro/debug_log.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/debug_log.cc), -which uses fprintf and stderr. If your platform has this level of support for -the C standard library in its toolchain, then you can just reuse this. -Otherwise, you'll need to do some research into how your platform and device can -communicate logging statements to the outside world. As another example, take a -look at -[the Mbed version of `DebugLog()`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/mbed/debug_log.cc), -which creates a UART object and uses it to output strings to the host's console -if it's connected. - -Begin by navigating to the micro_error_reporter_test folder in the pregenerated -projects you downloaded. Inside here, you'll see a set of folders containing all -the source code you need. If you look through them, you should find a total of -around 60 C or C++ files that compiled together will create the test executable. -There's an example makefile in the directory that lists all of the source files -and include paths for the headers. If you're building on a Linux or MacOS host -system, you may just be able to reuse that same makefile to cross-compile for -your system, as long as you swap out the `CC` and `CXX` variables from their -defaults, to point to your cross compiler instead (for example -`arm-none-eabi-gcc` or `riscv64-unknown-elf-gcc`). Otherwise, set up a project -in the build system you are using. It should hopefully be fairly -straightforward, since all of the source files in the folder need to be -compiled, so on many IDEs you can just drag the whole lot in. Then you need to -make sure that C++11 compatibility is turned on, and that the right include -paths (as mentioned in the makefile) have been added. - -You'll see the default `DebugLog()` implementation in -'tensorflow/lite/micro/debug_log.cc' inside the -micro_error_reporter_test folder. Modify that file to add the right -implementation for your platform, and then you should be able to build the set -of files into an executable. Transfer that executable to your target device (for -example by flashing it), and then try running it. You should see output that -looks something like this: - -``` -Number: 42 -Badly-formed format string -Another badly-formed format string -~~ALL TESTS PASSED~~~ -``` - -If not, you'll need to debug what went wrong, but hopefully with this small -starting project it should be manageable. - -### Troubleshooting - -When we've been porting to new platforms, it's often been hard to figure out -some of the fundamentals like linker settings and other toolchain setup flags. -If you are having trouble, see if you can find a simple example program for your -platform, like one that just blinks an LED. If you're able to build and run that -successfully, then start to swap in parts of the TF Lite Micro codebase to that -working project, taking it a step at a time and ensuring it's still working -after every change. For example, a first step might be to paste in your -`DebugLog()` implementation and call `DebugLog("Hello World!")` from the main -function. - -Another common problem on embedded platforms is the stack size being too small. -Mbed defaults to 4KB for the main thread's stack, which is too small for most -models since TensorFlow Lite allocates buffers and other data structures that -require more memory. The exact size will depend on which model you're running, -but try increasing it if you are running into strange corruption issues that -might be related to stack overwriting. - -### Optimizing for your platform - -The default reference implementations in TensorFlow Lite Micro are written to be -portable and easy to understand, not fast, so you'll want to replace performance -critical parts of the code with versions specifically tailored to your -architecture. The framework has been designed with this in mind, and we hope the -combination of small modules and many tests makes it as straightforward as -possible to swap in your own code a piece at a time, ensuring you have a working -version at every step. To write specialized implementations for a platform, it's -useful to understand how optional components are handled inside the build -system. - -### Code module organization - -We have adopted a system of small modules with platform-specific implementations -to help with portability. Every module is just a standard `.h` header file -containing the interface (either functions or a class), with an accompanying -reference implementation in a `.cc` with the same name. The source file -implements all of the code that's declared in the header. If you have a -specialized implementation, you can create a folder in the same directory as the -header and reference source, name it after your platform, and put your -implementation in a `.cc` file inside that folder. We've already seen one -example of this, where the Mbed and Bluepill versions of `DebugLog()` are inside -[mbed](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/mbed) -and -[bluepill](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/bluepill) -folders, children of the -[same directory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite//micro) -where the stdio-based -[`debug_log.cc`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/debug_log.cc) -reference implementation is found. - -The advantage of this approach is that we can automatically pick specialized -implementations based on the current build target, without having to manually -edit build files for every new platform. It allows incremental optimizations -from a always-working foundation, without cluttering the reference -implementations with a lot of variants. - -To see why we're doing this, it's worth looking at the alternatives. TensorFlow -Lite has traditionally used preprocessor macros to separate out some -platform-specific code within particular files, for example: - -``` -#ifndef USE_NEON -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#define USE_NEON -#include -#endif -``` - -There’s also a tradition in gemmlowp of using file suffixes to indicate -platform-specific versions of particular headers, with kernel_neon.h being -included by kernel.h if `USE_NEON` is defined. As a third variation, kernels are -separated out using a directory structure, with -tensorflow/lite/kernels/internal/reference containing portable implementations, -and tensorflow/lite/kernels/internal/optimized holding versions optimized for -NEON on Arm platforms. - -These approaches are hard to extend to multiple platforms. Using macros means -that platform-specific code is scattered throughout files in a hard-to-find way, -and can make following the control flow difficult since you need to understand -the macro state to trace it. For example, I temporarily introduced a bug that -disabled NEON optimizations for some kernels when I removed -tensorflow/lite/kernels/internal/common.h from their includes, without realizing -it was where USE_NEON was defined! - -It’s also tough to port to different build systems, since figuring out the right -combination of macros to use can be hard, especially since some of them are -automatically defined by the compiler, and others are only set by build scripts, -often across multiple rules. - -The approach we are using extends the file system approach that we use for -kernel implementations, but with some specific conventions: - -- For each module in TensorFlow Lite, there will be a parent directory that - contains tests, interface headers used by other modules, and portable - implementations of each part. -- Portable means that the code doesn’t include code from any libraries except - flatbuffers, or other TF Lite modules. You can include a limited subset of - standard C or C++ headers, but you can’t use any functions that require - linking against those libraries, including fprintf, etc. You can link - against functions in the standard math library, in . -- Specialized implementations are held inside subfolders of the parent - directory, named after the platform or library that they depend on. So, for - example if you had my_module/foo.cc, a version that used RISC-V extensions - would live in my_module/riscv/foo.cc. If you had a version that used the - CMSIS library, it should be in my_module/cmsis/foo.cc. -- These specialized implementations should completely replace the top-level - implementations. If this involves too much code duplication, the top-level - implementation should be split into smaller files, so only the - platform-specific code needs to be replaced. -- There is a convention about how build systems pick the right implementation - file. There will be an ordered list of 'tags' defining the preferred - implementations, and to generate the right list of source files, each module - will be examined in turn. If a subfolder with a tag’s name contains a .cc - file with the same base name as one in the parent folder, then it will - replace the parent folder’s version in the list of build files. If there are - multiple subfolders with matching tags and file names, then the tag that’s - latest in the ordered list will be chosen. This allows us to express “I’d - like generically-optimized fixed point if it’s available, but I’d prefer - something using the CMSIS library” using the list 'fixed_point cmsis'. These - tags are passed in as `TAGS=""` on the command line when you use the - main Makefile to build. -- There is an implicit “reference” tag at the start of every list, so that - it’s possible to support directory structures like the current - tensorflow/kernels/internal where portable implementations are held in a - “reference” folder that’s a sibling to the NEON-optimized folder. -- The headers for each unit in a module should remain platform-agnostic, and - be the same for all implementations. Private headers inside a sub-folder can - be used as needed, but shouldn’t be referred to by any portable code at the - top level. -- Tests should be at the parent level, with no platform-specific code. -- No platform-specific macros or #ifdef’s should be used in any portable code. - -The implementation of these rules is handled inside the Makefile, with a -[`specialize` function](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/tools/make/helper_functions.inc#L42) -that takes a list of reference source file paths as an input, and returns the -equivalent list with specialized versions of those files swapped in if they -exist. - -### Implementing more optimizations - -Clearly, getting debug logging support is only the beginning of the work you'll -need to do on a particular platform. It's very likely that you'll want to -optimize the core deep learning operations that take up the most time when -running models you care about. The good news is that the process for providing -optimized implementations is the same as the one you just went through to -provide your own logging. You'll need to identify parts of the code that are -bottlenecks, and then add specialized implementations in their own folders. -These don't need to be platform specific, they can also be broken out by which -library they rely on for example. [Here's where we do that for the CMSIS -implementation of integer fast-fourier -transforms](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/micro_speech/simple_features/simple_features_generator.cc). -This more complex case shows that you can also add helper source files alongside -the main implementation, as long as you -[mention them in the platform-specific makefile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/micro_speech/CMSIS/Makefile.inc). -You can also do things like update the list of libraries that need to be linked -in, or add include paths to required headers. +* [Benchmarks](benchmarks/README.md) +* [Memory Management](docs/memory_management.md) +* [New Platform Support](docs/new_platform_support.md) diff --git a/tensorflow/lite/micro/all_ops_resolver.cc b/tensorflow/lite/micro/all_ops_resolver.cc index d722ec146fb..0a2a0c0f7fe 100644 --- a/tensorflow/lite/micro/all_ops_resolver.cc +++ b/tensorflow/lite/micro/all_ops_resolver.cc @@ -70,6 +70,7 @@ AllOpsResolver::AllOpsResolver() { AddResizeNearestNeighbor(); AddRound(); AddRsqrt(); + AddShape(); AddSin(); AddSoftmax(); AddSplit(); diff --git a/tensorflow/lite/micro/benchmarks/BUILD b/tensorflow/lite/micro/benchmarks/BUILD index 409fa552130..f2eb0144d32 100644 --- a/tensorflow/lite/micro/benchmarks/BUILD +++ b/tensorflow/lite/micro/benchmarks/BUILD @@ -41,6 +41,7 @@ cc_binary( "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro/kernels:fully_connected", ], ) diff --git a/tensorflow/lite/micro/benchmarks/Makefile.inc b/tensorflow/lite/micro/benchmarks/Makefile.inc index d9dfba265ed..b50a51758c0 100644 --- a/tensorflow/lite/micro/benchmarks/Makefile.inc +++ b/tensorflow/lite/micro/benchmarks/Makefile.inc @@ -27,8 +27,15 @@ tensorflow/lite/micro/examples/person_detection_experimental/person_detect_model $(eval $(call microlite_test,keyword_benchmark,\ $(KEYWORD_BENCHMARK_SRCS),$(KEYWORD_BENCHMARK_HDRS))) +# We can not build the person detection benchmarks for STM32F4 due to +# insufficient flash (see issue #43743 for more details). Currently disabling +# this benchmark for STM32F4 but we can revisit this later. +ifneq ($(TARGET), stm32f4) + $(eval $(call microlite_test,person_detection_benchmark,\ $(PERSON_DETECTION_BENCHMARK_SRCS),$(PERSON_DETECTION_BENCHMARK_HDRS))) $(eval $(call microlite_test,person_detection_experimental_benchmark,\ $(PERSON_DETECTION_EXPERIMENTAL_BENCHMARK_SRCS),$(PERSON_DETECTION_EXPERIMENTAL_BENCHMARK_HDRS))) + +endif diff --git a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc index dd11892127c..815be071f1f 100644 --- a/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.h" #include "tensorflow/lite/micro/benchmarks/micro_benchmark.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_mutable_op_resolver.h" @@ -31,23 +32,48 @@ limitations under the License. namespace { -// Create an area of memory to use for input, output, and intermediate arrays. -// Align arena to 16 bytes to avoid alignment warnings on certain platforms. -constexpr int tensor_arena_size = 21 * 1024; -alignas(16) uint8_t tensor_arena[tensor_arena_size]; -// A random number generator seed to generate input values. +using KeywordBenchmarkRunner = MicroBenchmarkRunner; +using KeywordOpResolver = tflite::MicroMutableOpResolver<6>; + constexpr int kRandomSeed = 42; -static MicroBenchmarkRunner benchmark_runner( - g_keyword_scrambled_model_data, tensor_arena, tensor_arena_size); +// Create an area of memory to use for input, output, and intermediate arrays. +// Align arena to 16 bytes to avoid alignment warnings on certain platforms. +constexpr int kTensorArenaSize = 21 * 1024; +alignas(16) uint8_t tensor_arena[kTensorArenaSize]; + +uint8_t benchmark_runner_buffer[sizeof(KeywordBenchmarkRunner)]; +uint8_t op_resolver_buffer[sizeof(KeywordOpResolver)]; +KeywordBenchmarkRunner* benchmark_runner = nullptr; + +// Initialize benchmark runner instance explicitly to avoid global init order +// issues on Sparkfun. Use new since static variables within a method +// are automatically surrounded by locking, which breaks bluepill and stm32f4. +void CreateBenchmarkRunner() { + // We allocate the KeywordOpResolver from a global buffer because the object's + // lifetime must exceed that of the KeywordBenchmarkRunner object. + KeywordOpResolver* op_resolver = new (op_resolver_buffer) KeywordOpResolver(); + op_resolver->AddDequantize(); + op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8()); + op_resolver->AddQuantize(); + op_resolver->AddSoftmax(); + op_resolver->AddSvdf(); + + benchmark_runner = new (benchmark_runner_buffer) + KeywordBenchmarkRunner(g_keyword_scrambled_model_data, op_resolver, + tensor_arena, kTensorArenaSize); +} // Initializes keyword runner and sets random inputs. -void InitializeKeywordRunner() { benchmark_runner.SetRandomInput(kRandomSeed); } +void InitializeKeywordRunner() { + CreateBenchmarkRunner(); + benchmark_runner->SetRandomInput(kRandomSeed); +} // This method assumes InitializeKeywordRunner has already been run. void KeywordRunNIerations(int iterations) { for (int i = 0; i < iterations; i++) { - benchmark_runner.RunSingleIteration(); + benchmark_runner->RunSingleIteration(); } } diff --git a/tensorflow/lite/micro/benchmarks/micro_benchmark.h b/tensorflow/lite/micro/benchmarks/micro_benchmark.h index 2c7390b9cd4..83b5cbbdd5c 100644 --- a/tensorflow/lite/micro/benchmarks/micro_benchmark.h +++ b/tensorflow/lite/micro/benchmarks/micro_benchmark.h @@ -18,9 +18,9 @@ limitations under the License. #include -#include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" +#include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/micro/micro_time.h" namespace micro_benchmark { @@ -43,31 +43,34 @@ extern tflite::ErrorReporter* reporter; return 0; \ } -#define TF_LITE_MICRO_BENCHMARK(func) \ - if (tflite::ticks_per_second() == 0) { \ - TF_LITE_REPORT_ERROR(micro_benchmark::reporter, \ - "no timer implementation found"); \ - return 0; \ - } \ - start_ticks = tflite::GetCurrentTimeTicks(); \ - func; \ - duration_ticks = tflite::GetCurrentTimeTicks() - start_ticks; \ - if (duration_ticks > INT_MAX / 1000) { \ - duration_ms = duration_ticks / (tflite::ticks_per_second() / 1000); \ - } else { \ - duration_ms = (duration_ticks * 1000) / tflite::ticks_per_second(); \ - } \ - TF_LITE_REPORT_ERROR(micro_benchmark::reporter, "%s took %d ticks (%d ms)", \ - #func, duration_ticks, duration_ms); +#define TF_LITE_MICRO_BENCHMARK(func) \ + if (tflite::ticks_per_second() == 0) { \ + TF_LITE_REPORT_ERROR(micro_benchmark::reporter, \ + "no timer implementation found"); \ + return 0; \ + } \ + start_ticks = tflite::GetCurrentTimeTicks(); \ + func; \ + duration_ticks = tflite::GetCurrentTimeTicks() - start_ticks; \ + if (duration_ticks > INT_MAX / 1000) { \ + duration_ms = duration_ticks / (tflite::ticks_per_second() / 1000); \ + } else { \ + duration_ms = (duration_ticks * 1000) / tflite::ticks_per_second(); \ + } \ + micro_benchmark::reporter->Report("%s took %d ticks (%d ms)", #func, \ + duration_ticks, duration_ms); template class MicroBenchmarkRunner { public: - MicroBenchmarkRunner(const uint8_t* model, uint8_t* tensor_arena, - int tensor_arena_size) + // The lifetimes of model, op_resolver and tensor_arena must exceed that of + // the created MicroBenchmarkRunner object. + MicroBenchmarkRunner(const uint8_t* model, + const tflite::MicroOpResolver* op_resolver, + uint8_t* tensor_arena, int tensor_arena_size) : model_(tflite::GetModel(model)), reporter_(µ_reporter_), - interpreter_(model_, resolver_, tensor_arena, tensor_arena_size, + interpreter_(model_, *op_resolver, tensor_arena, tensor_arena_size, reporter_) { interpreter_.AllocateTensors(); } @@ -109,7 +112,6 @@ class MicroBenchmarkRunner { const tflite::Model* model_; tflite::MicroErrorReporter micro_reporter_; tflite::ErrorReporter* reporter_; - tflite::AllOpsResolver resolver_; tflite::MicroInterpreter interpreter_; }; diff --git a/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc b/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc index 64759145eda..687440ac9fb 100644 --- a/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/benchmarks/micro_benchmark.h" #include "tensorflow/lite/micro/examples/person_detection/model_settings.h" #include "tensorflow/lite/micro/examples/person_detection/no_person_image_data.h" @@ -21,7 +22,6 @@ limitations under the License. #include "tensorflow/lite/micro/examples/person_detection/person_image_data.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" -#include "tensorflow/lite/micro/micro_mutable_op_resolver.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" @@ -34,35 +34,54 @@ limitations under the License. namespace { +using PersonDetectionOpResolver = tflite::AllOpsResolver; +using PersonDetectionBenchmarkRunner = MicroBenchmarkRunner; + +constexpr int kRandomSeed = 42; + // Create an area of memory to use for input, output, and intermediate arrays. // Align arena to 16 bytes to avoid alignment warnings on certain platforms. constexpr int kTensorArenaSize = 95 * 1024; -constexpr int kRandomSeed = 42; alignas(16) uint8_t tensor_arena[kTensorArenaSize]; -static MicroBenchmarkRunner benchmark_runner( - g_person_detect_model_data, tensor_arena, kTensorArenaSize); +uint8_t op_resolver_buffer[sizeof(PersonDetectionOpResolver)]; +uint8_t benchmark_runner_buffer[sizeof(PersonDetectionBenchmarkRunner)]; +PersonDetectionBenchmarkRunner* benchmark_runner = nullptr; + +// Initialize benchmark runner instance explicitly to avoid global init order +// issues on Sparkfun. Use new since static variables within a method +// are automatically surrounded by locking, which breaks bluepill and stm32f4. +void CreateBenchmarkRunner() { + // We allocate PersonDetectionOpResolver from a global buffer because the + // object's lifetime must exceed that of the PersonDetectionBenchmarkRunner + // object. + benchmark_runner = new (benchmark_runner_buffer) + PersonDetectionBenchmarkRunner(g_person_detect_model_data, + new (op_resolver_buffer) + PersonDetectionOpResolver(), + tensor_arena, kTensorArenaSize); +} void PersonDetectionTenIterationsWithRandomInput() { - benchmark_runner.SetRandomInput(kRandomSeed); + benchmark_runner->SetRandomInput(kRandomSeed); for (int i = 0; i < 10; i++) { - benchmark_runner.RunSingleIteration(); + benchmark_runner->RunSingleIteration(); } } void PersonDetectionTenIerationsWithPerson() { // TODO(b/152644476): Add a way to run more than a single deterministic input. - benchmark_runner.SetInput(g_person_data); + benchmark_runner->SetInput(g_person_data); for (int i = 0; i < 10; i++) { - benchmark_runner.RunSingleIteration(); + benchmark_runner->RunSingleIteration(); } } void PersonDetectionTenIerationsWithoutPerson() { // TODO(b/152644476): Add a way to run more than a single deterministic input. - benchmark_runner.SetInput(g_no_person_data); + benchmark_runner->SetInput(g_no_person_data); for (int i = 0; i < 10; i++) { - benchmark_runner.RunSingleIteration(); + benchmark_runner->RunSingleIteration(); } } @@ -70,6 +89,7 @@ void PersonDetectionTenIerationsWithoutPerson() { TF_LITE_MICRO_BENCHMARKS_BEGIN +TF_LITE_MICRO_BENCHMARK(CreateBenchmarkRunner()); TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIterationsWithRandomInput()); TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithPerson()); TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithoutPerson()); diff --git a/tensorflow/lite/micro/benchmarks/person_detection_experimental_benchmark.cc b/tensorflow/lite/micro/benchmarks/person_detection_experimental_benchmark.cc index 8e6be2755fa..68850c90188 100644 --- a/tensorflow/lite/micro/benchmarks/person_detection_experimental_benchmark.cc +++ b/tensorflow/lite/micro/benchmarks/person_detection_experimental_benchmark.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/benchmarks/micro_benchmark.h" #include "tensorflow/lite/micro/examples/person_detection_experimental/model_settings.h" #include "tensorflow/lite/micro/examples/person_detection_experimental/no_person_image_data.h" @@ -21,7 +22,6 @@ limitations under the License. #include "tensorflow/lite/micro/examples/person_detection_experimental/person_image_data.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" -#include "tensorflow/lite/micro/micro_mutable_op_resolver.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" @@ -34,30 +34,49 @@ limitations under the License. namespace { +using PersonDetectionExperimentalOpResolver = tflite::AllOpsResolver; +using PersonDetectionExperimentalBenchmarkRunner = MicroBenchmarkRunner; + // Create an area of memory to use for input, output, and intermediate arrays. // Align arena to 16 bytes to avoid alignment warnings on certain platforms. -constexpr int tensor_arena_size = 135 * 1024; -alignas(16) uint8_t tensor_arena[tensor_arena_size]; +constexpr int kTensorArenaSize = 135 * 1024; +alignas(16) uint8_t tensor_arena[kTensorArenaSize]; -static MicroBenchmarkRunner benchmark_runner(g_person_detect_model_data, - tensor_arena, - tensor_arena_size); +uint8_t op_resolver_buffer[sizeof(PersonDetectionExperimentalOpResolver)]; +uint8_t + benchmark_runner_buffer[sizeof(PersonDetectionExperimentalBenchmarkRunner)]; +PersonDetectionExperimentalBenchmarkRunner* benchmark_runner = nullptr; + +// Initialize benchmark runner instance explicitly to avoid global init order +// issues on Sparkfun. Use new since static variables within a method +// are automatically surrounded by locking, which breaks bluepill and stm32f4. +void CreateBenchmarkRunner() { + // We allocate PersonDetectionExperimentalOpResolver from a global buffer + // because the object's lifetime must exceed that of the + // PersonDetectionBenchmarkRunner object. + benchmark_runner = + new (benchmark_runner_buffer) PersonDetectionExperimentalBenchmarkRunner( + g_person_detect_model_data, + new (op_resolver_buffer) PersonDetectionExperimentalOpResolver(), + tensor_arena, kTensorArenaSize); +} void InitializeBenchmarkRunner() { - benchmark_runner.SetInput(reinterpret_cast(g_person_data)); + CreateBenchmarkRunner(); + benchmark_runner->SetInput(reinterpret_cast(g_person_data)); } void PersonDetectionTenIerationsWithPerson() { - benchmark_runner.SetInput(reinterpret_cast(g_person_data)); + benchmark_runner->SetInput(reinterpret_cast(g_person_data)); for (int i = 0; i < 10; i++) { - benchmark_runner.RunSingleIteration(); + benchmark_runner->RunSingleIteration(); } } void PersonDetectionTenIerationsWithoutPerson() { - benchmark_runner.SetInput(reinterpret_cast(g_no_person_data)); + benchmark_runner->SetInput(reinterpret_cast(g_no_person_data)); for (int i = 0; i < 10; i++) { - benchmark_runner.RunSingleIteration(); + benchmark_runner->RunSingleIteration(); } } @@ -66,7 +85,7 @@ void PersonDetectionTenIerationsWithoutPerson() { TF_LITE_MICRO_BENCHMARKS_BEGIN TF_LITE_MICRO_BENCHMARK(InitializeBenchmarkRunner()); -TF_LITE_MICRO_BENCHMARK(benchmark_runner.RunSingleIteration()); +TF_LITE_MICRO_BENCHMARK(benchmark_runner->RunSingleIteration()); TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithPerson()); TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithoutPerson()); diff --git a/tensorflow/lite/micro/cortex_m_gcc_generic/README.md b/tensorflow/lite/micro/cortex_m_gcc_generic/README.md new file mode 100644 index 00000000000..4d5f85b24ea --- /dev/null +++ b/tensorflow/lite/micro/cortex_m_gcc_generic/README.md @@ -0,0 +1,12 @@ +# Generic Cortex-Mx customizations + +The customization requires a definition where the debug log goes to. The purpose +of the generic Cortex-Mx target is to generate a TFLM library file for use in +application projects outside of this repo. As the chip HAL and the board +specific layer are only defined in the application project, the TFLM library +cannot write the debug log anywhere. Instead, we allow the application layer to +register a callback function for writing the TFLM kernel debug log. + +# Usage + +See debug_log_callback.h diff --git a/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log.cc b/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log.cc new file mode 100644 index 00000000000..fce512e199b --- /dev/null +++ b/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log.cc @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation for the DebugLog() function that prints to the debug logger on +// an generic cortex-m device. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#include "tensorflow/lite/micro/debug_log.h" + +#include "tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h" + +static DebugLogCallback debug_log_callback = nullptr; + +void RegisterDebugLogCallback(void (*cb)(const char* s)) { + debug_log_callback = cb; +} + +void DebugLog(const char* s) { +#ifndef TF_LITE_STRIP_ERROR_STRINGS + if (debug_log_callback != nullptr) { + debug_log_callback(s); + } +#endif +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h b/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h new file mode 100644 index 00000000000..d462c8db368 --- /dev/null +++ b/tensorflow/lite/micro/cortex_m_gcc_generic/debug_log_callback.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_ +#define TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_ + +// The application layer must implement and register a callback before calling +// the network in a way similar to +// +// void debug_log_printf(const char* s) +// { +// printf(s); +// } +// +// int main(void) +// { +// // Register callback for printing debug log +// RegisterDebugLogCallback(debug_log_printf); +// +// // now call the network +// TfLiteStatus invoke_status = interpreter->Invoke(); +// } + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef void (*DebugLogCallback)(const char* s); + +// Registers and application-specific callback for debug logging. It must be +// called before the first call to DebugLog(). +void RegisterDebugLogCallback(DebugLogCallback callback); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_MICRO_CORTEX_M_GCC_GENERIC_DEBUG_LOG_CALLBACK_H_ diff --git a/tensorflow/lite/micro/debug_log.h b/tensorflow/lite/micro/debug_log.h index 1004ab9f5db..c2840d0f4b5 100644 --- a/tensorflow/lite/micro/debug_log.h +++ b/tensorflow/lite/micro/debug_log.h @@ -15,9 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_ #define TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_ +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + // This function should be implemented by each target platform, and provide a // way for strings to be output to some text stream. For more information, see // tensorflow/lite/micro/debug_log.cc. -extern "C" void DebugLog(const char* s); +void DebugLog(const char* s); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // TENSORFLOW_LITE_MICRO_DEBUG_LOG_H_ diff --git a/tensorflow/lite/micro/docs/memory_management.md b/tensorflow/lite/micro/docs/memory_management.md new file mode 100644 index 00000000000..a936cb6d7c3 --- /dev/null +++ b/tensorflow/lite/micro/docs/memory_management.md @@ -0,0 +1,173 @@ + + + + +* [Memory Management in TensorFlow Lite Micro](#memory-management-in-tensorflow-lite-micro) + * [Tensor Arena](#tensor-arena) + * [Head Section](#head-section) + * [Temporary Section](#temporary-section) + * [Tail Section](#tail-section) + * [Recording Memory APIs](#recording-memory-apis) + * [Allocation Section Details](#allocation-section-details) + + + + + +# Memory Management in TensorFlow Lite Micro + +This document outlines how memory is managed internally by TensorFlow Lite Micro +(TFLM) today. It outlines the "online" allocation strategy used by the default +TFLM APIs for loading a model into a shared tensor arena. + +## Tensor Arena + +The main "working" space for TFLM allocations is inside a single `char` or +`int8_t` buffer. This buffer can be managed by passing it directly into a +`tflite::MicroInterpreter` constructor or through a `tflite::MicroAllocator` +instance that can be passed into a `tflite::MicroInterpreter` constructor. +Internally, the `tflite::MicroAllocator` classifies allocations into 3 different +sections: + +* **Head** - non-persistent allocations. +* **Temporary** - short term "scoped" allocations. +* **Tail** - persistent allocations. + +The illustration below represents typical allocations in TFLM: + +## ``` + +| | | | | HEAD |<-- TEMPORARY -->| TAIL | + +## | | | | + +* Lowest Address Highest Address * ``` + +### Head Section + +This non-persistent section typically holds shared Tensor buffers. This section +does not allocate small iterative chunks, it can only be set by a specific +length for the entire section. + +This allocation length of this section is managed by the +`tflite::GreedyMemoryPlanner`. That memory planner looks at the entire graph of +a model and tries to reuse as many buffers as possible to create the smallest +length for the head. The Tensor buffers for this section can be accessed via a +`TfLiteEvalTensor` or `TfLiteTensor` instance on the `tflite::MicroInterpreter`. + +### Temporary Section + +This section is used to allocate "scoped" or short-term, non-guaranteed buffers. +Allocations from this section start from the current end address of the head +section and grow towards the tail section. An allocation chain can be reset (and +must be reset before adjusting the head) and moves the current allocation start +address back to the end of the head section. + +TFLM currently uses these allocations for a scope allocation of large C structs +or scratch memory that is expected to be valid for at least the lifetime of a +method call. This section. + +### Tail Section + +This section holds all persistent allocations used by TFLM. This section +contains many random sized allocations and grows towards the end of the head +section. Allocations in this section come from a variety of areas inside of +TFLM. TFLM provides a [recording API](#Recording-Memory-APIs) to assist with +auditing the contents of this section. + +## Recording Memory APIs + +TFLM provides simple APIs for auditing memory usage in the shared tensor arena. +These APIs are opt-in and require some additional memory overhead and a working +debug logging implementation +[(reference implementation)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/debug_log.cc). + +A typical bare-bones TFLM interpreter setup looks as such: + +```c++ +// Buffer for the tensor arena: +size_t tensor_arena_size = 2048; +uint8_t tensor_arena[tensor_arena_size]; + +// Interpreter using the shared tensor arena above: +tflite::MicroInterpreter interpreter( + tflite::GetModel(my_model_data), ops_resolver, + tensor_arena, tensor_arena_size, error_reporter); + +// Invoke one time which will allocate internals: +if (interpreter.Invoke() != kTfLiteOk) { + TF_LITE_REPORT_ERROR(error_reporter, "Exception during invoke()!"); +} +``` + +Recording API can simply be used by including the `RecordingMicroInterpreter` +class (`recording_micro_interpreter.h`) and replace `tflite::MicroInterpreter` +with `tflite::RecordingMicroInterpreter`. The same call to `invoke()` is +performed, but another call is made to `PrintAllocations()` which will output +detailed allocation logging: + +```c++ +// Add an include to the recording API: +#include "recording_micro_interpreter.h" + +// Simply change the class name from 'MicroInterpreter' to 'RecordingMicroInterpreter': +tflite::RecoridngMicroInterpreter interpreter( + tflite::GetModel(my_model_data), ops_resolver, + tensor_arena, tensor_arena_size, error_reporter); + +// Invoke one time which will allocate internals: +if (interpreter.Invoke() != kTfLiteOk) { + TF_LITE_REPORT_ERROR(error_reporter, "Exception during invoke()!"); +} + +// Print out detailed allocation information: +interpreter.PrintAllocations(); +``` + +The output of this call will look something similar to this (output from the +[memory_arena_threshold_test](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/memory_arena_threshold_test.cc#L205)): +`sh [RecordingMicroAllocator] Arena allocation total 9568 bytes +[RecordingMicroAllocator] Arena allocation head 7744 bytes +[RecordingMicroAllocator] Arena allocation tail 1824 bytes +[RecordingMicroAllocator] 'TfLiteEvalTensor data' used 360 bytes with alignment +overhead (requested 360 bytes for 15 allocations) [RecordingMicroAllocator] +'Persistent TfLiteTensor data' used 0 bytes with alignment overhead (requested 0 +bytes for 0 tensors) [RecordingMicroAllocator] 'Persistent TfLiteTensor +quantization data' used 0 bytes with alignment overhead (requested 0 bytes for 0 +allocations) [RecordingMicroAllocator] 'TfLiteTensor variable buffer data' used +0 bytes with alignment overhead (requested 0 bytes for 0 allocations) +[RecordingMicroAllocator] 'NodeAndRegistration struct' used 392 bytes with +alignment overhead (requested 392 bytes for 7 NodeAndRegistration structs) +[RecordingMicroAllocator] 'Operator runtime data' used 136 bytes with alignment +overhead (requested 136 bytes for 5 OpData structs)` + +### Allocation Section Details + +More information about each recorded allocation section: + +* 'TfLiteEvalTensor data' + * C struct that holds the data type, dimension, and a pointer to the + buffer representing the Tensor. +* 'Persistent TfLiteTensor data' + * C struct that holds more information than a `TfLiteEvalTensor` struct in + the graph. + * Allocations in this bucket will only show up when accessing tensors from + the accessors on `tflite::MicroInterpreter`. +* 'Persistent TfLiteTensor quantization data' + * Length of persistent quantization data assigned to persistent + `TfLiteTensor` structs. + * Allocations in this bucket will only show up when accessing tensors from + the accessors on `tflite::MicroInterpreter`. +* 'TfLiteTensor variable buffer data' + * Length of buffer data from a variable tensor (retains data throughout + calls to `invoke()`). +* 'NodeAndRegistration struct' + * C struct that holds a `TfLiteRegistration` and `TfLiteNode` struct + instance. + * Each operator in a model will contain one `NodeAndRegistration` struct. +* 'Operator runtime data' + * Persistent allocations of data cached by TFLM kernels (e.g. quantization + params, multipliers, etc). diff --git a/tensorflow/lite/micro/docs/new_platform_support.md b/tensorflow/lite/micro/docs/new_platform_support.md new file mode 100644 index 00000000000..0e1287a87cb --- /dev/null +++ b/tensorflow/lite/micro/docs/new_platform_support.md @@ -0,0 +1,306 @@ + + + + +* [Porting to a new platform](#porting-to-a-new-platform) + * [Requirements](#requirements) + * [Getting started](#getting-started) + * [Troubleshooting](#troubleshooting) + * [Optimizing for your platform](#optimizing-for-your-platform) + * [Code module organization](#code-module-organization) + * [Implementing more optimizations](#implementing-more-optimizations) + + + + + +***Please note that we are currently pausing accepting new platforms***. Please +see our [contributions guide](../CONTRIBUTING.md) for more details and context. + +Parts of the documentation below will likely change as we start accepting new +platform support again. + +# Porting to a new platform + +The remainder of this document provides guidance on porting TensorFlow Lite for +Microcontrollers to new platforms. You should read the +[developer documentation](https://www.tensorflow.org/lite/microcontrollers) +first. + +## Requirements + +Since the core neural network operations are pure arithmetic, and don't require +any I/O or other system-specific functionality, the code doesn't have to have +many dependencies. We've tried to enforce this, so that it's as easy as possible +to get TensorFlow Lite Micro running even on 'bare metal' systems without an OS. +Here are the core requirements that a platform needs to run the framework: + +- C/C++ compiler capable of C++11 compatibility. This is probably the most + restrictive of the requirements, since C++11 is not as widely adopted in the + embedded world as it is elsewhere. We made the decision to require it since + one of the main goals of TFL Micro is to share as much code as possible with + the wider TensorFlow codebase, and since that relies on C++11 features, we + need compatibility to achieve it. We only use a small subset of C++ though, + so don't worry about having to deal with template metaprogramming or + similar challenges! + +- Debug logging. The core network operations don't need any I/O functions, but + to be able to run tests and tell if they've worked as expected, the + framework needs some way to write out a string to some kind of debug + console. This will vary from system to system, for example on Linux it could + just be `fprintf(stderr, debug_string)` whereas an embedded device might + write the string out to a specified UART. As long as there's some mechanism + for outputting debug strings, you should be able to use TFL Micro on that + platform. + +- Math library. The C standard `libm.a` library is needed to handle some of + the mathematical operations used to calculate neural network results. + +- Global variable initialization. We do use a pattern of relying on global + variables being set before `main()` is run in some places, so you'll need to + make sure your compiler toolchain supports this. + +And that's it! You may be wondering about some other common requirements that +are needed by a lot of non-embedded software, so here's a brief list of things +that aren't necessary to get started with TFL Micro on a new platform: + +- Operating system. Since the only platform-specific function we need is + `DebugLog()`, there's no requirement for any kind of Posix or similar + functionality around files, processes, or threads. + +- C or C++ standard libraries. The framework tries to avoid relying on any + standard library functions that require linker-time support. This includes + things like string functions, but still allows us to use headers like + `stdtypes.h` which typically just define constants and typedefs. + Unfortunately this distinction isn't officially defined by any standard, so + it's possible that different toolchains may decide to require linked code + even for the subset we use, but in practice we've found it's usually a + pretty obvious decision and stable over platforms and toolchains. + +- Dynamic memory allocation. All the TFL Micro code avoids dynamic memory + allocation, instead relying on local variables on the stack in most cases, + or global variables for a few situations. These are all fixed-size, which + can mean some compile-time configuration to ensure there's enough space for + particular networks, but does avoid any need for a heap and the + implementation of `malloc\new` on a platform. + +- Floating point. Eight-bit integer arithmetic is enough for inference on many + networks, so if a model sticks to these kind of quantized operations, no + floating point instructions should be required or executed by the framework. + +## Getting started + +We recommend that you start trying to compile and run one of the simplest tests +in the framework as your first step. The full TensorFlow codebase can seem +overwhelming to work with at first, so instead you can begin with a collection +of self-contained project folders that only include the source files needed for +a particular test or executable. You can find a set of pre-generated projects +[here](https://drive.google.com/open?id=1cawEQAkqquK_SO4crReDYqf_v7yAwOY8). + +As mentioned above, the one function you will need to implement for a completely +new platform is debug logging. If your device is just a variation on an existing +platform you may be able to reuse code that's already been written. To +understand what's available, begin with the default reference implementation at +[tensorflow/lite/micro/debug_log.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/debug_log.cc), +which uses fprintf and stderr. If your platform has this level of support for +the C standard library in its toolchain, then you can just reuse this. +Otherwise, you'll need to do some research into how your platform and device can +communicate logging statements to the outside world. As another example, take a +look at +[the Mbed version of `DebugLog()`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/mbed/debug_log.cc), +which creates a UART object and uses it to output strings to the host's console +if it's connected. + +Begin by navigating to the micro_error_reporter_test folder in the pregenerated +projects you downloaded. Inside here, you'll see a set of folders containing all +the source code you need. If you look through them, you should find a total of +around 60 C or C++ files that compiled together will create the test executable. +There's an example makefile in the directory that lists all of the source files +and include paths for the headers. If you're building on a Linux or MacOS host +system, you may just be able to reuse that same makefile to cross-compile for +your system, as long as you swap out the `CC` and `CXX` variables from their +defaults, to point to your cross compiler instead (for example +`arm-none-eabi-gcc` or `riscv64-unknown-elf-gcc`). Otherwise, set up a project +in the build system you are using. It should hopefully be fairly +straightforward, since all of the source files in the folder need to be +compiled, so on many IDEs you can just drag the whole lot in. Then you need to +make sure that C++11 compatibility is turned on, and that the right include +paths (as mentioned in the makefile) have been added. + +You'll see the default `DebugLog()` implementation in +'tensorflow/lite/micro/debug_log.cc' inside the micro_error_reporter_test +folder. Modify that file to add the right implementation for your platform, and +then you should be able to build the set of files into an executable. Transfer +that executable to your target device (for example by flashing it), and then try +running it. You should see output that looks something like this: + +``` +Number: 42 +Badly-formed format string +Another badly-formed format string +~~ALL TESTS PASSED~~~ +``` + +If not, you'll need to debug what went wrong, but hopefully with this small +starting project it should be manageable. + +## Troubleshooting + +When we've been porting to new platforms, it's often been hard to figure out +some of the fundamentals like linker settings and other toolchain setup flags. +If you are having trouble, see if you can find a simple example program for your +platform, like one that just blinks an LED. If you're able to build and run that +successfully, then start to swap in parts of the TF Lite Micro codebase to that +working project, taking it a step at a time and ensuring it's still working +after every change. For example, a first step might be to paste in your +`DebugLog()` implementation and call `DebugLog("Hello World!")` from the main +function. + +Another common problem on embedded platforms is the stack size being too small. +Mbed defaults to 4KB for the main thread's stack, which is too small for most +models since TensorFlow Lite allocates buffers and other data structures that +require more memory. The exact size will depend on which model you're running, +but try increasing it if you are running into strange corruption issues that +might be related to stack overwriting. + +## Optimizing for your platform + +The default reference implementations in TensorFlow Lite Micro are written to be +portable and easy to understand, not fast, so you'll want to replace performance +critical parts of the code with versions specifically tailored to your +architecture. The framework has been designed with this in mind, and we hope the +combination of small modules and many tests makes it as straightforward as +possible to swap in your own code a piece at a time, ensuring you have a working +version at every step. To write specialized implementations for a platform, it's +useful to understand how optional components are handled inside the build +system. + +## Code module organization + +We have adopted a system of small modules with platform-specific implementations +to help with portability. Every module is just a standard `.h` header file +containing the interface (either functions or a class), with an accompanying +reference implementation in a `.cc` with the same name. The source file +implements all of the code that's declared in the header. If you have a +specialized implementation, you can create a folder in the same directory as the +header and reference source, name it after your platform, and put your +implementation in a `.cc` file inside that folder. We've already seen one +example of this, where the Mbed and Bluepill versions of `DebugLog()` are inside +[mbed](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/mbed) +and +[bluepill](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/bluepill) +folders, children of the +[same directory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite//micro) +where the stdio-based +[`debug_log.cc`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/debug_log.cc) +reference implementation is found. + +The advantage of this approach is that we can automatically pick specialized +implementations based on the current build target, without having to manually +edit build files for every new platform. It allows incremental optimizations +from a always-working foundation, without cluttering the reference +implementations with a lot of variants. + +To see why we're doing this, it's worth looking at the alternatives. TensorFlow +Lite has traditionally used preprocessor macros to separate out some +platform-specific code within particular files, for example: + +``` +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#include +#endif +``` + +There’s also a tradition in gemmlowp of using file suffixes to indicate +platform-specific versions of particular headers, with kernel_neon.h being +included by kernel.h if `USE_NEON` is defined. As a third variation, kernels are +separated out using a directory structure, with +tensorflow/lite/kernels/internal/reference containing portable implementations, +and tensorflow/lite/kernels/internal/optimized holding versions optimized for +NEON on Arm platforms. + +These approaches are hard to extend to multiple platforms. Using macros means +that platform-specific code is scattered throughout files in a hard-to-find way, +and can make following the control flow difficult since you need to understand +the macro state to trace it. For example, I temporarily introduced a bug that +disabled NEON optimizations for some kernels when I removed +tensorflow/lite/kernels/internal/common.h from their includes, without realizing +it was where USE_NEON was defined! + +It’s also tough to port to different build systems, since figuring out the right +combination of macros to use can be hard, especially since some of them are +automatically defined by the compiler, and others are only set by build scripts, +often across multiple rules. + +The approach we are using extends the file system approach that we use for +kernel implementations, but with some specific conventions: + +- For each module in TensorFlow Lite, there will be a parent directory that + contains tests, interface headers used by other modules, and portable + implementations of each part. +- Portable means that the code doesn’t include code from any libraries except + flatbuffers, or other TF Lite modules. You can include a limited subset of + standard C or C++ headers, but you can’t use any functions that require + linking against those libraries, including fprintf, etc. You can link + against functions in the standard math library, in . +- Specialized implementations are held inside subfolders of the parent + directory, named after the platform or library that they depend on. So, for + example if you had my_module/foo.cc, a version that used RISC-V extensions + would live in my_module/riscv/foo.cc. If you had a version that used the + CMSIS library, it should be in my_module/cmsis/foo.cc. +- These specialized implementations should completely replace the top-level + implementations. If this involves too much code duplication, the top-level + implementation should be split into smaller files, so only the + platform-specific code needs to be replaced. +- There is a convention about how build systems pick the right implementation + file. There will be an ordered list of 'tags' defining the preferred + implementations, and to generate the right list of source files, each module + will be examined in turn. If a subfolder with a tag’s name contains a .cc + file with the same base name as one in the parent folder, then it will + replace the parent folder’s version in the list of build files. If there are + multiple subfolders with matching tags and file names, then the tag that’s + latest in the ordered list will be chosen. This allows us to express “I’d + like generically-optimized fixed point if it’s available, but I’d prefer + something using the CMSIS library” using the list 'fixed_point cmsis'. These + tags are passed in as `TAGS=""` on the command line when you use the + main Makefile to build. +- There is an implicit “reference” tag at the start of every list, so that + it’s possible to support directory structures like the current + tensorflow/kernels/internal where portable implementations are held in a + “reference” folder that’s a sibling to the NEON-optimized folder. +- The headers for each unit in a module should remain platform-agnostic, and + be the same for all implementations. Private headers inside a sub-folder can + be used as needed, but shouldn’t be referred to by any portable code at the + top level. +- Tests should be at the parent level, with no platform-specific code. +- No platform-specific macros or #ifdef’s should be used in any portable code. + +The implementation of these rules is handled inside the Makefile, with a +[`specialize` function](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/tools/make/helper_functions.inc#L42) +that takes a list of reference source file paths as an input, and returns the +equivalent list with specialized versions of those files swapped in if they +exist. + +## Implementing more optimizations + +Clearly, getting debug logging support is only the beginning of the work you'll +need to do on a particular platform. It's very likely that you'll want to +optimize the core deep learning operations that take up the most time when +running models you care about. The good news is that the process for providing +optimized implementations is the same as the one you just went through to +provide your own logging. You'll need to identify parts of the code that are +bottlenecks, and then add specialized implementations in their own folders. +These don't need to be platform specific, they can also be broken out by which +library they rely on for example. [Here's where we do that for the CMSIS +implementation of integer fast-fourier +transforms](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/micro_speech/simple_features/simple_features_generator.cc). +This more complex case shows that you can also add helper source files alongside +the main implementation, as long as you +[mention them in the platform-specific makefile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/examples/micro_speech/CMSIS/Makefile.inc). +You can also do things like update the list of libraries that need to be linked +in, or add include paths to required headers. diff --git a/tensorflow/lite/micro/examples/hello_world/output_handler_test.cc b/tensorflow/lite/micro/examples/hello_world/output_handler_test.cc index 206113d1427..db529832ef9 100644 --- a/tensorflow/lite/micro/examples/hello_world/output_handler_test.cc +++ b/tensorflow/lite/micro/examples/hello_world/output_handler_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/lite/micro/examples/hello_world/output_handler.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/examples/magic_wand/output_handler_test.cc b/tensorflow/lite/micro/examples/magic_wand/output_handler_test.cc index 133d62427a1..2c34dfc37a7 100644 --- a/tensorflow/lite/micro/examples/magic_wand/output_handler_test.cc +++ b/tensorflow/lite/micro/examples/magic_wand/output_handler_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/lite/micro/examples/magic_wand/output_handler.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/examples/micro_speech/BUILD b/tensorflow/lite/micro/examples/micro_speech/BUILD index 268ce77dee0..d5aa451b351 100644 --- a/tensorflow/lite/micro/examples/micro_speech/BUILD +++ b/tensorflow/lite/micro/examples/micro_speech/BUILD @@ -316,6 +316,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) diff --git a/tensorflow/lite/micro/examples/micro_speech/command_responder_test.cc b/tensorflow/lite/micro/examples/micro_speech/command_responder_test.cc index 818b0840d08..99ff2d1d9ec 100644 --- a/tensorflow/lite/micro/examples/micro_speech/command_responder_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/command_responder_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/lite/micro/examples/micro_speech/command_responder.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc b/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc index 089da9173c7..805ff40820f 100644 --- a/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/lite/micro/examples/micro_speech/recognize_commands.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc b/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc index 563500f2115..1d945420b9a 100644 --- a/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc +++ b/tensorflow/lite/micro/examples/network_tester/network_tester_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/micro/examples/person_detection/detection_responder_test.cc b/tensorflow/lite/micro/examples/person_detection/detection_responder_test.cc index 1714079f39a..7292a4a1ade 100644 --- a/tensorflow/lite/micro/examples/person_detection/detection_responder_test.cc +++ b/tensorflow/lite/micro/examples/person_detection/detection_responder_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/lite/micro/examples/person_detection/detection_responder.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/detection_responder_test.cc b/tensorflow/lite/micro/examples/person_detection_experimental/detection_responder_test.cc index 3d86baa9d59..2cf6f68d7f4 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/detection_responder_test.cc +++ b/tensorflow/lite/micro/examples/person_detection_experimental/detection_responder_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/lite/micro/examples/person_detection_experimental/detection_responder.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 18a19ab599c..6eaf3549b32 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -19,11 +19,82 @@ config_setting( define_values = {"tflm_build": "xtensa_hifimini_staging"}, ) +package_group( + name = "micro", + packages = ["//tensorflow/lite/micro/..."], +) + package_group( name = "micro_top_level", packages = ["//tensorflow/lite/micro"], ) +cc_library( + name = "fixedpoint_utils", + hdrs = select({ + "//conditions:default": [ + ], + ":xtensa_hifimini": [ + "xtensa_hifimini/fixedpoint_utils.h", + ], + ":xtensa_hifimini_staging": [ + "xtensa_hifimini/fixedpoint_utils.h", + ], + }), + copts = micro_copts(), + deps = select({ + "//conditions:default": [], + ":xtensa_hifimini": [ + #"//third_party/xtensa/cstub64s:hifi_mini", + "//tensorflow/lite/kernels/internal:compatibility", + ], + }), +) + +cc_library( + name = "fully_connected", + srcs = select({ + "//conditions:default": [ + "fully_connected.cc", + ], + ":xtensa_hifimini": [ + "xtensa_hifimini/fully_connected.cc", + ], + ":xtensa_hifimini_staging": [ + "xtensa_hifimini_staging/fully_connected.cc", + ], + }), + hdrs = ["fully_connected.h"], + copts = micro_copts(), + visibility = [ + # Kernel variants need to be visible to the examples and benchmarks. + ":micro", + ], + deps = [ + ":activation_utils", + ":fixedpoint_utils", + ":kernel_util", + ":micro_utils", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:op_macros", + "//tensorflow/lite/kernels:padding", + "//tensorflow/lite/kernels/internal:common", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels/internal:types", + "//tensorflow/lite/micro:memory_helpers", + "//tensorflow/lite/micro:micro_utils", + ] + select({ + "//conditions:default": [], + ":xtensa_hifimini": [ + #"//third_party/xtensa/cstub64s:hifi_mini", + ], + }), +) + cc_library( name = "micro_ops", srcs = [ @@ -53,6 +124,7 @@ cc_library( "reshape.cc", "resize_nearest_neighbor.cc", "round.cc", + "shape.cc", "split.cc", "split_v.cc", "strided_slice.cc", @@ -63,7 +135,6 @@ cc_library( "//conditions:default": [ "conv.cc", "depthwise_conv.cc", - "fully_connected.cc", "quantize.cc", "softmax.cc", "svdf.cc", @@ -71,8 +142,6 @@ cc_library( ":xtensa_hifimini": [ "xtensa_hifimini/conv.cc", "xtensa_hifimini/depthwise_conv.cc", - "xtensa_hifimini/fixedpoint_utils.h", - "xtensa_hifimini/fully_connected.cc", "xtensa_hifimini/quantize.cc", "xtensa_hifimini/softmax.cc", "xtensa_hifimini/svdf.cc", @@ -85,8 +154,6 @@ cc_library( # behavior that we get with the Makefiles. "conv.cc", "depthwise_conv.cc", - "xtensa_hifimini/fixedpoint_utils.h", - "xtensa_hifimini_staging/fully_connected.cc", "xtensa_hifimini_staging/quantize.cc", "xtensa_hifimini_staging/softmax.cc", "xtensa_hifimini_staging/svdf.cc", @@ -102,6 +169,7 @@ cc_library( deps = [ ":activation_utils", ":kernel_util", + ":fixedpoint_utils", ":micro_utils", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:kernel_util", @@ -135,6 +203,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -148,6 +217,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -162,6 +232,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:tensor", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -172,13 +243,13 @@ tflite_micro_cc_test( "fully_connected_test.cc", ], deps = [ + ":fully_connected", ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", - "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -206,6 +277,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -234,6 +306,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_utils", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -247,6 +320,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -260,6 +334,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -273,6 +348,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -286,6 +362,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -299,6 +376,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -312,6 +390,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -325,6 +404,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -338,6 +418,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -351,6 +432,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -364,6 +446,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -377,6 +460,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -390,6 +474,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -404,6 +489,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -418,6 +504,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -432,6 +519,22 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + +tflite_micro_cc_test( + name = "split_v_test", + srcs = [ + "split_v_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:debug_log", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -445,6 +548,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -557,6 +661,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -570,6 +675,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -598,6 +704,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -612,6 +719,7 @@ tflite_micro_cc_test( ":micro_ops", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -625,6 +733,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -638,6 +747,7 @@ tflite_micro_cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -650,6 +760,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -662,6 +773,7 @@ tflite_micro_cc_test( "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], ) diff --git a/tensorflow/lite/micro/kernels/activations_test.cc b/tensorflow/lite/micro/kernels/activations_test.cc index db23bdec475..edbe717bedb 100644 --- a/tensorflow/lite/micro/kernels/activations_test.cc +++ b/tensorflow/lite/micro/kernels/activations_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/add_test.cc b/tensorflow/lite/micro/kernels/add_test.cc index 5ea9daee621..241dda7d090 100644 --- a/tensorflow/lite/micro/kernels/add_test.cc +++ b/tensorflow/lite/micro/kernels/add_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/arc_mli/conv.cc b/tensorflow/lite/micro/kernels/arc_mli/conv.cc index 2bdbe679087..55ef2650bef 100644 --- a/tensorflow/lite/micro/kernels/arc_mli/conv.cc +++ b/tensorflow/lite/micro/kernels/arc_mli/conv.cc @@ -30,9 +30,7 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" namespace tflite { -namespace ops { -namespace micro { -namespace conv { +namespace { constexpr int kInputTensor = 0; constexpr int kFilterTensor = 1; @@ -498,19 +496,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace conv +} // namespace TfLiteRegistration Register_CONV_2D() { - return {/*init=*/conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/conv::Prepare, - /*invoke=*/conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc index ae98b996987..d30a5308708 100644 --- a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc @@ -31,9 +31,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" namespace tflite { -namespace ops { -namespace micro { -namespace depthwise_conv { namespace { constexpr int kInputTensor = 0; @@ -127,8 +124,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -514,19 +509,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace depthwise_conv +} // namespace TfLiteRegistration Register_DEPTHWISE_CONV_2D() { - return {/*init=*/depthwise_conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/depthwise_conv::Prepare, - /*invoke=*/depthwise_conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arg_min_max_test.cc b/tensorflow/lite/micro/kernels/arg_min_max_test.cc index dfd04bf74b9..b85d59555f8 100644 --- a/tensorflow/lite/micro/kernels/arg_min_max_test.cc +++ b/tensorflow/lite/micro/kernels/arg_min_max_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/ceil_test.cc b/tensorflow/lite/micro/kernels/ceil_test.cc index 27caa507c00..a388d285db3 100644 --- a/tensorflow/lite/micro/kernels/ceil_test.cc +++ b/tensorflow/lite/micro/kernels/ceil_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/circular_buffer_test.cc b/tensorflow/lite/micro/kernels/circular_buffer_test.cc index 770f4565670..3e321ce070a 100644 --- a/tensorflow/lite/micro/kernels/circular_buffer_test.cc +++ b/tensorflow/lite/micro/kernels/circular_buffer_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc index cb55e416479..80a0a2ae748 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc @@ -28,15 +28,12 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace conv { +namespace { constexpr int kInputTensor = 0; constexpr int kFilterTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; -constexpr int kMaxChannels = 256; // Conv is quantized along dimension 0: // https://www.tensorflow.org/lite/performance/quantization_spec @@ -56,9 +53,8 @@ struct OpData { int output_shift; // Per channel output multiplier and shift. - // TODO(b/141139247): Allocate these dynamically when possible. - int32_t per_channel_output_multiplier[kMaxChannels]; - int32_t per_channel_output_shift[kMaxChannels]; + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; // The range of the fused activation layer. For example for kNone and // uint8_t these would be 0 and 255. @@ -82,10 +78,10 @@ inline PaddingType RuntimePaddingType(TfLitePadding padding) { } TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, int width, int height, - int filter_width, int filter_height, int out_width, - int out_height, const TfLiteType data_type, - OpData* data) { + const TfLiteConvParams* params, int width, + int height, int filter_width, int filter_height, + int out_width, int out_height, + const TfLiteType data_type, OpData* data) { bool has_bias = node->inputs->size == 3; // Check number of inputs/outputs TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); @@ -124,13 +120,13 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { -#if defined(__ARM_FEATURE_DSP) || defined(__ARM_FEATURE_MVE) - int32_t buf_size = 0; - TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - auto* params = reinterpret_cast(node->builtin_data); - auto* data = reinterpret_cast(node->user_data); + + int32_t buf_size = 0; + const auto params = static_cast(node->builtin_data); + OpData* data = static_cast(node->user_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); @@ -159,6 +155,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_dims.w = output->dims->data[2]; output_dims.c = output_shape.Dims(3); + // Dynamically allocate per-channel quantization parameters. + // TODO(#42883): This allocation is done even for non-int8 cases to get around + // a bug in kernel_utils.cc which incorrectly uses per_channel_output_shift in + // non-int8 cases. Protect this section with a if (input->type == kTfLiteInt8) + // when the issue is fixed. + const int num_channels = filter->dims->data[kConvQuantizedDimension]; + data->per_channel_output_multiplier = + static_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + data->per_channel_output_shift = + static_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + TF_LITE_ENSURE_STATUS(CalculateOpData( context, node, params, input_dims.w, input_dims.h, filter_dims.w, filter_dims.h, output_dims.w, output_dims.h, input->type, data)); @@ -191,7 +200,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } else { data->buffer_idx = -1; } -#endif return kTfLiteOk; } @@ -240,128 +248,122 @@ TfLiteStatus EvalQuantizedPerChannel( const OpData& data, const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output, TfLiteEvalTensor* im2col) { -#if defined(__ARM_FEATURE_DSP) || defined(__ARM_FEATURE_MVE) - - // Initialize cmsis-nn convolution parameters cmsis_nn_conv_params conv_params; - conv_params.input_offset = -data.input_zero_point; - conv_params.output_offset = data.output_zero_point; - conv_params.stride.h = params->stride_height; - conv_params.stride.w = params->stride_width; conv_params.dilation.h = params->dilation_height_factor; conv_params.dilation.w = params->dilation_width_factor; - conv_params.padding.h = data.padding.height; - conv_params.padding.w = data.padding.width; - conv_params.activation.min = data.output_activation_min; - conv_params.activation.max = data.output_activation_max; + // TODO(#43557) Remove checks for dilation and call to reference + // implementation when dilation is supported in the optimized implementation + // by CMSIS-NN. + if (conv_params.dilation.h == 1 && conv_params.dilation.w == 1) { + // Initialize cmsis-nn convolution parameters + conv_params.input_offset = -data.input_zero_point; + conv_params.output_offset = data.output_zero_point; + conv_params.stride.h = params->stride_height; + conv_params.stride.w = params->stride_width; + conv_params.padding.h = data.padding.height; + conv_params.padding.w = data.padding.width; + conv_params.activation.min = data.output_activation_min; + conv_params.activation.max = data.output_activation_max; - // Initialize cmsis-nn per channel quantization parameters - cmsis_nn_per_channel_quant_params quant_params; - quant_params.multiplier = - const_cast(data.per_channel_output_multiplier); - quant_params.shift = const_cast(data.per_channel_output_shift); + // Initialize cmsis-nn per channel quantization parameters + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = + const_cast(data.per_channel_output_multiplier); + quant_params.shift = const_cast(data.per_channel_output_shift); - RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter); - RuntimeShape input_shape = tflite::micro::GetTensorShape(input); - RuntimeShape output_shape = tflite::micro::GetTensorShape(output); - RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias); + RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter); + RuntimeShape input_shape = tflite::micro::GetTensorShape(input); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias); - // Consistency check. - TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max); - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - const int batch_size = MatchingDim(input_shape, 0, output_shape, 0); - const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); - const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); - if (tflite::micro::GetTensorData(bias)) { - TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); - } + // Consistency check. + TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batch_size = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (tflite::micro::GetTensorData(bias)) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } - // Initialize cmsis-nn dimensions - // Input - cmsis_nn_dims input_dims; - input_dims.n = batch_size; - input_dims.h = input_shape.Dims(1); - input_dims.w = input_shape.Dims(2); - input_dims.c = input_depth; + // Initialize cmsis-nn dimensions + // Input + cmsis_nn_dims input_dims; + input_dims.n = batch_size; + input_dims.h = input_shape.Dims(1); + input_dims.w = input_shape.Dims(2); + input_dims.c = input_depth; - // Filter - cmsis_nn_dims filter_dims; - filter_dims.n = output_depth; - filter_dims.h = filter_shape.Dims(1); - filter_dims.w = filter_shape.Dims(2); - filter_dims.c = input_depth; + // Filter + cmsis_nn_dims filter_dims; + filter_dims.n = output_depth; + filter_dims.h = filter_shape.Dims(1); + filter_dims.w = filter_shape.Dims(2); + filter_dims.c = input_depth; - // Bias - cmsis_nn_dims bias_dims; - bias_dims.n = 1; - bias_dims.h = 1; - bias_dims.w = 1; - bias_dims.c = output_depth; + // Bias + cmsis_nn_dims bias_dims; + bias_dims.n = 1; + bias_dims.h = 1; + bias_dims.w = 1; + bias_dims.c = output_depth; - // Output - cmsis_nn_dims output_dims; - output_dims.n = batch_size; - output_dims.h = output_shape.Dims(1); - output_dims.w = output_shape.Dims(2); - output_dims.c = output_depth; + // Output + cmsis_nn_dims output_dims; + output_dims.n = batch_size; + output_dims.h = output_shape.Dims(1); + output_dims.w = output_shape.Dims(2); + output_dims.c = output_depth; - // Initialize cmsis-nn context - cmsis_nn_context ctx; - ctx.buf = nullptr; - ctx.size = 0; + // Initialize cmsis-nn context + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; - if (data.buffer_idx > -1) { - ctx.buf = context->GetScratchBuffer(context, data.buffer_idx); - // Note: ctx.size is currently not used in cmsis-nn. - // The buffer should be allocated in the Prepare function through - // arm_convolve_wrapper_s8_get_buffer_size - } + if (data.buffer_idx > -1) { + ctx.buf = context->GetScratchBuffer(context, data.buffer_idx); + // Note: ctx.size is currently not used in cmsis-nn. + // The buffer should be allocated in the Prepare function through + // arm_convolve_wrapper_s8_get_buffer_size + } - // arm_convolve_wrapper_s8 dispatches the optimized kernel accordingly with - // the parameters passed - arm_status status = arm_convolve_wrapper_s8( - &ctx, &conv_params, &quant_params, &input_dims, - tflite::micro::GetTensorData(input), &filter_dims, - tflite::micro::GetTensorData(filter), &bias_dims, - tflite::micro::GetTensorData(bias), &output_dims, - tflite::micro::GetTensorData(output)); - - if (status == ARM_MATH_SUCCESS) { - return kTfLiteOk; + // arm_convolve_wrapper_s8 dispatches the optimized kernel accordingly with + // the parameters passed + TFLITE_DCHECK_EQ( + arm_convolve_wrapper_s8( + &ctx, &conv_params, &quant_params, &input_dims, + tflite::micro::GetTensorData(input), &filter_dims, + tflite::micro::GetTensorData(filter), &bias_dims, + tflite::micro::GetTensorData(bias), &output_dims, + tflite::micro::GetTensorData(output)), + ARM_MATH_SUCCESS); } else { - return kTfLiteError; + // TODO(b/154032858): Investigate removing extra copies. + ConvParams op_params; + op_params.input_offset = -data.input_zero_point; + op_params.output_offset = data.output_zero_point; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.padding_values.height = data.padding.height; + op_params.padding_values.width = data.padding.width; + op_params.quantized_activation_min = data.output_activation_min; + op_params.quantized_activation_max = data.output_activation_max; + + reference_integer_ops::ConvPerChannel( + op_params, data.per_channel_output_multiplier, + data.per_channel_output_shift, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } - -#else -#pragma message( \ - "CMSIS-NN optimization for conv not available for this target. Using reference kernel.") - - ConvParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.padding_values.height = data.padding.height; - op_params.padding_values.width = data.padding.width; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - reference_integer_ops::ConvPerChannel( - op_params, data.per_channel_output_multiplier, - data.per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - -#endif return kTfLiteOk; } @@ -438,19 +440,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace conv +} // namespace TfLiteRegistration Register_CONV_2D() { - return {/*init=*/conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/conv::Prepare, - /*invoke=*/conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc index 42ac15a0837..3a59b71c985 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/depthwise_conv.cc @@ -28,16 +28,12 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace depthwise_conv { namespace { constexpr int kInputTensor = 0; constexpr int kFilterTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; -constexpr int kMaxChannels = 256; // Depthwise conv is quantized along dimension 3: // https://www.tensorflow.org/lite/performance/quantization_spec @@ -57,9 +53,9 @@ struct OpData { int output_shift; // Per channel output multiplier and shift. - // TODO: Allocate dynamic buffers when b/158779832 is resolved - int32_t per_channel_output_multiplier[kMaxChannels]; - int32_t per_channel_output_shift[kMaxChannels]; + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; + // The range of the fused activation layer. For example for kNone and // uint8_t these would be 0 and 255. int32_t output_activation_min; @@ -105,8 +101,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -131,11 +125,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int filter_height = SizeOfDimension(filter, 1); if (input->type == kTfLiteInt8) { - // Allocate memory for per-channel quantization parameters - const int num_channels = - filter->dims->data[kDepthwiseConvQuantizedDimension]; - TFLITE_DCHECK_LE(num_channels, kMaxChannels); - TF_LITE_ENSURE_EQ(context, filter->quantization.type, kTfLiteAffineQuantization); @@ -154,6 +143,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { affine_quantization->zero_point->size); } + // Allocate memory for per-channel quantization parameters + const int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; + + data->per_channel_output_multiplier = + reinterpret_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + data->per_channel_output_shift = + reinterpret_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, filter_width, filter_height, data_type, data)); @@ -460,19 +459,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace depthwise_conv +} // namespace TfLiteRegistration Register_DEPTHWISE_CONV_2D() { - return {/*init=*/depthwise_conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/depthwise_conv::Prepare, - /*invoke=*/depthwise_conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc index d34cc3cc053..9f901d436a1 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc @@ -23,12 +23,10 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace fully_connected { namespace { struct OpData { @@ -56,6 +54,11 @@ constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; +// TODO(b/169801227): This global struct is needed for the linker to drop unused +// code (for example, by using Register_FULLY_CONNECTED_INT8 instead of +// Register_FULLY_CONNECTED). +TfLiteRegistration fully_connected_registration; + TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteFusedActivation activation, TfLiteType data_type, const TfLiteTensor* input, @@ -82,8 +85,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, return status; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -333,19 +334,59 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace fully_connected +// Note that the current function names are not ideal at all (this EvalInt8 +// function internally calls EvalQuantizedInt8, and there is similar name +// aliasing in the Eval function too). We will be attempting to have a more +// descriptive naming convention but holding off on that for now, since the +// renaming might be coupled with reducing code duplication and some additional +// refactoring. +TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kWeightsTensor); + const TfLiteEvalTensor* bias = + tflite::micro::GetEvalInput(context, node, kBiasTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); -TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/fully_connected::Init, - /*free=*/nullptr, - /*prepare=*/fully_connected::Prepare, - /*invoke=*/fully_connected::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + // Checks in Prepare ensure input, output and filter types are all the same. + if (input->type != kTfLiteInt8) { + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + + return EvalQuantizedInt8(context, node, data, input, filter, bias, output); +} + +} // namespace + +TfLiteRegistration Register_FULLY_CONNECTED() { + fully_connected_registration.init = Init; + fully_connected_registration.free = nullptr; + fully_connected_registration.prepare = Prepare; + fully_connected_registration.invoke = Eval; + fully_connected_registration.profiling_string = nullptr; + fully_connected_registration.builtin_code = 0; + fully_connected_registration.custom_name = nullptr; + fully_connected_registration.version = 0; + return fully_connected_registration; +} + +TfLiteRegistration Register_FULLY_CONNECTED_INT8() { + fully_connected_registration.init = Init; + fully_connected_registration.free = nullptr; + fully_connected_registration.prepare = Prepare; + fully_connected_registration.invoke = EvalInt8; + fully_connected_registration.profiling_string = nullptr; + fully_connected_registration.builtin_code = 0; + fully_connected_registration.custom_name = nullptr; + fully_connected_registration.version = 0; + return fully_connected_registration; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc index 194bba4f26a..60e1a9a88b0 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/softmax.cc @@ -21,9 +21,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace activations { namespace { TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, @@ -68,8 +65,6 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, return kTfLiteOk; } -} // namespace - void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(SoftmaxParams)); @@ -157,19 +152,17 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } } -} // namespace activations +} // namespace TfLiteRegistration Register_SOFTMAX() { - return {/*init=*/activations::SoftmaxInit, + return {/*init=*/SoftmaxInit, /*free=*/nullptr, - /*prepare=*/activations::SoftmaxPrepare, - /*invoke=*/activations::SoftmaxEval, + /*prepare=*/SoftmaxPrepare, + /*invoke=*/SoftmaxEval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/svdf.cc b/tensorflow/lite/micro/kernels/cmsis-nn/svdf.cc new file mode 100644 index 00000000000..16358e62e10 --- /dev/null +++ b/tensorflow/lite/micro/kernels/cmsis-nn/svdf.cc @@ -0,0 +1,476 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "cmsis/CMSIS/NN/Include/arm_nn_types.h" +#include "cmsis/CMSIS/NN/Include/arm_nnfunctions.h" +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/activation_utils.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { +namespace { + +struct OpData { + int32_t effective_scale_1_a; + int32_t effective_scale_2_a; + // b versions of each scale are kept at int since the numbers are just the + // shift value - typically between [-32, 32]. + int effective_scale_1_b; + int effective_scale_2_b; + int scratch_tensor_index; + int scratch_output_tensor_index; + + // Cached tensor zero point values for quantized operations. + int input_zero_point; + int output_zero_point; +}; + +// Input tensors. +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +// This is a variable tensor, and will be modified by this op. +constexpr int kInputActivationStateTensor = 4; + +// Output tensor. +constexpr int kOutputTensor = 0; + +/** + * This version of SVDF is specific to TFLite Micro. It contains the following + * differences between the TFLite version: + * + * 1.) Scratch tensor allocation - scratch tensors must be known ahead of time + * for the Micro interpreter. + * 2.) Output dimensions - the TFLite version determines output size and runtime + * and resizes the output tensor. Micro runtime does not support tensor + * resizing. + */ +static inline void ApplyTimeWeightsBiasAndActivation( + int batch_size, int memory_size, int num_filters, int num_units, int rank, + const float* const __restrict__ weights_time_ptr, + const float* const __restrict__ bias_ptr, TfLiteFusedActivation activation, + float* const __restrict__ state_ptr, float* const __restrict__ scratch_ptr, + float* const __restrict__ output_ptr) { + // Compute matmul(activation_state, weights_time). + for (int b = 0; b < batch_size; ++b) { + // Perform batched vector dot product: + float* scratch_ptr_batch = scratch_ptr + b * num_filters; + const float* vector1_ptr = weights_time_ptr; + const float* vector2_ptr = state_ptr + b * memory_size * num_filters; + for (int i = 0; i < num_filters; ++i) { + *scratch_ptr_batch = 0.f; + for (int j = 0; j < memory_size; ++j) { + *scratch_ptr_batch += *vector1_ptr++ * *vector2_ptr++; + } + scratch_ptr_batch++; + } + } + + // Initialize output with bias if provided. + if (bias_ptr) { + // VectorBatchVectorAssign + for (int i = 0; i < batch_size; ++i) { + float* output_data = output_ptr + i * num_units; + const float* bias_data = bias_ptr; + for (int j = 0; j < num_units; ++j) { + *output_data++ = *bias_data++; + } + } + } else { + float* output_data = output_ptr; + for (int i = 0; i < batch_size * num_units; ++i) { + *output_data++ = 0.0f; + } + } + + // Reduction sum. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output_ptr + b * num_units; + float* scratch_ptr_batch = scratch_ptr + b * num_filters; + + // Reduction sum vector + for (int i = 0; i < num_units; ++i) { + for (int j = 0; j < rank; j++) { + output_ptr_batch[i] += *scratch_ptr_batch++; + } + } + } + + // Apply activation. + for (int b = 0; b < batch_size; ++b) { + float* output_ptr_batch = output_ptr + b * num_units; + for (int i = 0; i < num_units; ++i) { + *output_ptr_batch = + tflite::ops::micro::ActivationValFloat(activation, *output_ptr_batch); + ++output_ptr_batch; + } + } +} + +inline void EvalFloatSVDF( + TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* weights_feature, + const TfLiteEvalTensor* weights_time, const TfLiteEvalTensor* bias, + const TfLiteSVDFParams* params, int scratch_tensor_index, + TfLiteEvalTensor* activation_state, TfLiteEvalTensor* output) { + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + const float* weights_feature_ptr = + tflite::micro::GetTensorData(weights_feature); + const float* weights_time_ptr = + tflite::micro::GetTensorData(weights_time); + const float* bias_ptr = tflite::micro::GetTensorData(bias); + const float* input_ptr = tflite::micro::GetTensorData(input); + + float* state_ptr = tflite::micro::GetTensorData(activation_state); + + TFLITE_DCHECK(context != nullptr); + TFLITE_DCHECK(context->GetScratchBuffer != nullptr); + + float* scratch_ptr = static_cast( + context->GetScratchBuffer(context, scratch_tensor_index)); + + float* output_ptr = tflite::micro::GetTensorData(output); + + // Left shift the activation_state. + { + float* new_state_start = state_ptr; + const float* old_state_start = state_ptr + 1; + const float* old_state_end = + state_ptr + batch_size * num_filters * memory_size; + while (old_state_start != old_state_end) { + *new_state_start++ = *old_state_start++; + } + } + + // Note: no need to clear the latest activation, matmul is not accumulative. + + // Compute conv1d(inputs, weights_feature). + // The activation_state's rightmost column is used to save current cycle + // activation. This is achieved by starting at state_ptr[memory_size - 1] and + // having the stride equal to memory_size. + + // Perform batched matrix vector multiply operation: + { + const float* matrix = weights_feature_ptr; + const float* vector = input_ptr; + float* result = &state_ptr[memory_size - 1]; + float* result_in_batch = result; + for (int i = 0; i < batch_size; ++i) { + const float* matrix_ptr = matrix; + for (int j = 0; j < num_filters; ++j) { + float dot_prod = 0.0f; + const float* vector_in_batch = vector + i * input_size; + for (int k = 0; k < input_size; ++k) { + dot_prod += *matrix_ptr++ * *vector_in_batch++; + } + *result_in_batch = dot_prod; + result_in_batch += memory_size; + } + } + } + + ApplyTimeWeightsBiasAndActivation( + batch_size, memory_size, num_filters, num_units, rank, weights_time_ptr, + bias_ptr, params->activation, state_ptr, scratch_ptr, output_ptr); +} + +void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, + const TfLiteEvalTensor* input_tensor, + const TfLiteEvalTensor* weights_feature_tensor, + const TfLiteEvalTensor* weights_time_tensor, + const TfLiteEvalTensor* bias_tensor, + const TfLiteSVDFParams* params, + TfLiteEvalTensor* activation_state_tensor, + TfLiteEvalTensor* output_tensor, const OpData& data) { + cmsis_nn_dims input_dims; + input_dims.n = input_tensor->dims->data[0]; + input_dims.h = input_tensor->dims->data[1]; + + cmsis_nn_dims weights_feature_dims; + weights_feature_dims.n = weights_feature_tensor->dims->data[0]; + weights_feature_dims.h = weights_feature_tensor->dims->data[1]; + + cmsis_nn_dims weights_time_dims; + weights_time_dims.n = weights_time_tensor->dims->data[0]; + weights_time_dims.h = weights_time_tensor->dims->data[1]; + + cmsis_nn_dims bias_dims; + bias_dims.n = bias_tensor->dims->data[0]; + + cmsis_nn_dims state_dims; + state_dims.n = bias_tensor->dims->data[0]; + state_dims.h = bias_tensor->dims->data[1]; + + cmsis_nn_dims output_dims; + output_dims.n = output_tensor->dims->data[0]; + output_dims.h = output_tensor->dims->data[1]; + + cmsis_nn_svdf_params svdf_params; + svdf_params.rank = params->rank; + svdf_params.input_offset = data.input_zero_point; + svdf_params.output_offset = data.output_zero_point; + + svdf_params.input_activation.min = INT16_MIN; + svdf_params.input_activation.max = INT16_MAX; + + svdf_params.output_activation.min = INT8_MIN; + svdf_params.output_activation.max = INT8_MAX; + + cmsis_nn_per_tensor_quant_params in_quant_params; + in_quant_params.multiplier = data.effective_scale_1_a; + in_quant_params.shift = data.effective_scale_1_b; + + cmsis_nn_per_tensor_quant_params out_quant_params; + out_quant_params.multiplier = data.effective_scale_2_a; + out_quant_params.shift = data.effective_scale_2_b; + + TFLITE_DCHECK(context != nullptr); + TFLITE_DCHECK(context->GetScratchBuffer != nullptr); + + cmsis_nn_context scratch_ctx; + scratch_ctx.buf = static_cast( + context->GetScratchBuffer(context, data.scratch_tensor_index)); + + cmsis_nn_context scratch_output_ctx; + scratch_output_ctx.buf = static_cast( + context->GetScratchBuffer(context, data.scratch_output_tensor_index)); + + int8_t* output_data = tflite::micro::GetTensorData(output_tensor); + arm_svdf_s8( + &scratch_ctx, &scratch_output_ctx, &svdf_params, &in_quant_params, + &out_quant_params, &input_dims, + (int8_t*)tflite::micro::GetTensorData(input_tensor), &state_dims, + (int16_t*)tflite::micro::GetTensorData(activation_state_tensor), + &weights_feature_dims, + (int8_t*)tflite::micro::GetTensorData(weights_feature_tensor), + &weights_time_dims, + (int16_t*)tflite::micro::GetTensorData(weights_time_tensor), + &bias_dims, (int32_t*)tflite::micro::GetTensorData(bias_tensor), + &output_dims, output_data); +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->builtin_data != nullptr); + + const auto* params = static_cast(node->builtin_data); + + // Validate Tensor Inputs (dtype depends on quantization): + // [0] = Input, {2, batch_size, input_size} + // [1] = Weights Feature, {2, num_filters, input_size} + // [2] = Weights Time, {2, num_filters, memory_size} + // [3] = Bias (optional), {1, num_units} + // [4] = Activation State (variable), + // {2, batch_size, memory_size * num_filters} + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* weights_feature = + GetInput(context, node, kWeightsFeatureTensor); + const TfLiteTensor* weights_time = + GetInput(context, node, kWeightsTimeTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + const TfLiteTensor* activation_state = + GetInput(context, node, kInputActivationStateTensor); + + // Define input constants based on input tensor definition above: + const int rank = params->rank; + const int input_size = input->dims->data[1]; + const int batch_size = input->dims->data[0]; + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ENSURE_EQ(context, num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Validate Input Tensor: + TF_LITE_ENSURE(context, + input->type == kTfLiteFloat32 || input->type == kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); + + // Validate Tensor Output: + // [0] = float/int8, {2, batch_size, num_units} + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2); + TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units); + + // Validate Weights Feature Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2); + TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size); + + // Validate Weights Time Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters); + TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size); + + // Validate Optional Bias Input Tensor: + if (bias != nullptr) { + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); + } + + // Validate Activation State Input Tensor: + TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1], + memory_size * num_filters); + // Since is_variable is not part of TFLiteEvalTensor, check is_variable here. + TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true); + + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = static_cast(node->user_data); + + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8); + TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteInt16); + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteInt16); + if (bias != nullptr) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + } + + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); + + const double effective_scale_1 = static_cast( + input->params.scale * weights_feature->params.scale / + activation_state->params.scale); + const double effective_scale_2 = + static_cast(activation_state->params.scale * + weights_time->params.scale / output->params.scale); + + // TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready. + TF_LITE_ENSURE( + context, + std::abs(static_cast(bias->params.scale) - + static_cast(activation_state->params.scale * + weights_time->params.scale)) < 1e-5); + + QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a), + &(data->effective_scale_1_b)); + QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a), + &(data->effective_scale_2_b)); + + data->input_zero_point = input->params.zero_point; + data->output_zero_point = output->params.zero_point; + + TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr); + + const TfLiteStatus scratch_status = context->RequestScratchBufferInArena( + context, batch_size * num_filters * sizeof(int32_t), + &(data->scratch_tensor_index)); + TF_LITE_ENSURE_OK(context, scratch_status); + + const TfLiteStatus scratch_output_status = + context->RequestScratchBufferInArena( + context, batch_size * num_units * sizeof(int32_t), + &(data->scratch_output_tensor_index)); + TF_LITE_ENSURE_OK(context, scratch_output_status); + } else { + TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32); + if (bias != nullptr) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32); + } + TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); + + TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr); + const TfLiteStatus scratch_status = context->RequestScratchBufferInArena( + context, batch_size * num_filters * sizeof(float), + &(data->scratch_tensor_index)); + TF_LITE_ENSURE_OK(context, scratch_status); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* weights_feature = + tflite::micro::GetEvalInput(context, node, kWeightsFeatureTensor); + const TfLiteEvalTensor* weights_time = + tflite::micro::GetEvalInput(context, node, kWeightsTimeTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 5) + ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + : nullptr; + TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput( + context, node, kInputActivationStateTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + switch (weights_feature->type) { + case kTfLiteFloat32: { + EvalFloatSVDF(context, node, input, weights_feature, weights_time, bias, + params, data.scratch_tensor_index, activation_state, + output); + return kTfLiteOk; + break; + } + + case kTfLiteInt8: { + EvalIntegerSVDF(context, node, input, weights_feature, weights_time, bias, + params, activation_state, output, data); + return kTfLiteOk; + break; + } + + default: + TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", + TfLiteTypeGetName(weights_feature->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration Register_SVDF() { + return {/*init=*/Init, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/comparisons_test.cc b/tensorflow/lite/micro/kernels/comparisons_test.cc index 393f0e22187..855192baed2 100644 --- a/tensorflow/lite/micro/kernels/comparisons_test.cc +++ b/tensorflow/lite/micro/kernels/comparisons_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/concatenation_test.cc b/tensorflow/lite/micro/kernels/concatenation_test.cc index d82a804e659..d7ed2213c98 100644 --- a/tensorflow/lite/micro/kernels/concatenation_test.cc +++ b/tensorflow/lite/micro/kernels/concatenation_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc index ebeb54c64f6..9b1b1148176 100644 --- a/tensorflow/lite/micro/kernels/conv.cc +++ b/tensorflow/lite/micro/kernels/conv.cc @@ -26,9 +26,7 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace conv { +namespace { constexpr int kInputTensor = 0; constexpr int kFilterTensor = 1; @@ -146,10 +144,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Dynimically allocate per-channel quantization parameters. const int num_channels = filter->dims->data[kConvQuantizedDimension]; data->per_channel_output_multiplier = - reinterpret_cast(context->AllocatePersistentBuffer( + static_cast(context->AllocatePersistentBuffer( context, num_channels * sizeof(int32_t))); data->per_channel_output_shift = - reinterpret_cast(context->AllocatePersistentBuffer( + static_cast(context->AllocatePersistentBuffer( context, num_channels * sizeof(int32_t))); // All per-channel quantized tensors need valid zero point and scale arrays. @@ -322,19 +320,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace conv +} // namespace TfLiteRegistration Register_CONV_2D() { - return {/*init=*/conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/conv::Prepare, - /*invoke=*/conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc index be646d63659..d0d942c53c8 100644 --- a/tensorflow/lite/micro/kernels/conv_test.cc +++ b/tensorflow/lite/micro/kernels/conv_test.cc @@ -17,27 +17,32 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { namespace { // Common inputs and outputs. -static const int kInputElements = 16; +constexpr int kInputElements = 16; static const int kInputShape[] = {4, 2, 2, 4, 1}; -static const float kInputData[] = {1, 1, 1, 1, 2, 2, 2, 2, - 1, 2, 3, 4, 1, 2, 3, 4}; -static const int kFilterElements = 12; +static const float kInputData[kInputElements] = {1, 1, 1, 1, 2, 2, 2, 2, + 1, 2, 3, 4, 1, 2, 3, 4}; + +constexpr int kFilterElements = 12; static const int kFilterShape[] = {4, 3, 2, 2, 1}; -static const float kFilterData[] = {1, 2, 3, 4, -1, 1, -1, 1, -1, -1, 1, 1}; -static const int kBiasElements = 3; +static const float kFilterData[kFilterElements] = {1, 2, 3, 4, -1, 1, + -1, 1, -1, -1, 1, 1}; + +constexpr int kBiasElements = 3; static const int kBiasShape[] = {1, 3}; -static const float kBiasData[] = {1, 2, 3}; -static const int kOutputElements = 12; +static const float kBiasData[kBiasElements] = {1, 2, 3}; + +constexpr int kOutputElements = 12; static const int kOutputShape[] = {4, 2, 1, 2, 3}; -static const float kGoldenData[] = {18, 2, 5, 18, 2, 5, 17, 4, 3, 37, 4, 3}; +static const float kGoldenData[kOutputElements] = {18, 2, 5, 18, 2, 5, + 17, 4, 3, 37, 4, 3}; static TfLiteConvParams common_conv_params = { kTfLitePaddingValid, // padding @@ -59,8 +64,7 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, int outputs_array_data[] = {1, 3}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - const TfLiteRegistration registration = - tflite::ops::micro::Register_CONV_2D(); + const TfLiteRegistration registration = Register_CONV_2D(); micro::KernelRunner runner( registration, tensors, tensors_size, inputs_array, outputs_array, reinterpret_cast(conv_params), micro_test::reporter); @@ -412,41 +416,46 @@ TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannel) { // padding, stride_, activation, dilation_ TfLiteConvParams conv_params = {kTfLitePaddingValid, 1, 1, kTfLiteActNone, 1, 1}; - const int kInputShape[] = {4, 1, 2, 2, 4}; // [len,N,H,W,C] - const int kInputElements = - kInputShape[1] * kInputShape[2] * kInputShape[3] * kInputShape[4]; - float kInputData[/* kInputElements */] = {1, 1, 1, 1, 2, 2, 2, 2, - 1, 2, 3, 4, 1, 2, 3, 4}; - const int kFilterShape[] = {4, 3, 1, 1, 4}; - const int kFilterElements = - kFilterShape[1] * kFilterShape[2] * kFilterShape[3] * kFilterShape[4]; - float kFilterData[/* kFilterElements */] = {1, 2, 3, 4, -1, 1, + + constexpr int input_shape[] = {4, 1, 2, 2, 4}; // [len,N,H,W,C] + constexpr int input_elements = + input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4]; + constexpr float input_data[input_elements] = {1, 1, 1, 1, 2, 2, 2, 2, + 1, 2, 3, 4, 1, 2, 3, 4}; + + constexpr int filter_shape[] = {4, 3, 1, 1, 4}; + constexpr int filter_elements = + filter_shape[1] * filter_shape[2] * filter_shape[3] * filter_shape[4]; + const float filter_data[filter_elements] = {1, 2, 3, 4, -1, 1, -1, 1, -1, -1, 1, 1}; - const int kBiasElements = kFilterShape[1]; - const int kBiasShape[] = {1, kBiasElements}; - float kBiasData[/* kBiasElements */] = {1, 2, 3}; - const int kOutputShape[] = {4, 1, 2, 2, kBiasElements}; - const int kOutputElements = 4 * 3; - int8_t output_data[kOutputElements]; - const float kGoldenData[/* kOutputElements */] = {11, 2, 3, 21, 2, 3, - 31, 4, 7, 31, 4, 7}; + + constexpr int bias_elements = filter_shape[1]; + constexpr int bias_shape[] = {1, bias_elements}; + constexpr float bias_data[bias_elements] = {1, 2, 3}; + + constexpr int output_shape[] = {4, 1, 2, 2, bias_elements}; + constexpr int output_elements = 4 * 3; + int8_t output_data[output_elements]; + + const float golden_data[output_elements] = {11, 2, 3, 21, 2, 3, + 31, 4, 7, 31, 4, 7}; const float input_scale = 0.5f; const float output_scale = 1.0f; const int input_zero_point = 0; const int output_zero_point = 0; - int8_t input_quantized[kInputElements]; - int8_t filter_quantized[kFilterElements]; - int32_t bias_quantized[kBiasElements]; - int8_t golden_quantized[kOutputElements]; - int zero_points[kBiasElements + 1]; - float scales[kBiasElements + 1]; + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; tflite::testing::TestConvQuantizedPerChannel( - kInputShape, kInputData, input_quantized, input_scale, input_zero_point, - kFilterShape, kFilterData, filter_quantized, kBiasShape, kBiasData, - bias_quantized, scales, zero_points, kOutputShape, kGoldenData, + input_shape, input_data, input_quantized, input_scale, input_zero_point, + filter_shape, filter_data, filter_quantized, bias_shape, bias_data, + bias_quantized, scales, zero_points, output_shape, golden_data, golden_quantized, output_data, output_scale, output_zero_point, &conv_params); } @@ -456,41 +465,46 @@ TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannelRelu6) { // padding, stride_, activation, dilation_ TfLiteConvParams conv_params = {kTfLitePaddingValid, 1, 1, kTfLiteActRelu6, 1, 1}; - const int kInputShape[] = {4, 1, 2, 2, 4}; // [len,N,H,W,C] - const int kInputElements = - kInputShape[1] * kInputShape[2] * kInputShape[3] * kInputShape[4]; - float kInputData[/* kInputElements */] = {1, 1, 1, 1, 2, 2, 2, 2, - 1, 2, 3, 4, 1, 2, 3, 4}; - const int kFilterShape[] = {4, 3, 1, 1, 4}; - const int kFilterElements = - kFilterShape[1] * kFilterShape[2] * kFilterShape[3] * kFilterShape[4]; - float kFilterData[/* kFilterElements */] = {1, 2, 3, 4, -1, 1, + + constexpr int input_shape[] = {4, 1, 2, 2, 4}; // [len,N,H,W,C] + constexpr int input_elements = + input_shape[1] * input_shape[2] * input_shape[3] * input_shape[4]; + constexpr float input_data[input_elements] = {1, 1, 1, 1, 2, 2, 2, 2, + 1, 2, 3, 4, 1, 2, 3, 4}; + + constexpr int filter_shape[] = {4, 3, 1, 1, 4}; + constexpr int filter_elements = + filter_shape[1] * filter_shape[2] * filter_shape[3] * filter_shape[4]; + const float filter_data[filter_elements] = {1, 2, 3, 4, -1, 1, -1, 1, -1, -1, 1, 1}; - const int kBiasElements = kFilterShape[1]; - const int kBiasShape[] = {1, kBiasElements}; - float kBiasData[/* kBiasElements */] = {1, 2, -3}; - const int kOutputShape[] = {4, 1, 2, 2, kBiasElements}; - const int kOutputElements = 4 * 3; - int8_t output_data[kOutputElements]; - const float kGoldenData[/* kOutputElements */] = {6, 2, 0, 6, 2, 0, - 6, 4, 1, 6, 4, 1}; + + constexpr int bias_elements = filter_shape[1]; + constexpr int bias_shape[] = {1, bias_elements}; + constexpr float bias_data[bias_elements] = {1, 2, -3}; + + constexpr int output_shape[] = {4, 1, 2, 2, bias_elements}; + constexpr int output_elements = 4 * 3; + int8_t output_data[output_elements]; + + const float golden_data[output_elements] = {6, 2, 0, 6, 2, 0, + 6, 4, 1, 6, 4, 1}; const float input_scale = 0.023529f; const float output_scale = 0.023529f; const int input_zero_point = -128; const int output_zero_point = -128; - int8_t input_quantized[kInputElements]; - int8_t filter_quantized[kFilterElements]; - int32_t bias_quantized[kBiasElements]; - int8_t golden_quantized[kOutputElements]; - int zero_points[kBiasElements + 1]; - float scales[kBiasElements + 1]; + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; tflite::testing::TestConvQuantizedPerChannel( - kInputShape, kInputData, input_quantized, input_scale, input_zero_point, - kFilterShape, kFilterData, filter_quantized, kBiasShape, kBiasData, - bias_quantized, scales, zero_points, kOutputShape, kGoldenData, + input_shape, input_data, input_quantized, input_scale, input_zero_point, + filter_shape, filter_data, filter_quantized, bias_shape, bias_data, + bias_quantized, scales, zero_points, output_shape, golden_data, golden_quantized, output_data, output_scale, output_zero_point, &conv_params); } diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.cc b/tensorflow/lite/micro/kernels/depthwise_conv.cc index cfb457c2016..85b51233e90 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv.cc @@ -27,9 +27,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace depthwise_conv { namespace { constexpr int kInputTensor = 0; @@ -101,8 +98,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -315,19 +310,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace depthwise_conv +} // namespace TfLiteRegistration Register_DEPTHWISE_CONV_2D() { - return {/*init=*/depthwise_conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/depthwise_conv::Prepare, - /*invoke=*/depthwise_conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc index e16e9f893cb..358d508a564 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc +++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -47,8 +47,7 @@ TfLiteStatus ValidateDepthwiseConvGoldens( int outputs_array_data[] = {1, 3}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - const TfLiteRegistration registration = - tflite::ops::micro::Register_DEPTHWISE_CONV_2D(); + const TfLiteRegistration registration = Register_DEPTHWISE_CONV_2D(); micro::KernelRunner runner( registration, tensors, tensors_size, inputs_array, outputs_array, reinterpret_cast(conv_params), micro_test::reporter); diff --git a/tensorflow/lite/micro/kernels/dequantize_test.cc b/tensorflow/lite/micro/kernels/dequantize_test.cc index 6b499204b98..86059c63647 100644 --- a/tensorflow/lite/micro/kernels/dequantize_test.cc +++ b/tensorflow/lite/micro/kernels/dequantize_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/elementwise_test.cc b/tensorflow/lite/micro/kernels/elementwise_test.cc index b7094cbd445..1f3b49b3616 100644 --- a/tensorflow/lite/micro/kernels/elementwise_test.cc +++ b/tensorflow/lite/micro/kernels/elementwise_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/floor_test.cc b/tensorflow/lite/micro/kernels/floor_test.cc index 3a27a937b17..dc9086a07cd 100644 --- a/tensorflow/lite/micro/kernels/floor_test.cc +++ b/tensorflow/lite/micro/kernels/floor_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index 74a3f4f97bd..d3fdeacb016 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,21 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace fully_connected { namespace { struct OpData { @@ -77,8 +75,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, return status; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -241,19 +237,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace fully_connected +} // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/fully_connected::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/fully_connected::Prepare, - /*invoke=*/fully_connected::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h new file mode 100644 index 00000000000..3e6467183fe --- /dev/null +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ + +#include "tensorflow/lite/c/common.h" + +namespace tflite { + +// This is the most generic TfLiteRegistration. The actual supported types may +// still be target dependent. The only requirement is that every implementation +// (reference or optimized) must define this function. +TfLiteRegistration Register_FULLY_CONNECTED(); + +#if defined(CMSIS_NN) || defined(ARDUINO) +// The Arduino is a special case where we use the CMSIS kernels, but because of +// the current approach to building for Arduino, we do not support -DCMSIS_NN as +// part of the build. As a result, we use defined(ARDUINO) as proxy for the +// CMSIS kernels for this one special case. + +// Returns a TfLiteRegistration struct for cmsis-nn kernel variant that only +// supports int8. +TfLiteRegistration Register_FULLY_CONNECTED_INT8(); + +#else +// Note that while this block gets used for both reference and optimized kernels +// that do not have any specialized implementations, the only goal here is to +// define fallback implementation that allow reference kernels to still be used +// from applications that call a more specific kernel variant. + +inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() { + return Register_FULLY_CONNECTED(); +} + +#endif +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_ diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index d41ae4b9e8c..3f113010485 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -20,11 +20,9 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" -#include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -241,8 +239,7 @@ TfLiteStatus ValidateFullyConnectedGoldens( int outputs_array_data[] = {1, 3}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - const TfLiteRegistration registration = - ops::micro::Register_FULLY_CONNECTED(); + const TfLiteRegistration registration = Register_FULLY_CONNECTED(); micro::KernelRunner runner( registration, tensors, tensors_size, inputs_array, outputs_array, reinterpret_cast(&builtin_data), micro_test::reporter); diff --git a/tensorflow/lite/micro/kernels/hard_swish_test.cc b/tensorflow/lite/micro/kernels/hard_swish_test.cc index 83cdacc96bc..91345870023 100644 --- a/tensorflow/lite/micro/kernels/hard_swish_test.cc +++ b/tensorflow/lite/micro/kernels/hard_swish_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/l2norm_test.cc b/tensorflow/lite/micro/kernels/l2norm_test.cc index 791f9036c56..b37c6394a66 100644 --- a/tensorflow/lite/micro/kernels/l2norm_test.cc +++ b/tensorflow/lite/micro/kernels/l2norm_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -30,64 +30,32 @@ constexpr float kInputMax = 2.0; constexpr float kOutputMin = -1.0; constexpr float kOutputMax = 127.0 / 128.0; -void QuantizeInputData(const float input_data[], int length, - uint8_t* quantized_data) { - for (int i = 0; i < 6; i++) { - quantized_data[i] = tflite::testing::F2Q( - input_data[i], tflite::testing::kInputMin, tflite::testing::kInputMax); - } -} - -void QuantizeInputData(const float input_data[], int length, - int8_t* quantized_data) { - for (int i = 0; i < 6; i++) { - quantized_data[i] = tflite::testing::F2QS( - input_data[i], tflite::testing::kInputMin, tflite::testing::kInputMax); - } -} - TfLiteTensor CreateL2NormTensor(const float* data, TfLiteIntArray* dims, bool is_input) { return CreateFloatTensor(data, dims); } -TfLiteTensor CreateL2NormTensor(const uint8_t* data, TfLiteIntArray* dims, - bool is_input) { - TfLiteTensor tensor; - - if (is_input) { - tensor = CreateQuantizedTensor(data, dims, kInputMin, kInputMax); - } else { - tensor = CreateQuantizedTensor(data, dims, kOutputMin, kOutputMax); - } - - tensor.quantization.type = kTfLiteAffineQuantization; - return tensor; -} - -TfLiteTensor CreateL2NormTensor(const int8_t* data, TfLiteIntArray* dims, - bool is_input) { - TfLiteTensor tensor; - - if (is_input) { - tensor = CreateQuantizedTensor(data, dims, kInputMin, kInputMax); - } else { - tensor = CreateQuantizedTensor(data, dims, kOutputMin, kOutputMax); - } - - tensor.quantization.type = kTfLiteAffineQuantization; - return tensor; -} - template -inline float Dequantize(const T data, float scale, int32_t zero_point) { - return scale * (data - zero_point); +TfLiteTensor CreateL2NormTensor(const T* data, TfLiteIntArray* dims, + bool is_input) { + float kInputScale = ScaleFromMinMax(kInputMin, kInputMax); + int kInputZeroPoint = ZeroPointFromMinMax(kInputMin, kInputMax); + float kOutputScale = ScaleFromMinMax(kOutputMin, kOutputMax); + int kOutputZeroPoint = ZeroPointFromMinMax(kOutputMin, kOutputMax); + TfLiteTensor tensor; + if (is_input) { + tensor = CreateQuantizedTensor(data, dims, kInputScale, kInputZeroPoint); + } else { + tensor = CreateQuantizedTensor(data, dims, kOutputScale, kOutputZeroPoint); + } + + tensor.quantization.type = kTfLiteAffineQuantization; + return tensor; } template void TestL2Normalization(const int* input_dims_data, const T* input_data, - const float* expected_output_data, T* output_data, - float variance) { + const T* expected_output_data, T* output_data) { TfLiteIntArray* dims = IntArrayFromInts(input_dims_data); const int output_dims_count = ElementCount(*dims); @@ -116,25 +84,8 @@ void TestL2Normalization(const int* input_dims_data, const T* input_data, TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - // Compare the results from dequantization and expected outputs, and make - // sure the difference is within a threshold. - if (tensors[1].quantization.type != kTfLiteNoQuantization) { - TfLiteTensor* output_tensor = &tensors[1]; - int32_t zero_point = output_tensor->params.zero_point; - float scale = output_tensor->params.scale; - - for (int i = 0; i < output_dims_count; ++i) { - float output_val = Dequantize(output_data[i], scale, zero_point); - - TF_LITE_MICRO_EXPECT_LE(expected_output_data[i] - variance, output_val); - TF_LITE_MICRO_EXPECT_GE(expected_output_data[i] + variance, output_val); - } - } else { - for (int i = 0; i < output_dims_count; ++i) { - float output_val = static_cast(output_data[i]); - TF_LITE_MICRO_EXPECT_LE(expected_output_data[i] - variance, output_val); - TF_LITE_MICRO_EXPECT_GE(expected_output_data[i] + variance, output_val); - } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); } } @@ -153,7 +104,7 @@ TF_LITE_MICRO_TEST(SimpleFloatTest) { float output_data[data_length]; tflite::testing::TestL2Normalization( - input_dims, input_data, expected_output_data, output_data, 0); + input_dims, input_data, expected_output_data, output_data); } TF_LITE_MICRO_TEST(ZerosVectorFloatTest) { @@ -164,7 +115,7 @@ TF_LITE_MICRO_TEST(ZerosVectorFloatTest) { float output_data[data_length]; tflite::testing::TestL2Normalization( - input_dims, input_data, expected_output_data, output_data, 0); + input_dims, input_data, expected_output_data, output_data); } TF_LITE_MICRO_TEST(SimpleFloatWithRankLessThanFourTest) { @@ -176,7 +127,7 @@ TF_LITE_MICRO_TEST(SimpleFloatWithRankLessThanFourTest) { float output_data[data_length]; tflite::testing::TestL2Normalization( - input_dims, input_data, expected_output_data, output_data, 0); + input_dims, input_data, expected_output_data, output_data); } TF_LITE_MICRO_TEST(MultipleBatchFloatTest) { @@ -195,107 +146,91 @@ TF_LITE_MICRO_TEST(MultipleBatchFloatTest) { float output_data[data_length]; tflite::testing::TestL2Normalization( - input_dims, input_data, expected_output_data, output_data, 0); + input_dims, input_data, expected_output_data, output_data); } TF_LITE_MICRO_TEST(ZerosVectorUint8Test) { const int input_dims[] = {4, 1, 1, 1, 6}; constexpr int data_length = 6; - const float input_data[data_length] = {0}; - const float expected_output_data[data_length] = {0}; - uint8_t quantized_input[data_length]; + const uint8_t input_data[data_length] = {127, 127, 127, 127, 127, 127}; + const uint8_t expected_output[data_length] = {128, 128, 128, 128, 128, 128}; uint8_t output_data[data_length]; - tflite::testing::QuantizeInputData(input_data, data_length, quantized_input); - - tflite::testing::TestL2Normalization( - input_dims, quantized_input, expected_output_data, output_data, .1); + tflite::testing::TestL2Normalization(input_dims, input_data, + expected_output, output_data); } TF_LITE_MICRO_TEST(SimpleUint8Test) { const int input_dims[] = {4, 1, 1, 1, 6}; constexpr int data_length = 6; - float input_data[data_length] = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}; - float expected_output[data_length] = {-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}; - uint8_t quantized_input[data_length]; + const uint8_t input_data[data_length] = {57, 165, 172, 204, 82, 133}; + const uint8_t expected_output[data_length] = { + 58, 166, 173, 205, 83, 134, + }; uint8_t output_data[data_length]; - tflite::testing::QuantizeInputData(input_data, data_length, quantized_input); - - tflite::testing::TestL2Normalization( - input_dims, quantized_input, expected_output, output_data, .1); + tflite::testing::TestL2Normalization(input_dims, input_data, + expected_output, output_data); } TF_LITE_MICRO_TEST(SimpleInt8Test) { const int input_dims[] = {4, 1, 1, 1, 6}; constexpr int data_length = 6; - float input_data[data_length] = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}; - float expected_output[data_length] = {-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}; - int8_t quantized_input[data_length]; + const int8_t input_data[data_length] = {-71, 37, 44, 76, -46, 5}; + const int8_t expected_output[data_length] = {-70, 38, 45, 77, -45, 6}; int8_t output_data[data_length]; - tflite::testing::QuantizeInputData(input_data, data_length, quantized_input); - - tflite::testing::TestL2Normalization( - input_dims, quantized_input, expected_output, output_data, .1); + tflite::testing::TestL2Normalization(input_dims, input_data, + expected_output, output_data); } TF_LITE_MICRO_TEST(ZerosVectorInt8Test) { const int input_dims[] = {4, 1, 1, 1, 6}; constexpr int data_length = 6; - const float input_data[data_length] = {0}; - const float expected_output_data[data_length] = {0}; - int8_t quantized_input[data_length]; + const int8_t input_data[data_length] = {-1, -1, -1, -1, -1, -1}; + const int8_t expected_output[data_length] = {0, 0, 0, 0, 0, 0}; int8_t output_data[data_length]; - tflite::testing::QuantizeInputData(input_data, data_length, quantized_input); - - tflite::testing::TestL2Normalization( - input_dims, quantized_input, expected_output_data, output_data, .1); + tflite::testing::TestL2Normalization(input_dims, input_data, + expected_output, output_data); } TF_LITE_MICRO_TEST(MultipleBatchUint8Test) { - const int input_dims[] = {4, 1, 1, 1, 6}; + const int input_dims[] = {2, 3, 6}; constexpr int data_length = 18; - float input_data[data_length] = { - -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 - -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 - -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + const uint8_t input_data[data_length] = { + 57, 165, 172, 204, 82, 133, // batch 1 + 57, 165, 172, 204, 82, 133, // batch 2 + 57, 165, 172, 204, 82, 133, // batch 3 }; - float expected_output[data_length] = { - -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 - -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 - -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + const uint8_t expected_output[data_length] = { + 58, 166, 173, 205, 83, 134, // batch 1 + 58, 166, 173, 205, 83, 134, // batch 2 + 58, 166, 173, 205, 83, 134, // batch 3 }; - uint8_t quantized_input[data_length]; uint8_t output_data[data_length]; - tflite::testing::QuantizeInputData(input_data, data_length, quantized_input); - - tflite::testing::TestL2Normalization( - input_dims, quantized_input, expected_output, output_data, .1); + tflite::testing::TestL2Normalization(input_dims, input_data, + expected_output, output_data); } TF_LITE_MICRO_TEST(MultipleBatchInt8Test) { - const int input_dims[] = {4, 1, 1, 1, 6}; + const int input_dims[] = {2, 3, 6}; constexpr int data_length = 18; - float input_data[data_length] = { - -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 - -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 - -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + const int8_t input_data[data_length] = { + -71, 37, 44, 76, -46, 5, // batch 1 + -71, 37, 44, 76, -46, 5, // batch 2 + -71, 37, 44, 76, -46, 5, // batch 3 }; - float expected_output[data_length] = { - -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 - -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 - -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + const int8_t expected_output[data_length] = { + -70, 38, 45, 77, -45, 6, // batch 1 + -70, 38, 45, 77, -45, 6, // batch 2 + -70, 38, 45, 77, -45, 6, // batch 3 }; - int8_t quantized_input[data_length]; int8_t output_data[data_length]; - tflite::testing::QuantizeInputData(input_data, data_length, quantized_input); - - tflite::testing::TestL2Normalization( - input_dims, quantized_input, expected_output, output_data, .1); + tflite::testing::TestL2Normalization(input_dims, input_data, + expected_output, output_data); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/logical_test.cc b/tensorflow/lite/micro/kernels/logical_test.cc index d5355c830b6..67606e772e4 100644 --- a/tensorflow/lite/micro/kernels/logical_test.cc +++ b/tensorflow/lite/micro/kernels/logical_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/logistic_test.cc b/tensorflow/lite/micro/kernels/logistic_test.cc index b1e56df3161..7ba2dd8f52f 100644 --- a/tensorflow/lite/micro/kernels/logistic_test.cc +++ b/tensorflow/lite/micro/kernels/logistic_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/maximum_minimum_test.cc b/tensorflow/lite/micro/kernels/maximum_minimum_test.cc index 7fab5407cdb..0db93ff18cb 100644 --- a/tensorflow/lite/micro/kernels/maximum_minimum_test.cc +++ b/tensorflow/lite/micro/kernels/maximum_minimum_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -104,14 +104,11 @@ void TestMaxMinQuantized(const TfLiteRegistration& registration, } } -void TestMaxMinQuantizedInt32(const TfLiteRegistration& registration, - const int* input1_dims_data, - const int32_t* input1_data, float input1_scale, - const int* input2_dims_data, - const int32_t* input2_data, float input2_scale, - const int32_t* expected_output_data, - float output_scale, const int* output_dims_data, - int32_t* output_data) { +void TestMaxMinQuantizedInt32( + const TfLiteRegistration& registration, const int* input1_dims_data, + const int32_t* input1_data, const int* input2_dims_data, + const int32_t* input2_data, const int32_t* expected_output_data, + const int* output_dims_data, int32_t* output_data) { TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data); TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); @@ -121,9 +118,9 @@ void TestMaxMinQuantizedInt32(const TfLiteRegistration& registration, constexpr int outputs_size = 1; constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { - CreateQuantized32Tensor(input1_data, input1_dims, input1_scale), - CreateQuantized32Tensor(input2_data, input2_dims, input2_scale), - CreateQuantized32Tensor(output_data, output_dims, output_scale), + CreateInt32Tensor(input1_data, input1_dims), + CreateInt32Tensor(input2_data, input2_dims), + CreateInt32Tensor(output_data, output_dims), }; int inputs_array_data[] = {2, 0, 1}; @@ -210,9 +207,6 @@ TF_LITE_MICRO_TEST(FloatWithBroadcastTest) { } TF_LITE_MICRO_TEST(Int32WithBroadcastTest) { - const float input1_scale = 0.5; - const float input2_scale = 0.5; - const float output_scale = 0.5; const int dims[] = {3, 3, 1, 2}; const int dims_scalar[] = {1, 1}; const int32_t data1[] = {1, 0, -1, -2, 3, 11}; @@ -222,14 +216,12 @@ TF_LITE_MICRO_TEST(Int32WithBroadcastTest) { int32_t output_data[6]; tflite::testing::TestMaxMinQuantizedInt32( - tflite::ops::micro::Register_MAXIMUM(), dims, data1, input1_scale, - dims_scalar, data2, input2_scale, golden_max, output_scale, dims, - output_data); + tflite::ops::micro::Register_MAXIMUM(), dims, data1, dims_scalar, data2, + golden_max, dims, output_data); tflite::testing::TestMaxMinQuantizedInt32( - tflite::ops::micro::Register_MINIMUM(), dims, data1, input1_scale, - dims_scalar, data2, input2_scale, golden_min, output_scale, dims, - output_data); + tflite::ops::micro::Register_MINIMUM(), dims, data1, dims_scalar, data2, + golden_min, dims, output_data); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index f86b28a9ff2..a65fc4f6a15 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -17,10 +17,6 @@ limitations under the License. #include "tensorflow/lite/c/common.h" -namespace tflite { -namespace ops { -namespace micro { - // Forward declaration of all micro op kernel registration methods. These // registrations are included with the standard `BuiltinOpResolver`. // @@ -29,6 +25,22 @@ namespace micro { // their model requires, using a custom `(Micro)MutableOpResolver`. Selective // registration in turn allows the linker to strip unused kernels. +namespace tflite { + +// TFLM is incrementally moving towards a flat tflite namespace +// (https://abseil.io/tips/130). Any new ops (or cleanup of existing ops should +// have their Register function declarations in the tflite namespace. + +TfLiteRegistration Register_CONV_2D(); +TfLiteRegistration Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration Register_QUANTIZE(); +TfLiteRegistration Register_SHAPE(); +TfLiteRegistration Register_SOFTMAX(); +TfLiteRegistration Register_SVDF(); + +namespace ops { +namespace micro { + TfLiteRegistration Register_ABS(); TfLiteRegistration Register_ADD(); TfLiteRegistration Register_ARG_MAX(); @@ -37,14 +49,11 @@ TfLiteRegistration Register_AVERAGE_POOL_2D(); TfLiteRegistration Register_CEIL(); // TODO(b/160234179): Change custom OPs to also return by value. TfLiteRegistration* Register_CIRCULAR_BUFFER(); -TfLiteRegistration Register_CONV_2D(); TfLiteRegistration Register_CONCATENATION(); TfLiteRegistration Register_COS(); -TfLiteRegistration Register_DEPTHWISE_CONV_2D(); TfLiteRegistration Register_DEQUANTIZE(); TfLiteRegistration Register_EQUAL(); TfLiteRegistration Register_FLOOR(); -TfLiteRegistration Register_FULLY_CONNECTED(); TfLiteRegistration Register_GREATER(); TfLiteRegistration Register_GREATER_EQUAL(); TfLiteRegistration Register_HARD_SWISH(); @@ -66,7 +75,6 @@ TfLiteRegistration Register_PACK(); TfLiteRegistration Register_PAD(); TfLiteRegistration Register_PADV2(); TfLiteRegistration Register_PRELU(); -TfLiteRegistration Register_QUANTIZE(); TfLiteRegistration Register_REDUCE_MAX(); TfLiteRegistration Register_RELU(); TfLiteRegistration Register_RELU6(); @@ -75,14 +83,12 @@ TfLiteRegistration Register_RESIZE_NEAREST_NEIGHBOR(); TfLiteRegistration Register_ROUND(); TfLiteRegistration Register_RSQRT(); TfLiteRegistration Register_SIN(); -TfLiteRegistration Register_SOFTMAX(); TfLiteRegistration Register_SPLIT(); TfLiteRegistration Register_SPLIT_V(); TfLiteRegistration Register_SQRT(); TfLiteRegistration Register_SQUARE(); TfLiteRegistration Register_STRIDED_SLICE(); TfLiteRegistration Register_SUB(); -TfLiteRegistration Register_SVDF(); TfLiteRegistration Register_UNPACK(); TfLiteRegistration Register_L2_NORMALIZATION(); TfLiteRegistration Register_TANH(); diff --git a/tensorflow/lite/micro/kernels/mul_test.cc b/tensorflow/lite/micro/kernels/mul_test.cc index 86b4d8be57c..8503a1502d1 100644 --- a/tensorflow/lite/micro/kernels/mul_test.cc +++ b/tensorflow/lite/micro/kernels/mul_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/neg_test.cc b/tensorflow/lite/micro/kernels/neg_test.cc index 544a3eddc1c..f3c0e7d36a8 100644 --- a/tensorflow/lite/micro/kernels/neg_test.cc +++ b/tensorflow/lite/micro/kernels/neg_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/pack_test.cc b/tensorflow/lite/micro/kernels/pack_test.cc index c05595df146..5ac80d698b5 100644 --- a/tensorflow/lite/micro/kernels/pack_test.cc +++ b/tensorflow/lite/micro/kernels/pack_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -168,9 +168,9 @@ void TestPackTwoInputsQuantized32(const int* input1_dims_data, constexpr int output_size = 1; constexpr int tensors_size = input_size + output_size; TfLiteTensor tensors[tensors_size] = { - CreateQuantized32Tensor(input1_data, input1_dims, 1.0), - CreateQuantized32Tensor(input2_data, input2_dims, 1.0), - CreateQuantized32Tensor(output_data, output_dims, 1.0)}; + CreateInt32Tensor(input1_data, input1_dims), + CreateInt32Tensor(input2_data, input2_dims), + CreateInt32Tensor(output_data, output_dims)}; TfLitePackParams builtin_data = { .values_count = 2, diff --git a/tensorflow/lite/micro/kernels/pad_test.cc b/tensorflow/lite/micro/kernels/pad_test.cc index 4d391057858..e94bc993fea 100644 --- a/tensorflow/lite/micro/kernels/pad_test.cc +++ b/tensorflow/lite/micro/kernels/pad_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/pooling_test.cc b/tensorflow/lite/micro/kernels/pooling_test.cc index a33f5df6fd4..9782b49ad98 100644 --- a/tensorflow/lite/micro/kernels/pooling_test.cc +++ b/tensorflow/lite/micro/kernels/pooling_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/prelu_test.cc b/tensorflow/lite/micro/kernels/prelu_test.cc index f559ddff993..3a0b10a0d94 100644 --- a/tensorflow/lite/micro/kernels/prelu_test.cc +++ b/tensorflow/lite/micro/kernels/prelu_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/quantize.cc b/tensorflow/lite/micro/kernels/quantize.cc index a5715bcf981..81f5f073b48 100644 --- a/tensorflow/lite/micro/kernels/quantize.cc +++ b/tensorflow/lite/micro/kernels/quantize.cc @@ -23,9 +23,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_utils.h" namespace tflite { -namespace ops { -namespace micro { -namespace quantize { +namespace { struct OpData { tflite::QuantizationParams quantization_params; @@ -175,22 +173,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace quantize +} // namespace -// This Op (QUANTIZE) quantizes the input and produces quantized output. -// AffineQuantize takes scale and zero point and quantizes the float value to -// quantized output, in int8_t or uint8_t format. TfLiteRegistration Register_QUANTIZE() { - return {/*init=*/quantize::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/quantize::Prepare, - /*invoke=*/quantize::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/quantize_test.cc b/tensorflow/lite/micro/kernels/quantize_test.cc index b92758d3b93..b630fb53bca 100644 --- a/tensorflow/lite/micro/kernels/quantize_test.cc +++ b/tensorflow/lite/micro/kernels/quantize_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -35,8 +34,7 @@ void ValidateQuantizeGoldens(TfLiteTensor* tensors, int tensors_size, TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); // Version 1 of quantize supports int8_t and uint8_t quantization. - const TfLiteRegistration registration = - tflite::ops::micro::Register_QUANTIZE(); + const TfLiteRegistration registration = Register_QUANTIZE(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, outputs_array, /*builtin_data=*/nullptr, micro_test::reporter); diff --git a/tensorflow/lite/micro/kernels/reduce_test.cc b/tensorflow/lite/micro/kernels/reduce_test.cc index e7be6569ca7..fdb8fe95466 100644 --- a/tensorflow/lite/micro/kernels/reduce_test.cc +++ b/tensorflow/lite/micro/kernels/reduce_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/reshape_test.cc b/tensorflow/lite/micro/kernels/reshape_test.cc index 91ecbdc7a49..48d1956f1c8 100644 --- a/tensorflow/lite/micro/kernels/reshape_test.cc +++ b/tensorflow/lite/micro/kernels/reshape_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -113,22 +112,41 @@ void TestReshapeWithoutShape(TfLiteTensor* input_tensor, expected_dims_len, expect_failure); } -template -void TestReshape(const int* input_dims_data, const T* input_data, +void TestReshape(const int* input_dims_data, const float* input_data, const int* shape_dims_data, const int32_t* shape_data, - int* output_dims_data, T* output_data, - const T* expected_output, const size_t expected_output_len, + int* output_dims_data, float* output_data, + const float* expected_output, const size_t expected_output_len, const int* expected_dims, const size_t expected_dims_len, bool expect_failure = false) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); - TfLiteTensor input_tensor = - CreateTensor(input_data, input_dims); - TfLiteTensor shape_tensor = - CreateTensor(shape_data, shape_dims); - TfLiteTensor output_tensor = - CreateTensor(output_data, output_dims); + TfLiteTensor input_tensor = CreateFloatTensor(input_data, input_dims); + TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims); + TfLiteTensor output_tensor = CreateFloatTensor(output_data, output_dims); + + TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor, + expected_output, expected_output_len, expected_dims, + expected_dims_len, expect_failure); +} + +template +void TestReshapeQuantized(const int* input_dims_data, const T* input_data, + const int* shape_dims_data, const int32_t* shape_data, + int* output_dims_data, T* output_data, + const T* expected_output, + const size_t expected_output_len, + const int* expected_dims, + const size_t expected_dims_len, + bool expect_failure = false) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + TfLiteTensor input_tensor = CreateQuantizedTensor( + input_data, input_dims, /*scale=*/1.f, /*zero_point=*/0); + TfLiteTensor shape_tensor = CreateInt32Tensor(shape_data, shape_dims); + TfLiteTensor output_tensor = CreateQuantizedTensor( + output_data, output_dims, /*scale=*/1.f, /*zero_point=*/0); TestReshapeWithShape(&input_tensor, &shape_tensor, &output_tensor, expected_output, expected_output_len, expected_dims, @@ -233,11 +251,11 @@ TF_LITE_MICRO_TEST(ReshapeWithRegularShapesShouldSucceed) { output_dims, output_data_float, golden_output_float, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_int8, shape_dims, shape_int32, output_dims, output_data_int8, golden_output_int8, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_uint8, shape_dims, shape_int32, output_dims, output_data_uint8, golden_output_uint8, golden_output_len, golden_dims, golden_dims_len, false); @@ -265,11 +283,11 @@ TF_LITE_MICRO_TEST(ReshapeWithStretchDimensionShouldSucceed) { output_dims, output_data_float, golden_output_float, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_int8, shape_dims, shape_int32, output_dims, output_data_int8, golden_output_int8, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_uint8, shape_dims, shape_int32, output_dims, output_data_uint8, golden_output_uint8, golden_output_len, golden_dims, golden_dims_len, false); @@ -297,11 +315,11 @@ TF_LITE_MICRO_TEST(ReshapeWithScalarOutputShouldSucceed) { output_dims, output_data_float, golden_output_float, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_int8, shape_dims, shape_int32, output_dims, output_data_int8, golden_output_int8, golden_output_len, golden_dims, golden_dims_len, false); - tflite::testing::TestReshape( + tflite::testing::TestReshapeQuantized( input_dims, input_uint8, shape_dims, shape_int32, output_dims, output_data_uint8, golden_output_uint8, golden_output_len, golden_dims, golden_dims_len, false); @@ -327,8 +345,8 @@ TF_LITE_MICRO_TEST(ReshapeWithLegacyScalarOutputShouldSucceed) { TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data); const int32_t shape_data[] = {0}; - auto shape_tensor = tflite::testing::CreateTensor( - shape_data, shape_dims); + auto shape_tensor = + tflite::testing::CreateInt32Tensor(shape_data, shape_dims); const float expected_output_with_shape[] = {}; const int expected_output_with_shape_len = 0; const float expected_output_no_shape[] = {3}; diff --git a/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc b/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc index bd65077f23a..9362a89a3ed 100644 --- a/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/micro/kernels/resize_nearest_neighbor_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/round_test.cc b/tensorflow/lite/micro/kernels/round_test.cc index 7048a98f52b..8067d8cd091 100644 --- a/tensorflow/lite/micro/kernels/round_test.cc +++ b/tensorflow/lite/micro/kernels/round_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/shape.cc b/tensorflow/lite/micro/kernels/shape.cc new file mode 100755 index 00000000000..df962f629cd --- /dev/null +++ b/tensorflow/lite/micro/kernels/shape.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/memory_helpers.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { + +namespace { +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +void ExtractShape(const TfLiteEvalTensor* input, int32_t* output_data) { + for (int i = 0; i < input->dims->size; ++i) { + output_data[i] = input->dims->data[i]; + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + if (output->type != kTfLiteInt32) { + TF_LITE_KERNEL_LOG(context, "Output type %s (%d) not supported.", + TfLiteTypeGetName(output->type), output->type); + return kTfLiteError; + } else { + ExtractShape(input, tflite::micro::GetTensorData(output)); + } + + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration Register_SHAPE() { + return {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; +} + +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/shape_test.cc b/tensorflow/lite/micro/kernels/shape_test.cc new file mode 100755 index 00000000000..7c7e0db82db --- /dev/null +++ b/tensorflow/lite/micro/kernels/shape_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +void ValidateShape(TfLiteTensor* tensors, const int tensor_count, + int32_t* output_data, const int32_t* expected_output, + int output_dims_count) { + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + const TfLiteRegistration registration = tflite::Register_SHAPE(); + micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array, + outputs_array, nullptr, micro_test::reporter); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output[i], output_data[i]); + } +} + +void TestShape(const int* input_dims_data, const float* input_data, + const int* output_dims_data, const int32_t* expected_output_data, + int32_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateFloatTensor(input_data, input_dims), + CreateInt32Tensor(output_data, output_dims, true), + }; + + ValidateShape(tensors, tensors_size, output_data, expected_output_data, + output_dims_count); +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestShape0) { + int input_shape[] = {1, 5}; + float input_values[] = {1, 3, 1, 3, 5}; + int output_dims[] = {1, 1}; // this is actually input_shapes shape + int32_t expected_output_data[] = {5}; + int32_t output_data[1]; + + tflite::testing::TestShape(input_shape, input_values, output_dims, + expected_output_data, output_data); +} + +TF_LITE_MICRO_TEST(TestShape1) { + int input_shape[] = {2, 4, 3}; + float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + int output_dims[] = {2, 1, 1}; + int32_t expected_output_data[] = {4, 3}; + int32_t output_data[2]; + + tflite::testing::TestShape(input_shape, input_values, output_dims, + expected_output_data, output_data); +} + +TF_LITE_MICRO_TEST(TestShape2) { + int input_shape[] = {2, 12, 1}; + float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + int output_dims[] = {2, 1, 1}; + int32_t expected_output_data[] = {12, 1}; + int32_t output_data[2]; + + tflite::testing::TestShape(input_shape, input_values, output_dims, + expected_output_data, output_data); +} + +TF_LITE_MICRO_TEST(TestShape3) { + int input_shape[] = {2, 2, 6}; + float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + int output_dims[] = {2, 1, 1}; + int32_t expected_output_data[] = {2, 6}; + int32_t output_data[2]; + + tflite::testing::TestShape(input_shape, input_values, output_dims, + expected_output_data, output_data); +} + +TF_LITE_MICRO_TEST(TestShape4) { + int input_shape[] = {2, 2, 2, 3}; + float input_values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + int output_dims[] = {3, 1, 1, 1}; + int32_t expected_output_data[] = {2, 2, 3}; + int32_t output_data[3]; + + tflite::testing::TestShape(input_shape, input_values, output_dims, + expected_output_data, output_data); +} + +TF_LITE_MICRO_TEST(TestShape5) { + int input_shape[] = {1, 1}; + float input_values[] = {1}; + int output_dims[] = {1, 1}; + int32_t expected_output_data[] = {1}; + int32_t output_data[1]; + + tflite::testing::TestShape(input_shape, input_values, output_dims, + expected_output_data, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/softmax.cc b/tensorflow/lite/micro/kernels/softmax.cc index 5a26c4e4caa..c96fa561c7c 100644 --- a/tensorflow/lite/micro/kernels/softmax.cc +++ b/tensorflow/lite/micro/kernels/softmax.cc @@ -25,9 +25,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace activations { namespace { // Softmax parameter data that persists in user_data @@ -92,8 +89,6 @@ TfLiteStatus CalculateSoftmaxParams(TfLiteContext* context, return kTfLiteOk; } -} // namespace - // Takes a tensor and performs softmax along the last dimension. void SoftmaxFloat(const TfLiteEvalTensor* input, TfLiteEvalTensor* output, const SoftmaxParams& op_data) { @@ -212,19 +207,17 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } } -} // namespace activations +} // namespace TfLiteRegistration Register_SOFTMAX() { - return {/*init=*/activations::SoftmaxInit, + return {/*init=*/SoftmaxInit, /*free=*/nullptr, - /*prepare=*/activations::SoftmaxPrepare, - /*invoke=*/activations::SoftmaxEval, + /*prepare=*/SoftmaxPrepare, + /*invoke=*/SoftmaxEval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/softmax_test.cc b/tensorflow/lite/micro/kernels/softmax_test.cc index 808ea9396ba..21fc1074760 100644 --- a/tensorflow/lite/micro/kernels/softmax_test.cc +++ b/tensorflow/lite/micro/kernels/softmax_test.cc @@ -257,8 +257,7 @@ void ValidateSoftmaxGoldens(TfLiteTensor* tensors, const int tensor_count, int outputs_array_data[] = {1, 1}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - const TfLiteRegistration registration = - tflite::ops::micro::Register_SOFTMAX(); + const TfLiteRegistration registration = Register_SOFTMAX(); micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array, outputs_array, &builtin_data, micro_test::reporter); diff --git a/tensorflow/lite/micro/kernels/split_test.cc b/tensorflow/lite/micro/kernels/split_test.cc index 711e807b2e8..cd9a90804e0 100644 --- a/tensorflow/lite/micro/kernels/split_test.cc +++ b/tensorflow/lite/micro/kernels/split_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -42,7 +42,7 @@ void TestSplitTwoOutputsFloat( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateQuantized32Tensor(axis_data, axis_dims, 1.0), + CreateInt32Tensor(axis_data, axis_dims), CreateFloatTensor(input_data, input_dims), CreateFloatTensor(output1_data, output1_dims), CreateFloatTensor(output2_data, output2_dims)}; @@ -104,7 +104,7 @@ void TestSplitFourOutputsFloat( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateQuantized32Tensor(axis_data, axis_dims, 1.0), + CreateInt32Tensor(axis_data, axis_dims), CreateFloatTensor(input_data, input_dims), CreateFloatTensor(output1_data, output1_dims), CreateFloatTensor(output2_data, output2_dims), @@ -171,7 +171,7 @@ void TestSplitTwoOutputsQuantized( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateQuantized32Tensor(axis_data, axis_dims, 1.0), + CreateInt32Tensor(axis_data, axis_dims), CreateQuantizedTensor(input_data, input_dims, 0, 10), CreateQuantizedTensor(output1_data, output1_dims, 0, 10), CreateQuantizedTensor(output2_data, output2_dims, 0, 10)}; @@ -227,10 +227,10 @@ void TestSplitTwoOutputsQuantized32( constexpr int axis_size = 1; constexpr int tensors_size = input_size + output_size + axis_size; TfLiteTensor tensors[tensors_size] = { - CreateQuantized32Tensor(axis_data, axis_dims, 1.0), - CreateQuantized32Tensor(input_data, input_dims, 1.0), - CreateQuantized32Tensor(output1_data, output1_dims, 1.0), - CreateQuantized32Tensor(output2_data, output2_dims, 1.0)}; + CreateInt32Tensor(axis_data, axis_dims), + CreateInt32Tensor(input_data, input_dims), + CreateInt32Tensor(output1_data, output1_dims), + CreateInt32Tensor(output2_data, output2_dims)}; // Currently only support constant axis tensor. tensors[0].allocation_type = kTfLiteMmapRo; diff --git a/tensorflow/lite/micro/kernels/split_v_test.cc b/tensorflow/lite/micro/kernels/split_v_test.cc index 8816ecc4f5d..6a41b2b1985 100755 --- a/tensorflow/lite/micro/kernels/split_v_test.cc +++ b/tensorflow/lite/micro/kernels/split_v_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -64,8 +64,8 @@ void TestSplitVFloat(const int* input_dims_data, const float* input_data, TfLiteTensor tensors[tensors_size]; tensors[0] = CreateFloatTensor(input_data, input_dims); - tensors[1] = CreateQuantized32Tensor(split_data, split_dims, 1.0); - tensors[2] = CreateQuantized32Tensor(axis_data, axis_dims, 1.0); + tensors[1] = CreateInt32Tensor(split_data, split_dims); + tensors[2] = CreateInt32Tensor(axis_data, axis_dims); // add output tensors for (int i = 0; i < N; i++) diff --git a/tensorflow/lite/micro/kernels/strided_slice_test.cc b/tensorflow/lite/micro/kernels/strided_slice_test.cc index a2a472af990..a6de5bd1e59 100644 --- a/tensorflow/lite/micro/kernels/strided_slice_test.cc +++ b/tensorflow/lite/micro/kernels/strided_slice_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -75,9 +75,9 @@ void TestStridedSliceFloat(const int* input_shape, const int* begin_shape, constexpr int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[tensors_size] = { CreateFloatTensor(input_data, input_dims), - CreateQuantized32Tensor(begin_data, begin_dims, 1.0), - CreateQuantized32Tensor(end_data, end_dims, 1.0), - CreateQuantized32Tensor(strides_data, strides_dims, 1.0), + CreateInt32Tensor(begin_data, begin_dims), + CreateInt32Tensor(end_data, end_dims), + CreateInt32Tensor(strides_data, strides_dims), CreateFloatTensor(output_data, output_dims), }; @@ -106,9 +106,9 @@ void TestStridedSliceQuantized( std::numeric_limits::max() + std::numeric_limits::min() / 2; TfLiteTensor tensors[tensors_size] = { CreateQuantizedTensor(input_data, input_dims, 1.0, zero_point), - CreateQuantized32Tensor(begin_data, begin_dims, 1.0), - CreateQuantized32Tensor(end_data, end_dims, 1.0), - CreateQuantized32Tensor(strides_data, strides_dims, 1.0), + CreateInt32Tensor(begin_data, begin_dims), + CreateInt32Tensor(end_data, end_dims), + CreateInt32Tensor(strides_data, strides_dims), CreateQuantizedTensor(output_data, output_dims, 1.0, zero_point), }; diff --git a/tensorflow/lite/micro/kernels/sub_test.cc b/tensorflow/lite/micro/kernels/sub_test.cc index fdfe4234c64..1cc0c80527b 100644 --- a/tensorflow/lite/micro/kernels/sub_test.cc +++ b/tensorflow/lite/micro/kernels/sub_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/svdf.cc b/tensorflow/lite/micro/kernels/svdf.cc index 077e2323744..764fdc1bf06 100644 --- a/tensorflow/lite/micro/kernels/svdf.cc +++ b/tensorflow/lite/micro/kernels/svdf.cc @@ -27,9 +27,6 @@ limitations under the License. #include "tensorflow/lite/micro/micro_utils.h" namespace tflite { -namespace ops { -namespace micro { -namespace svdf { namespace { struct OpData { @@ -47,6 +44,17 @@ struct OpData { int output_zero_point; }; +// Input tensors. +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +// This is a variable tensor, and will be modified by this op. +constexpr int kInputActivationStateTensor = 4; + +// Output tensor. +constexpr int kOutputTensor = 0; + /** * This version of SVDF is specific to TFLite Micro. It contains the following * differences between the TFLite version: @@ -112,7 +120,8 @@ static inline void ApplyTimeWeightsBiasAndActivation( for (int b = 0; b < batch_size; ++b) { float* output_ptr_batch = output_ptr + b * num_units; for (int i = 0; i < num_units; ++i) { - *output_ptr_batch = ActivationValFloat(activation, *output_ptr_batch); + *output_ptr_batch = + tflite::ops::micro::ActivationValFloat(activation, *output_ptr_batch); ++output_ptr_batch; } } @@ -335,19 +344,6 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, } } -} // namespace - -// Input tensors. -constexpr int kInputTensor = 0; -constexpr int kWeightsFeatureTensor = 1; -constexpr int kWeightsTimeTensor = 2; -constexpr int kBiasTensor = 3; -// This is a variable tensor, and will be modified by this op. -constexpr int kInputActivationStateTensor = 4; - -// Output tensor. -constexpr int kOutputTensor = 0; - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -535,19 +531,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace svdf +} // namespace TfLiteRegistration Register_SVDF() { - return {/*init=*/svdf::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/svdf::Prepare, - /*invoke=*/svdf::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/svdf_test.cc b/tensorflow/lite/micro/kernels/svdf_test.cc index 8999d8fb4a0..771ff66a4b7 100644 --- a/tensorflow/lite/micro/kernels/svdf_test.cc +++ b/tensorflow/lite/micro/kernels/svdf_test.cc @@ -497,7 +497,7 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units, int outputs_array_data[] = {1, 5}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - const TfLiteRegistration registration = tflite::ops::micro::Register_SVDF(); + const TfLiteRegistration registration = Register_SVDF(); micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array, outputs_array, ¶ms, micro_test::reporter); diff --git a/tensorflow/lite/micro/kernels/tanh_test.cc b/tensorflow/lite/micro/kernels/tanh_test.cc index ef1564f4675..4a4f94bc2e5 100644 --- a/tensorflow/lite/micro/kernels/tanh_test.cc +++ b/tensorflow/lite/micro/kernels/tanh_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/kernels/unpack_test.cc b/tensorflow/lite/micro/kernels/unpack_test.cc index 5b2c36cdf3f..b5c17bd8d2f 100644 --- a/tensorflow/lite/micro/kernels/unpack_test.cc +++ b/tensorflow/lite/micro/kernels/unpack_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" namespace tflite { namespace testing { @@ -222,10 +222,10 @@ void TestUnpackThreeOutputsQuantized32( constexpr int output_size = 3; constexpr int tensors_size = input_size + output_size; TfLiteTensor tensors[tensors_size] = { - CreateQuantized32Tensor(input_data, input_dims, 1.0), - CreateQuantized32Tensor(output1_data, output1_dims, 1.0), - CreateQuantized32Tensor(output2_data, output2_dims, 1.0), - CreateQuantized32Tensor(output3_data, output3_dims, 1.0)}; + CreateInt32Tensor(input_data, input_dims), + CreateInt32Tensor(output1_data, output1_dims), + CreateInt32Tensor(output2_data, output2_dims), + CreateInt32Tensor(output3_data, output3_dims)}; // Place a unique value in the uninitialized output buffer. for (int i = 0; i < output1_dims_count; ++i) { diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc index 273a2fad583..2c3577d77be 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/conv.cc @@ -28,11 +28,37 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" namespace tflite { -namespace ops { -namespace micro { -namespace conv { -namespace xtensa { -namespace hifimini { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kConvQuantizedDimension = 0; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Cached tensor zero point values for quantized operations. + int32_t input_zero_point; + int32_t output_zero_point; + + // Per channel output multiplier and shift. + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, const int32_t* output_shift, @@ -56,7 +82,6 @@ void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, const int input_height = input_shape.Dims(1); const int input_width = input_shape.Dims(2); const int input_depth = input_shape.Dims(3); - const int input_depth_iters = input_depth / 2; const int filter_height = filter_shape.Dims(1); const int filter_width = filter_shape.Dims(2); @@ -80,7 +105,7 @@ void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, ae_q56s acc_56 = AE_ZEROQ56(); for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int filter_x = 0; filter_x < filter_width; filter_x += 2) { const int in_x = in_x_origin + dilation_width_factor * filter_x; const int in_y = in_y_origin + dilation_height_factor * filter_y; const bool is_point_inside_image = @@ -93,10 +118,10 @@ void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, // with intrinsics: int input_idx = ((batch * input_height + in_y) * input_width + in_x) * - input_depth - + input_depth * 2 - 2; const int8_t* input_vals_offset_ptr = input_data + input_idx; - for (int i = 0; i < input_depth_iters; ++i) { + for (int i = 0; i < input_depth; i += 2) { // Load signed 2x 8bit values and right shift into 24bit // alignment: ae_p24x2s input_vals_24x2; @@ -113,7 +138,7 @@ void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, ((out_channel * filter_height + filter_y) * filter_width + filter_x) * filter_depth + - (i * 2) - 2; + i - 2; const int8_t* filter_vals_offset_ptr = filter_data + filter_idx; @@ -145,9 +170,10 @@ void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, ae_p24x2s acc_24x2 = AE_TRUNCP24Q48(acc_56); // Apply quantized multiplier and accumulate result at 48bit - // alignment: - acc_56 = micro::xtensa::hifimini::MultiplyByQuantizedMultiplier( - acc_24x2, output_multiplier[out_channel], + // alignment. Convert the (unsigned) 32-bit multiplier down to a + // 24-bit multiplier. + acc_56 = MultiplyByQuantizedMultiplier( + acc_24x2, output_multiplier[out_channel] >> 8, output_shift[out_channel]); // Add output offset, cap activation, and assign to the output: @@ -223,8 +249,8 @@ inline void Conv1x32Input32x32Filter( // Apply quantized multiplier and accumulate result at 48bit alignment. // Convert the (unsigned) 32-bit multiplier down to a 24-bit multiplier. - acc_56 = micro::xtensa::hifimini::MultiplyByQuantizedMultiplier( - acc_24x2, output_multiplier[ch] >> 8, output_shift[ch]); + acc_56 = MultiplyByQuantizedMultiplier(acc_24x2, output_multiplier[ch] >> 8, + output_shift[ch]); // Add output offset, cap activation, and assign to the output: acc_56 = AE_ADDQ56(acc_56, output_offset_56); @@ -235,39 +261,6 @@ inline void Conv1x32Input32x32Filter( } } -} // namespace hifimini -} // namespace xtensa - -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -// Conv is quantized along dimension 0: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kConvQuantizedDimension = 0; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t output_zero_point; - - // Per channel output multiplier and shift. - int32_t* per_channel_output_multiplier; - int32_t* per_channel_output_shift; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params, int width, int height, int filter_width, int filter_height, int out_width, @@ -386,16 +379,16 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; - xtensa::hifimini::ConvPerChannel( - op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + ConvPerChannel(op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -420,7 +413,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (input_dims[0] == 1 && input_dims[1] == 1 && input_dims[2] == 1 && input_dims[3] == 32 && filter_dims[0] == 32 && filter_dims[1] == 1 && filter_dims[2] == 1 && filter_dims[3] == 32) { - xtensa::hifimini::Conv1x32Input32x32Filter( + Conv1x32Input32x32Filter( -op_data->input_zero_point, op_data->output_zero_point, op_data->output_activation_min, op_data->output_activation_max, op_data->per_channel_output_multiplier, @@ -447,20 +440,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } return kTfLiteOk; } - -} // namespace conv +} // namespace TfLiteRegistration Register_CONV_2D() { - return {/*init=*/conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/conv::Prepare, - /*invoke=*/conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc index 0d79d56ad0c..4a37becbf4d 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/depthwise_conv.cc @@ -28,11 +28,38 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" namespace tflite { -namespace ops { -namespace micro { -namespace depthwise_conv { -namespace xtensa { -namespace hifimini { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// Depthwise conv is quantized along dimension 3: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kDepthwiseConvQuantizedDimension = 3; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Cached tensor zero point values for quantized operations. + int32_t input_zero_point; + int32_t output_zero_point; + + // Per channel output multiplier and shift. + // TODO(b/141139247): Allocate these dynamically when possible. + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; inline void DepthwiseConvPerChannel( const DepthwiseParams& params, const int32_t* output_multiplier, @@ -145,7 +172,7 @@ inline void DepthwiseConvPerChannel( // Apply quantized multiplier and accumulate result at 48bit // alignment: - acc_56 = micro::xtensa::hifimini::MultiplyByQuantizedMultiplier( + acc_56 = MultiplyByQuantizedMultiplier( acc_24x2, output_multiplier[output_channel], output_shift[output_channel]); @@ -260,12 +287,10 @@ inline void DepthwiseConv4x32MatchingInputAndFilter( // Apply quantized multiplier and accumulate result at 48bit // alignment: - block_0_acc = micro::xtensa::hifimini::MultiplyByQuantizedMultiplier( - acc_24x2_0, mult, shift); + block_0_acc = MultiplyByQuantizedMultiplier(acc_24x2_0, mult, shift); // Apply quantized multiplier and accumulate result at 48bit // alignment: - block_1_acc = micro::xtensa::hifimini::MultiplyByQuantizedMultiplier( - acc_24x2_1, mult, shift); + block_1_acc = MultiplyByQuantizedMultiplier(acc_24x2_1, mult, shift); // Add output offset, cap activation, and assign to the output: block_0_acc = AE_ADDQ56(block_0_acc, output_offset_56); @@ -280,42 +305,6 @@ inline void DepthwiseConv4x32MatchingInputAndFilter( } } -} // namespace hifimini -} // namespace xtensa - -namespace { - -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -// Depthwise conv is quantized along dimension 3: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kDepthwiseConvQuantizedDimension = 3; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t output_zero_point; - - // Per channel output multiplier and shift. - // TODO(b/141139247): Allocate these dynamically when possible. - int32_t* per_channel_output_multiplier; - int32_t* per_channel_output_shift; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, TfLiteDepthwiseConvParams* params, int width, int height, int filter_width, int filter_height, @@ -353,8 +342,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -437,16 +424,16 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, op_params.quantized_activation_min = std::numeric_limits::min(); op_params.quantized_activation_max = std::numeric_limits::max(); - xtensa::hifimini::DepthwiseConvPerChannel( - op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + DepthwiseConvPerChannel(op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { @@ -473,7 +460,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (input_dims[0] == 1 && input_dims[1] == 4 && input_dims[2] == 1 && input_dims[3] == 32 && filter_dims[0] == 1 && filter_dims[1] == 4 && filter_dims[2] == 1 && filter_dims[3] == 32) { - xtensa::hifimini::DepthwiseConv4x32MatchingInputAndFilter( + DepthwiseConv4x32MatchingInputAndFilter( -op_data->input_zero_point, op_data->output_zero_point, std::numeric_limits::min(), std::numeric_limits::max(), op_data->per_channel_output_multiplier, @@ -500,19 +487,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace depthwise_conv +} // namespace TfLiteRegistration Register_DEPTHWISE_CONV_2D() { - return {/*init=*/depthwise_conv::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/depthwise_conv::Prepare, - /*invoke=*/depthwise_conv::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h index 918192c4d8f..a1d14df1352 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h @@ -25,19 +25,13 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" namespace tflite { -namespace ops { -namespace micro { -namespace xtensa { -namespace hifimini { // INT24 MIN/MAX #define INT24_MIN -8388608 #define INT24_MAX 8388607 -// // Multiply 24bit value by a quantized multiplier (w/ shift) and returns a 48bit // aligned value in the QR register. -// inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, int32_t quantized_multiplier, int shift) { @@ -77,13 +71,10 @@ inline ae_q56s MultiplyByQuantizedMultiplier(ae_p24x2s x_24x2, return result_56; } -// // Multiply 32bit value by a quantized multiplier (w/ shift) and returns a 48bit // aligned value in the QR register. -// -inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x, - int32_t quantized_multiplier, - int shift) { +inline ae_q56s MultiplyByQuantizedMultiplierResult48Bit( + int32_t x, int32_t quantized_multiplier, int shift) { // Convert x into a 2x24bit PR register file. If x is outside the numerical // limits of a 24bit integer, the "fractional" or lower 8bits are discarded. // If x is within the range of a 24 bit integer, the "signed" or upper 8bits @@ -99,11 +90,10 @@ inline ae_q56s MultiplyByQuantizedMultiplier(int32_t x, return MultiplyByQuantizedMultiplier(x_24x2, quantized_multiplier, shift); } -// // Calculate quantization params for 24bit runtimes. -// -inline void QuantizeMultiplier(float multiplier, int32_t* quantized_multiplier, - int* shift) { +inline void QuantizeMultiplierForInt24(float multiplier, + int32_t* quantized_multiplier, + int* shift) { if (multiplier == 0.0f) { *quantized_multiplier = 0; *shift = 0; @@ -130,9 +120,7 @@ inline void QuantizeMultiplier(float multiplier, int32_t* quantized_multiplier, *quantized_multiplier = static_cast(q_fixed); } -// // Convert a floating point number to a Q representation for 24 bit integers. -// inline int CreateQConstantForInt24(int integer_bits, float f) { const float min_bounds = static_cast(INT24_MIN); const float max_bounds = static_cast(INT24_MAX); @@ -144,10 +132,6 @@ inline int CreateQConstantForInt24(int integer_bits, float f) { return static_cast(raw); } -} // namespace hifimini -} // namespace xtensa -} // namespace micro -} // namespace ops } // namespace tflite #endif // TENSORFLOW_LITE_MICRO_KERNELS_XTENSA_HIFIMINI_FIXEDPOINT_UTILS_H_ diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc index 4bd9875d152..30a5b6a602a 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc @@ -28,11 +28,31 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" namespace tflite { -namespace ops { -namespace micro { +namespace { -namespace xtensa { -namespace hifimini { +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Cached tensor zero point values for quantized operations. + int32_t input_zero_point; + int32_t filter_zero_point; + int32_t output_zero_point; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; void FullyConnected(const FullyConnectedParams& params, const RuntimeShape& input_shape, const int8_t* input_data, @@ -118,36 +138,6 @@ void FullyConnected(const FullyConnectedParams& params, } } -} // namespace hifimini -} // namespace xtensa - -namespace fully_connected { -namespace { - -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteFusedActivation activation, TfLiteType data_type, const TfLiteTensor* input, @@ -157,15 +147,13 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - xtensa::hifimini::QuantizeMultiplier( - real_multiplier, &data->output_multiplier, &data->output_shift); + QuantizeMultiplierForInt24(real_multiplier, &data->output_multiplier, + &data->output_shift); return CalculateActivationRangeQuantized(context, activation, output, &data->output_activation_min, &data->output_activation_max); } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -218,15 +206,14 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, op_params.quantized_activation_min = data.output_activation_min; op_params.quantized_activation_max = data.output_activation_max; - xtensa::hifimini::FullyConnected( - op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + FullyConnected(op_params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); return kTfLiteOk; } @@ -249,19 +236,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return EvalQuantizedInt8(context, node, data, input, filter, bias, output); } -} // namespace fully_connected +} // namespace TfLiteRegistration Register_FULLY_CONNECTED() { - return {/*init=*/fully_connected::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/fully_connected::Prepare, - /*invoke=*/fully_connected::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc index bbd2a7d9fea..b867e70d98b 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/quantize.cc @@ -25,11 +25,12 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" namespace tflite { -namespace ops { -namespace micro { +namespace { -namespace xtensa { -namespace hifimini { +struct OpData { + int32_t zero_point = 0; + int scale_multiplier = 0; +}; void AffineQuantize(int scale_multiplier, const int32_t zero_point, const RuntimeShape& input_shape, const int16_t* input_data, @@ -98,16 +99,6 @@ void AffineQuantize(int scale_multiplier, const int32_t zero_point, } } -} // namespace hifimini -} // namespace xtensa - -namespace quantize { - -struct OpData { - int32_t zero_point = 0; - int scale_multiplier = 0; -}; - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -121,8 +112,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, 0); // TODO(b/155682734): Fix dangerous input/output scale ratio assumptions. - op_data->scale_multiplier = xtensa::hifimini::CreateQConstantForInt24( - 0, input->params.scale / output->params.scale); + op_data->scale_multiplier = + CreateQConstantForInt24(0, input->params.scale / output->params.scale); op_data->zero_point = output->params.zero_point; @@ -146,31 +137,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } - xtensa::hifimini::AffineQuantize( - op_data->scale_multiplier, op_data->zero_point, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + AffineQuantize(op_data->scale_multiplier, op_data->zero_point, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); return kTfLiteOk; } -} // namespace quantize +} // namespace -// This Op (QUANTIZE) quantizes the input and produces quantized output. -// AffineQuantize takes scale and zero point and quantizes the float value to -// quantized output, in int8_t or uint8_t format. TfLiteRegistration Register_QUANTIZE() { - return {/*init=*/quantize::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/quantize::Prepare, - /*invoke=*/quantize::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc index f0a0134d7e6..79a44e2c670 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc @@ -25,9 +25,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/kernel_util.h" namespace tflite { -namespace ops { -namespace micro { -namespace activations { namespace { struct OpData { @@ -105,8 +102,6 @@ TfLiteStatus Softmax(OpData op_data, const RuntimeShape& input_shape, return kTfLiteOk; } -} // namespace - TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context, const TfLiteTensor* input, TfLiteTensor* output, @@ -196,19 +191,18 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } } -} // namespace activations + +} // namespace TfLiteRegistration Register_SOFTMAX() { - return {/*init=*/activations::SoftmaxInit, + return {/*init=*/SoftmaxInit, /*free=*/nullptr, - /*prepare=*/activations::SoftmaxPrepare, - /*invoke=*/activations::SoftmaxEval, + /*prepare=*/SoftmaxPrepare, + /*invoke=*/SoftmaxEval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc index 545e91bab3d..28f8f1e1af0 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/svdf.cc @@ -28,9 +28,6 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/xtensa_hifimini/fixedpoint_utils.h" namespace tflite { -namespace ops { -namespace micro { -namespace svdf { namespace { struct OpData { @@ -153,10 +150,8 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, dot_prod_56 = AE_Q56S_SLAI(dot_prod_56, 24); ae_p24x2s dot_prod_24x2 = AE_TRUNCP24Q48(dot_prod_56); - dot_prod_56 = - tflite::ops::micro::xtensa::hifimini::MultiplyByQuantizedMultiplier( - dot_prod_24x2, data.effective_scale_1_a, - data.effective_scale_1_b); + dot_prod_56 = MultiplyByQuantizedMultiplier( + dot_prod_24x2, data.effective_scale_1_a, data.effective_scale_1_b); // Cap min/max and convert to int32_t: dot_prod_56 = AE_MAXQ56S(dot_prod_56, output_int16_min_56); @@ -247,10 +242,9 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, ae_q56s output_int8_min_56 = AE_CVTQ48A32S(INT8_MIN); ae_q56s output_zp_56 = AE_CVTQ48A32S(data.output_zero_point); for (int i = 0; i < n_batch * n_unit; ++i) { - ae_q56s x_56 = - tflite::ops::micro::xtensa::hifimini::MultiplyByQuantizedMultiplier( - scratch_output_tensor[i], data.effective_scale_2_a, - data.effective_scale_2_b); + ae_q56s x_56 = MultiplyByQuantizedMultiplierResult48Bit( + scratch_output_tensor[i], data.effective_scale_2_a, + data.effective_scale_2_b); // Add output adjustment: x_56 = AE_ADDQ56(x_56, output_zp_56); // Cap min/max and convert to int32_t (already aligned to 32bit): @@ -262,8 +256,6 @@ void EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node, } } -} // namespace - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -365,12 +357,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); OpData* data = static_cast(node->user_data); - xtensa::hifimini::QuantizeMultiplier(effective_scale_1, - &data->effective_scale_1_a, - &data->effective_scale_1_b); - xtensa::hifimini::QuantizeMultiplier(effective_scale_2, - &data->effective_scale_2_a, - &data->effective_scale_2_b); + QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a, + &data->effective_scale_1_b); + QuantizeMultiplierForInt24(effective_scale_2, &data->effective_scale_2_a, + &data->effective_scale_2_b); data->input_zero_point = input->params.zero_point; data->output_zero_point = output->params.zero_point; @@ -414,19 +404,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -} // namespace svdf +} // namespace TfLiteRegistration Register_SVDF() { - return {/*init=*/svdf::Init, + return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/svdf::Prepare, - /*invoke=*/svdf::Eval, + /*prepare=*/Prepare, + /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, /*custom_name=*/nullptr, /*version=*/0}; } -} // namespace micro -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/memory_arena_threshold_test.cc b/tensorflow/lite/micro/memory_arena_threshold_test.cc index d2bb404051a..bf78945816b 100644 --- a/tensorflow/lite/micro/memory_arena_threshold_test.cc +++ b/tensorflow/lite/micro/memory_arena_threshold_test.cc @@ -49,15 +49,16 @@ constexpr int kKeywordModelNodeAndRegistrationCount = 15; // Run this test with '--copt=-DTF_LITE_STATIC_MEMORY' to get optimized memory // runtime values: #ifdef TF_LITE_STATIC_MEMORY -constexpr int kKeywordModelTotalSize = 14336; -constexpr int kKeywordModelTailSize = 13664; +constexpr int kKeywordModelTotalSize = 14160; +constexpr int kKeywordModelTailSize = 13488; #else -constexpr int kKeywordModelTotalSize = 14704; -constexpr int kKeywordModelTailSize = 14032; +constexpr int kKeywordModelTotalSize = 14512; +constexpr int kKeywordModelTailSize = 13840; #endif constexpr int kKeywordModelHeadSize = 672; constexpr int kKeywordModelTfLiteTensorVariableBufferDataSize = 10240; constexpr int kKeywordModelOpRuntimeDataSize = 148; +constexpr int kKeywordModelPersistentBufferDataSize = 532; constexpr int kTestConvModelArenaSize = 12 * 1024; uint8_t test_conv_tensor_arena[kTestConvModelArenaSize]; @@ -68,14 +69,15 @@ constexpr int kTestConvModelNodeAndRegistrationCount = 7; // NOTE: These values are measured on x86-64: // TODO(b/158651472): Consider auditing these values on non-64 bit systems. #ifdef TF_LITE_STATIC_MEMORY -constexpr int kTestConvModelTotalSize = 9552; -constexpr int kTestConvModelTailSize = 1808; +constexpr int kTestConvModelTotalSize = 9584; +constexpr int kTestConvModelTailSize = 1840; #else -constexpr int kTestConvModelTotalSize = 9712; -constexpr int kTestConvModelTailSize = 1968; +constexpr int kTestConvModelTotalSize = 9760; +constexpr int kTestConvModelTailSize = 2016; #endif constexpr int kTestConvModelHeadSize = 7744; constexpr int kTestConvModelOpRuntimeDataSize = 136; +constexpr int kTestConvModelPersistentBufferDataSize = 648; struct ModelAllocationThresholds { size_t tensor_count = 0; @@ -85,6 +87,7 @@ struct ModelAllocationThresholds { size_t tail_alloc_size = 0; size_t tensor_variable_buffer_data_size = 0; size_t op_runtime_data_size = 0; + size_t persistent_buffer_data = 0; }; void EnsureAllocatedSizeThreshold(const char* allocation_type, size_t actual, @@ -148,6 +151,13 @@ void ValidateModelAllocationThresholds( kPersistentTfLiteTensorQuantizationData) .used_bytes, 0); + EnsureAllocatedSizeThreshold( + "PersistentBufferData", + allocator + .GetRecordedAllocation( + tflite::RecordedAllocationType::kPersistentBufferData) + .used_bytes, + thresholds.persistent_buffer_data); EnsureAllocatedSizeThreshold( "NodeAndRegistration", allocator @@ -195,6 +205,7 @@ TF_LITE_MICRO_TEST(TestKeywordModelMemoryThreshold) { thresholds.tensor_variable_buffer_data_size = kKeywordModelTfLiteTensorVariableBufferDataSize; thresholds.op_runtime_data_size = kKeywordModelOpRuntimeDataSize; + thresholds.persistent_buffer_data = kKeywordModelPersistentBufferDataSize; ValidateModelAllocationThresholds(interpreter.GetMicroAllocator(), thresholds); @@ -216,6 +227,7 @@ TF_LITE_MICRO_TEST(TestConvModelMemoryThreshold) { thresholds.head_alloc_size = kTestConvModelHeadSize; thresholds.tail_alloc_size = kTestConvModelTailSize; thresholds.op_runtime_data_size = kTestConvModelOpRuntimeDataSize; + thresholds.persistent_buffer_data = kTestConvModelPersistentBufferDataSize; ValidateModelAllocationThresholds(interpreter.GetMicroAllocator(), thresholds); diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index edac0d5ae5e..770921b4234 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,6 +36,15 @@ limitations under the License. namespace tflite { namespace { + +// Maximum number of scratch buffer requests per operator. Operator kernels that +// request more than this value will receive an exception. +constexpr size_t kMaxScratchBuffersPerOp = 8; + +// Sentinel value used as a placeholder to mark a ScratchBufferRequest request +// needs a node id assignment. +constexpr int kUnassignedScratchBufferRequestIndex = -1; + // Used to hold information used during allocation calculations. struct AllocationInfo { size_t bytes; @@ -50,7 +59,7 @@ struct AllocationInfo { // requirement for SIMD extensions. constexpr int kBufferAlignment = 16; constexpr char kOfflineMemAllocMetadata[] = "OfflineMemoryAllocation"; -const TfLiteIntArray kZeroLengthIntArray = {0, {}}; +const TfLiteIntArray kZeroLengthIntArray = {}; class MicroBuiltinDataAllocator : public BuiltinDataAllocator { public: @@ -143,17 +152,12 @@ TfLiteStatus CheckOfflinePlannedOffsets(const Model* model, // plan. Methods need to be called in order from `Init`, `Add*`, to `Finish`. class AllocationInfoBuilder { public: - AllocationInfoBuilder(ErrorReporter* reporter, - SimpleMemoryAllocator* allocator) - : reporter_(reporter), allocator_(allocator) {} - - // Initializes the builder by allocating AllocationInfo array from the - // simple memory allocator. - TfLiteStatus Init(size_t tensor_count, size_t scratch_buffer_count) { - tensor_count_ = tensor_count; - buffer_count_ = scratch_buffer_count; - return Allocate(); - } + AllocationInfoBuilder(AllocationInfo* info, size_t tensor_count, + size_t scratch_buffer_count, ErrorReporter* reporter) + : info_(info), + tensor_count_(tensor_count), + buffer_count_(scratch_buffer_count), + reporter_(reporter) {} // Check if model contains offline planned buffer offsets. // - If there's no metadata available, offline_planner_offsets is not set @@ -168,37 +172,20 @@ class AllocationInfoBuilder { TfLiteEvalTensor* eval_tensors); // Add allocation information for the scratch buffers. - TfLiteStatus AddScratchBuffers(internal::ScratchBufferHandle* buffer_handles); + TfLiteStatus AddScratchBuffers( + internal::ScratchBufferRequest* scratch_buffer_requests, + ScratchBufferHandle* scratch_buffer_handles); // Returns a pointer to the built AllocationInfo array. const AllocationInfo* Finish() const { return info_; } - size_t Size() const { return tensor_count_ + buffer_count_; } private: - // Allocate the output AllocationInfo array from the allocator_; - TfLiteStatus Allocate(); - - ErrorReporter* reporter_ = nullptr; - SimpleMemoryAllocator* allocator_ = nullptr; + AllocationInfo* info_ = nullptr; size_t tensor_count_ = 0; size_t buffer_count_ = 0; - AllocationInfo* info_ = nullptr; + ErrorReporter* reporter_ = nullptr; }; -TfLiteStatus AllocationInfoBuilder::Allocate() { - size_t bytes = sizeof(AllocationInfo) * Size(); - info_ = reinterpret_cast( - allocator_->AllocateFromTail(bytes, alignof(AllocationInfo))); - if (info_ == nullptr) { - TF_LITE_REPORT_ERROR( - reporter_, - "Failed to allocate memory for allocation_info, %d bytes required", - bytes); - return kTfLiteError; - } - return kTfLiteOk; -} - TfLiteStatus AllocationInfoBuilder::AddTensors(const SubGraph* subgraph, const int32_t* offline_offsets, TfLiteEvalTensor* eval_tensors) { @@ -342,16 +329,20 @@ TfLiteStatus AllocationInfoBuilder::GetOfflinePlannedOffsets( } TfLiteStatus AllocationInfoBuilder::AddScratchBuffers( - internal::ScratchBufferHandle* buffer_handles) { + internal::ScratchBufferRequest* scratch_buffer_requests, + ScratchBufferHandle* scratch_buffer_handles) { // Set up allocation info for buffers. for (size_t i = tensor_count_; i < tensor_count_ + buffer_count_; ++i) { + internal::ScratchBufferRequest* current_request = + &(scratch_buffer_requests[i - tensor_count_]); + ScratchBufferHandle* current_handle = + &(scratch_buffer_handles[i - tensor_count_]); + AllocationInfo* current = &info_[i]; - internal::ScratchBufferHandle* handle = - &(buffer_handles[i - tensor_count_]); - current->output_ptr = reinterpret_cast(&handle->data); - current->bytes = handle->bytes; - current->first_created = handle->node_idx; - current->last_used = handle->node_idx; + current->output_ptr = reinterpret_cast(¤t_handle->data); + current->bytes = current_request->bytes; + current->first_created = current_request->node_idx; + current->last_used = current_request->node_idx; current->offline_offset = kOnlinePlannedBuffer; current->needs_allocating = true; } @@ -459,7 +450,7 @@ void* GetFlatbufferTensorBuffer( // and if there is update the runtime structure to point to its location in // memory. // First see if there's any buffer information in the serialized tensor. - // TODO(b/160894903): Add better unit tests that validate flatbuffer values. + // TODO(b/170379532): Add better unit tests to validate flatbuffer values. void* out_buffer = nullptr; if (auto* buffer = (*buffers)[flatbuffer_tensor.buffer()]) { // If we've found a buffer, does it have any data? @@ -670,7 +661,7 @@ TfLiteStatus MicroAllocator::StartModelAllocation( model_is_allocating_ = true; - TF_LITE_ENSURE_STATUS(InitScratchBufferHandles()); + TF_LITE_ENSURE_STATUS(InitScratchBufferData()); TF_LITE_ENSURE_STATUS(AllocateTfLiteEvalTensors(model, eval_tensors)); TF_LITE_ENSURE_STATUS( AllocateNodeAndRegistrations(model, node_and_registrations)); @@ -682,7 +673,7 @@ TfLiteStatus MicroAllocator::StartModelAllocation( TfLiteStatus MicroAllocator::FinishModelAllocation( const Model* model, TfLiteEvalTensor* eval_tensors, - void** scratch_buffer_handles) { + ScratchBufferHandle** scratch_buffer_handles) { if (!model_is_allocating_) { TF_LITE_REPORT_ERROR(error_reporter_, "MicroAllocator: Model allocation finished before " @@ -693,13 +684,12 @@ TfLiteStatus MicroAllocator::FinishModelAllocation( const SubGraph* subgraph = GetSubGraphFromModel(model); TFLITE_DCHECK(subgraph != nullptr); - TF_LITE_ENSURE_STATUS(MoveScratchBufferHandlesToTail()); - TF_LITE_ENSURE_STATUS(CommitStaticMemoryPlan(model, subgraph, eval_tensors)); + TF_LITE_ENSURE_STATUS(AllocateScratchBufferHandles( + scratch_buffer_handles, scratch_buffer_request_count_)); + TF_LITE_ENSURE_STATUS(CommitStaticMemoryPlan(model, subgraph, eval_tensors, + *scratch_buffer_handles)); TF_LITE_ENSURE_STATUS(AllocateVariables(subgraph, eval_tensors)); - if (scratch_buffer_handles != nullptr) { - *scratch_buffer_handles = scratch_buffer_handles_; - } model_is_allocating_ = false; return kTfLiteOk; } @@ -708,42 +698,74 @@ void* MicroAllocator::AllocatePersistentBuffer(size_t bytes) { return memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); } -TfLiteStatus MicroAllocator::RequestScratchBufferInArena(int node_id, - size_t bytes, +TfLiteStatus MicroAllocator::RequestScratchBufferInArena(size_t bytes, int* buffer_idx) { - // This method is only called during Prepare stage, when the scratch buffer - // handles are placed in the head. + // All scratch buffer requests are stored in the head section of the arena + // when a model is in the prepare phase. First align a scratch buffer request + // pointer to the start of the head: + internal::ScratchBufferRequest* requests = GetScratchBufferRequests(); - // Allocate space for the new scratch buffer handle. - TF_LITE_ENSURE_STATUS(memory_allocator_->EnsureHeadSize( - sizeof(internal::ScratchBufferHandle) * (scratch_buffer_count_ + 1), - alignof(internal::ScratchBufferHandle))); - - if (scratch_buffer_handles_ == nullptr) { - // If this is the first scratch buffer handle, place it in the buffer head. - scratch_buffer_handles_ = reinterpret_cast( - memory_allocator_->GetBufferHead()); + // Count the number of requested scratch buffers for the current node: + size_t current_node_request_count = 0; + for (size_t i = 0; i < scratch_buffer_request_count_; ++i) { + if (requests[i].node_idx == kUnassignedScratchBufferRequestIndex) { + ++current_node_request_count; + } } - // Initialize the handle. `data` field will be set during memory planning. - internal::ScratchBufferHandle* handle = - scratch_buffer_handles_ + scratch_buffer_count_; - *handle = {}; - handle->bytes = bytes; - handle->node_idx = node_id; + // First, ensure that the per-kernel request has not exceeded the limit: + if (current_node_request_count >= kMaxScratchBuffersPerOp) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Scratch buffer request exeeds limit per operator (%d)", + kMaxScratchBuffersPerOp); + return kTfLiteError; + } - // Buffer idx starts from 0 in this implementation. - *buffer_idx = scratch_buffer_count_; - scratch_buffer_count_ += 1; + // Initialize and assign values for the request at the current index: + internal::ScratchBufferRequest* current_request = + &requests[scratch_buffer_request_count_]; + *current_request = {}; + // Assign -1 as a sentinel value that will be updated when the node finishes + // allocating: + current_request->bytes = bytes; + current_request->node_idx = kUnassignedScratchBufferRequestIndex; + + // Assign the current request index to the out-param: + *buffer_idx = scratch_buffer_request_count_; + + // Bump the request count to prepare for the next request: + ++scratch_buffer_request_count_; return kTfLiteOk; } -void* MicroAllocator::GetScratchBuffer(void* scratch_buffer_handles, - int buffer_idx) { - internal::ScratchBufferHandle* handle = - reinterpret_cast(scratch_buffer_handles) + - buffer_idx; - return handle->data; +TfLiteStatus MicroAllocator::FinishPrepareNodeAllocations(int node_id) { + // When a node has finished preparing, all temp allocations performed by the + // kernel should be cleaned up: + ResetTempAllocations(); + + // Find and update any new scratch buffer requests for the current node: + internal::ScratchBufferRequest* requests = GetScratchBufferRequests(); + + for (size_t i = 0; i < scratch_buffer_request_count_; ++i) { + // A request with a node_idx of -1 is a sentinel value used to indicate this + // was a new request for the current node. The allocator finally knows the + // node index at this point. Assign the value and update the list of new + // requests so the head section can be adjusted to allow for the next kernel + // to allocate at most kMaxScratchBuffersPerOp requests: + if (requests[i].node_idx == kUnassignedScratchBufferRequestIndex) { + requests[i].node_idx = node_id; + } + } + + // Ensure that the head is re-adjusted to allow for another at-most + // kMaxScratchBuffersPerOp scratch buffer requests in the next operator: + TF_LITE_ENSURE_STATUS(memory_allocator_->SetHeadBufferSize( + sizeof(internal::ScratchBufferRequest) * + (scratch_buffer_request_count_ + kMaxScratchBuffersPerOp), + alignof(internal::ScratchBufferRequest))); + + return kTfLiteOk; } size_t MicroAllocator::used_bytes() const { @@ -997,9 +1019,9 @@ TfLiteTensor* MicroAllocator::AllocatePersistentTfLiteTensorInternal( TfLiteStatus MicroAllocator::PopulateTfLiteTensorFromFlatbuffer( const Model* model, const SubGraph* subgraph, TfLiteTensor* tensor, int tensor_index, bool allocate_temp) { - // TODO(b/160894903): This method serves as a stub to ensure quantized - // allocations in the tail can be recorded. Once all kernels have been ported - // to the new API this can be dropped. + // TODO(b/162311891): This method serves as a stub to ensure quantized + // allocations in the tail can be recorded. Once the interpreter has APIs for + // accessing buffers on TfLiteEvalTensor this method can be dropped. return internal::InitializeTfLiteTensorFromFlatbuffer( memory_allocator_, allocate_temp, *subgraph->tensors()->Get(tensor_index), model->buffers(), error_reporter_, tensor); @@ -1021,89 +1043,138 @@ const SubGraph* MicroAllocator::GetSubGraphFromModel(const Model* model) { TfLiteStatus MicroAllocator::CommitStaticMemoryPlan( const Model* model, const SubGraph* subgraph, - TfLiteEvalTensor* eval_tensors) { + TfLiteEvalTensor* eval_tensors, + ScratchBufferHandle* scratch_buffer_handles) { size_t head_usage = 0; // Create static memory plan // 1. Calculate AllocationInfo to know the lifetime of each tensor/buffer. // 2. Add them into the planner (such as the GreedyMemoryPlanner). // 3. Static memory planning using the planner. // 4. Set tensor/buffer pointers based on the offsets from the previous step. + // // Note that AllocationInfo is only needed for creating the plan. It will be - // thrown away when the child allocator (tmp_allocator) goes out of scope. - { - // TODO(b/162595810): Use temp allocation buffer instead of a stack - // instance: - SimpleMemoryAllocator tmp_allocator(error_reporter_, - memory_allocator_->GetBufferHead(), - memory_allocator_->GetTail()); + // allocated from the temp section and cleaned up at the bottom of this + // function. - AllocationInfoBuilder builder(error_reporter_, &tmp_allocator); - TF_LITE_ENSURE_STATUS( - builder.Init(subgraph->tensors()->size(), scratch_buffer_count_)); + size_t allocation_info_count = + subgraph->tensors()->size() + scratch_buffer_request_count_; + size_t bytes = sizeof(AllocationInfo) * allocation_info_count; - const int32_t* offline_planner_offsets = nullptr; - TF_LITE_ENSURE_STATUS( - builder.GetOfflinePlannedOffsets(model, &offline_planner_offsets)); - TF_LITE_ENSURE_STATUS( - builder.AddTensors(subgraph, offline_planner_offsets, eval_tensors)); - TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_handles_)); - const AllocationInfo* allocation_info = builder.Finish(); - - // Remaining arena size that memory planner can use for calculating offsets. - size_t remaining_arena_size = - tmp_allocator.GetAvailableMemory(kBufferAlignment); - uint8_t* planner_arena = - tmp_allocator.AllocateTemp(remaining_arena_size, kBufferAlignment); - TF_LITE_ENSURE(error_reporter_, planner_arena != nullptr); - GreedyMemoryPlanner planner(planner_arena, remaining_arena_size); - TF_LITE_ENSURE_STATUS( - CreatePlan(error_reporter_, &planner, allocation_info, builder.Size())); - - size_t actual_available_arena_size = - memory_allocator_->GetAvailableMemory(kBufferAlignment); - - // Make sure we have enough arena size. - if (planner.GetMaximumMemorySize() > actual_available_arena_size) { - TF_LITE_REPORT_ERROR( - error_reporter_, - "Arena size is too small for all buffers. Needed %u but only " - "%u was available.", - planner.GetMaximumMemorySize(), actual_available_arena_size); - return kTfLiteError; - } - // Commit the plan. - TF_LITE_ENSURE_STATUS(CommitPlan(error_reporter_, &planner, - memory_allocator_->GetBufferHead(), - allocation_info, builder.Size())); - head_usage = planner.GetMaximumMemorySize(); + // Allocate an array of AllocationInfo structs from the temp section. This + // struct will be used by AllocationInfoBuilder to find buffer usage. + AllocationInfo* allocation_info = reinterpret_cast( + memory_allocator_->AllocateTemp(bytes, alignof(AllocationInfo))); + if (allocation_info == nullptr) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Failed to allocate memory for allocation_info, %d bytes required", + bytes); + return kTfLiteError; } + // Use the AllocationInfoBuilder class to help determine where buffers are + // used in the subgraph. + AllocationInfoBuilder builder(allocation_info, subgraph->tensors()->size(), + scratch_buffer_request_count_, error_reporter_); + + const int32_t* offline_planner_offsets = nullptr; TF_LITE_ENSURE_STATUS( - memory_allocator_->EnsureHeadSize(head_usage, kBufferAlignment)); + builder.GetOfflinePlannedOffsets(model, &offline_planner_offsets)); + TF_LITE_ENSURE_STATUS( + builder.AddTensors(subgraph, offline_planner_offsets, eval_tensors)); + + internal::ScratchBufferRequest* scratch_buffer_requests = + GetScratchBufferRequests(); + + TF_LITE_ENSURE_STATUS(builder.AddScratchBuffers(scratch_buffer_requests, + scratch_buffer_handles)); + + // Remaining arena size that memory planner can use for calculating offsets. + size_t remaining_arena_size = + memory_allocator_->GetAvailableMemory(kBufferAlignment); + uint8_t* planner_arena = + memory_allocator_->AllocateTemp(remaining_arena_size, kBufferAlignment); + TF_LITE_ENSURE(error_reporter_, planner_arena != nullptr); + GreedyMemoryPlanner planner(planner_arena, remaining_arena_size); + TF_LITE_ENSURE_STATUS(CreatePlan(error_reporter_, &planner, allocation_info, + allocation_info_count)); + + // Reset all temp allocations used above: + memory_allocator_->ResetTempAllocations(); + + size_t actual_available_arena_size = + memory_allocator_->GetAvailableMemory(kBufferAlignment); + + // Make sure we have enough arena size. + if (planner.GetMaximumMemorySize() > actual_available_arena_size) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Arena size is too small for all buffers. Needed %u but only " + "%u was available.", + planner.GetMaximumMemorySize(), actual_available_arena_size); + return kTfLiteError; + } + // Commit the plan. + TF_LITE_ENSURE_STATUS(CommitPlan(error_reporter_, &planner, + memory_allocator_->GetHeadBuffer(), + allocation_info, allocation_info_count)); + head_usage = planner.GetMaximumMemorySize(); + + // The head is used to store memory plans for one model at a time during the + // model preparation stage, and is re-purposed to store scratch buffer handles + // during model invocation. The head must be as large as the greater of the + // largest model memory plan's size and the total space required for all + // scratch buffer handles. + if (max_head_buffer_usage_ < head_usage) { + max_head_buffer_usage_ = head_usage; + } + + // The head is used for storing scratch buffer allocations before finalizing a + // memory plan in this function. Ensure that the head is set to the largest + // memory plan sent through the allocator: + TF_LITE_ENSURE_STATUS(memory_allocator_->SetHeadBufferSize( + max_head_buffer_usage_, kBufferAlignment)); return kTfLiteOk; } -TfLiteStatus MicroAllocator::InitScratchBufferHandles() { - scratch_buffer_count_ = 0; - scratch_buffer_handles_ = nullptr; - return kTfLiteOk; -} +TfLiteStatus MicroAllocator::AllocateScratchBufferHandles( + ScratchBufferHandle** scratch_buffer_handles, size_t handle_count) { + TFLITE_DCHECK(scratch_buffer_handles != nullptr); -TfLiteStatus MicroAllocator::MoveScratchBufferHandlesToTail() { - if (scratch_buffer_count_ == 0) { + if (scratch_buffer_request_count_ == 0) { + // No scratch buffer requests were requested during model allocation. return kTfLiteOk; } - auto src = scratch_buffer_handles_; - internal::ScratchBufferHandle* dest = - reinterpret_cast( - memory_allocator_->AllocateFromTail( - sizeof(internal::ScratchBufferHandle) * scratch_buffer_count_, - alignof(internal::ScratchBufferHandle))); - for (size_t i = 0; i < scratch_buffer_count_; i++) { - *(dest + i) = *(src + i); - } - scratch_buffer_handles_ = dest; + + // Allocate a consecutive block of memory store the scratch buffer handles. + // This alignment ensures quick lookup during inference time for the model: + *scratch_buffer_handles = reinterpret_cast( + memory_allocator_->AllocateFromTail( + sizeof(ScratchBufferHandle) * handle_count, + alignof(ScratchBufferHandle))); + return kTfLiteOk; } +TfLiteStatus MicroAllocator::InitScratchBufferData() { + // A model is preparing to allocate resources, ensure that scratch buffer + // request counter is cleared: + scratch_buffer_request_count_ = 0; + + // All requests will be stored in the head section. Each kernel is allowed at + // most kMaxScratchBuffersPerOp requests. Adjust the head to reserve at most + // that many requests to begin: + TF_LITE_ENSURE_STATUS(memory_allocator_->SetHeadBufferSize( + sizeof(internal::ScratchBufferRequest) * kMaxScratchBuffersPerOp, + alignof(internal::ScratchBufferRequest))); + + return kTfLiteOk; +} + +internal::ScratchBufferRequest* MicroAllocator::GetScratchBufferRequests() { + return reinterpret_cast( + AlignPointerUp(memory_allocator_->GetHeadBuffer(), + alignof(internal::ScratchBufferRequest))); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h index 5b478832d3d..39a12eaf858 100644 --- a/tensorflow/lite/micro/micro_allocator.h +++ b/tensorflow/lite/micro/micro_allocator.h @@ -1,5 +1,5 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. -b/160894903 +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -28,30 +28,29 @@ limitations under the License. namespace tflite { -// Namespace used for unittests. namespace internal { // Sets up all of the data structure members for a TfLiteTensor based on the // contents of a serialized tensor in the flatbuffer. -// TODO(b/160894903): Once all kernels have been updated to the new -// TfLiteEvalTensor API - drop the allocate_temp flag. This enables internal -// flatbuffer quantization or dimension allocations to take place in either the -// temp or tail section of the arena. +// TODO(b/162311891): Drop this method when the interpreter has an API for +// returning buffers on TfLiteEvalTensor. TfLiteStatus InitializeTfLiteTensorFromFlatbuffer( SimpleMemoryAllocator* allocator, bool allocate_temp, const tflite::Tensor& flatbuffer_tensor, const flatbuffers::Vector>* buffers, ErrorReporter* error_reporter, TfLiteTensor* result); -// A handle tracking scratch buffer allocation. This handle is created by -// `RequestScratchBufferInArena`. `data` field is populated in -// `FinishModelAllocation` after static memory planning. -// TODO(b/150257460) As a future optimization, this struct could be replaced by -// a union, since once `data` is populated, `bytes` and `node_idx` is not -// needed. +// Holds placeholder information for a scratch buffer request from a kernel. +// This struct is only used during the model prepare stage. Each request from a +// kernel is stored in the head section. During the prepare stage, the head +// section will at least hold kMaxScratchBuffersPerOp number of requests plus +// any requests from previous kernel requests. +// +// When the memory plan is finalized, these structs are no longer used in favor +// of a sequential, array of ScratchBufferHandle allocations in the tail +// section. These allocations are indexed by the request API defined in the +// TfLiteContext struct. typedef struct { - // Pointer to the scratch buffer. - uint8_t* data; // Number of bytes required by the buffer. The actual allocated size might be // greater than `bytes` due to buffer alignment. size_t bytes; @@ -59,7 +58,8 @@ typedef struct { // determine the lifetime of the buffer. In AllocationInfo, this buffer will // have `before` = node_idx and `after` = node_idx. int node_idx; -} ScratchBufferHandle; +} ScratchBufferRequest; + } // namespace internal typedef struct { @@ -67,6 +67,14 @@ typedef struct { const TfLiteRegistration* registration; } NodeAndRegistration; +// Holds a pointer to a buffer for a scratch buffer requested by a kernel during +// the model prepare stage. This struct is allocated in-place and allows for +// quick pointer-indexed lookup for speed during model inference. +typedef struct { + // Pointer to location of the scratch buffer: + uint8_t* data; +} ScratchBufferHandle; + // Allocator responsible for allocating memory for all intermediate tensors // necessary to invoke a model. // @@ -126,9 +134,9 @@ class MicroAllocator { // passed into this class during StartModelAllocation(). Scratch buffer // handles are stored in the out-param `scratch_buffer_handles`. This value // will be used in `GetScratchBuffer` call to retrieve scratch buffers. - TfLiteStatus FinishModelAllocation(const Model* model, - TfLiteEvalTensor* eval_tensors, - void** scratch_buffer_handles = nullptr); + TfLiteStatus FinishModelAllocation( + const Model* model, TfLiteEvalTensor* eval_tensors, + ScratchBufferHandle** scratch_buffer_handles); // Allocates a TfLiteTensor struct and populates the returned value with // properties from the model flatbuffer. This struct is allocated from @@ -157,24 +165,19 @@ class MicroAllocator { // Allocates persistent buffer which has the same life time as the allocator. // The memory is immediately available and is allocated from the tail of the // arena. - void* AllocatePersistentBuffer(size_t bytes); + virtual void* AllocatePersistentBuffer(size_t bytes); // Register a scratch buffer of size `bytes` for Node with `node_id`. - // This method only allocates a BufferHandle holding information for memory - // planning. The buffer ptr is ready after `FinishModelAllocation` and can - // be retrieved by `GetScratchBuffer` method using the returned buffer_idx. - // Note that this method should only be called in the Prepare stage. - TfLiteStatus RequestScratchBufferInArena(int node_id, size_t bytes, - int* buffer_idx); + // This method only requests a buffer with a given size to be used after a + // model has finished allocation via FinishModelAllocation(). All requested + // buffers will be accessible by the out-param in that method. + TfLiteStatus RequestScratchBufferInArena(size_t bytes, int* buffer_idx); - // Return the number of scratch buffers in the allocator. - size_t GetScratchBufferCount() const { return scratch_buffer_count_; } - - // Return the pointer to the planned scratch buffer. `scratch_buffer_handles` - // should be the corresponding value returned in `FinishModelAllocation`. - // `scratch_buffer_handles` is intentionally desigend as void*. The actual - // data type is an implementation detail, and is only visible in this class. - static void* GetScratchBuffer(void* scratch_buffer_handles, int buffer_idx); + // Finish allocating a specific NodeAndRegistration prepare block (kernel + // entry for a model) with a given node ID. This call ensures that any scratch + // buffer requests and temporary allocations are handled and ready for the + // next node prepare block. + TfLiteStatus FinishPrepareNodeAllocations(int node_id); // Returns the arena usage in bytes, only available after // `FinishModelAllocation`. Otherwise, it will return 0. @@ -210,17 +213,15 @@ class MicroAllocator { virtual TfLiteStatus AllocateVariables(const SubGraph* subgraph, TfLiteEvalTensor* eval_tensors); - // TODO(b/160894903): Once all kernels have been updated to the new API drop - // this method. It is only used to record TfLiteTensor persistent allocations. + // Allocate and return a persistent TfLiteTensor. + // TODO(b/162311891): Drop this method when the interpreter has an API for + // accessing TfLiteEvalTensor structs. virtual TfLiteTensor* AllocatePersistentTfLiteTensorInternal( const Model* model, TfLiteEvalTensor* eval_tensors, int tensor_index); // Populates a TfLiteTensor struct with data from the model flatbuffer. Any // quantization data is allocated from either the tail (persistent) or temp // sections of the arena based on the allocation flag. - // TODO(b/160894903): Once all kernels have been updated to the new API drop - // this function since all allocations for quantized data will take place in - // the temp section. virtual TfLiteStatus PopulateTfLiteTensorFromFlatbuffer( const Model* model, const SubGraph* subgraph, TfLiteTensor* tensor, int tensor_index, bool allocate_temp); @@ -234,10 +235,28 @@ class MicroAllocator { // Commits a memory plan for all non-persistent buffer allocations in the // 'head' section of the memory arena. The eval_tensors pointer is the list of // pre-allocated TfLiteEvalTensor structs that will point to the buffers that - // will be allocated into the head section in this function call. - virtual TfLiteStatus CommitStaticMemoryPlan(const Model* model, - const SubGraph* subgraph, - TfLiteEvalTensor* eval_tensors); + // will be allocated into the head section in this function call. The + // scratch_buffer_handles pointer is the array of pre-allocated + // ScratchBufferHandle structs that will point to allocated buffers also in + // the head section. + virtual TfLiteStatus CommitStaticMemoryPlan( + const Model* model, const SubGraph* subgraph, + TfLiteEvalTensor* eval_tensors, + ScratchBufferHandle* scratch_buffer_handles); + + // Allocates an array of ScratchBufferHandle structs in the tail section for a + // given number of handles. + virtual TfLiteStatus AllocateScratchBufferHandles( + ScratchBufferHandle** scratch_buffer_handles, size_t handle_count); + + // Clears all internal scratch buffer request counts and resets the head to + // prepare for kernels to request scratch buffer data when a model is + // preparing. + TfLiteStatus InitScratchBufferData(); + + // Returns the pointer for the array of ScratchBufferRequest allocations in + // the head section. + internal::ScratchBufferRequest* GetScratchBufferRequests(); // A simple memory allocator that always allocate from the arena tail or head. SimpleMemoryAllocator* memory_allocator_; @@ -245,15 +264,13 @@ class MicroAllocator { ErrorReporter* error_reporter_; bool model_is_allocating_; - // Points to the first allocated scratch buffer handle. - // Scratch buffer handles are placed in the head during `Prepare` stage and - // then moved to the tail for static memory plan. - internal::ScratchBufferHandle* scratch_buffer_handles_ = nullptr; - // How many scratch buffers have been allocated. - size_t scratch_buffer_count_ = 0; + // Holds the number of ScratchBufferRequest instances stored in the head + // section when a model is allocating. + size_t scratch_buffer_request_count_ = 0; - virtual TfLiteStatus InitScratchBufferHandles(); - virtual TfLiteStatus MoveScratchBufferHandlesToTail(); + // Holds the byte length of the memory plan with the largest head usage. Used + // to ensure that multi-tenant allocations can share the head for buffers. + size_t max_head_buffer_usage_ = 0; TF_LITE_REMOVE_VIRTUAL_DELETE }; diff --git a/tensorflow/lite/micro/micro_allocator_test.cc b/tensorflow/lite/micro/micro_allocator_test.cc index 1ac68443f1a..222c8566fb9 100644 --- a/tensorflow/lite/micro/micro_allocator_test.cc +++ b/tensorflow/lite/micro/micro_allocator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -142,8 +142,9 @@ TF_LITE_MICRO_TEST(TestInitializeRuntimeTensor) { simple_allocator->~SimpleMemoryAllocator(); } -// TODO(b/160894903): Drop this test when InitializeTfLiteTensorFromFlatbuffer() -// always allocates from temp (kernels are using the new TfLiteEvalTensor API): +// TODO(b/162311891): Drop this test when InitializeTfLiteTensorFromFlatbuffer() +// always allocates from temp (interpreter returns buffers from +// TfLiteEvalTensor): TF_LITE_MICRO_TEST(TestInitializeTempRuntimeTensor) { constexpr size_t arena_size = 1024; uint8_t arena[arena_size]; @@ -246,6 +247,7 @@ TF_LITE_MICRO_TEST(TestFailsWhenModelStartsTwice) { TF_LITE_MICRO_TEST(TestFailsWithWrongSequence) { const tflite::Model* model = tflite::testing::GetSimpleMockModel(); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver(); tflite::NodeAndRegistration* node_and_registration; constexpr size_t arena_size = 1024; @@ -256,7 +258,8 @@ TF_LITE_MICRO_TEST(TestFailsWithWrongSequence) { // We can't finish allocation before it ever got started. TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteError, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); // Start twice is not allowed. TF_LITE_MICRO_EXPECT_EQ( @@ -272,6 +275,7 @@ TF_LITE_MICRO_TEST(TestFailsWithWrongSequence) { TF_LITE_MICRO_TEST(TestMockModelAllocation) { const tflite::Model* model = tflite::testing::GetSimpleMockModel(); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver(); tflite::NodeAndRegistration* node_and_registration; constexpr size_t arena_size = 1024; @@ -284,7 +288,8 @@ TF_LITE_MICRO_TEST(TestMockModelAllocation) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); size_t model_tensor_size = tflite::testing::GetModelTensorCount(model); TF_LITE_MICRO_EXPECT_EQ(static_cast(4), model_tensor_size); @@ -319,6 +324,7 @@ TF_LITE_MICRO_TEST(TestMultiTenantAllocation) { tflite::MicroAllocator::Create(arena, arena_size, micro_test::reporter); TF_LITE_MICRO_EXPECT_NE(nullptr, allocator); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; // Allocate for model 1. We use ComplexMockModel here to cover the code path // allocatig variables. @@ -329,7 +335,8 @@ TF_LITE_MICRO_TEST(TestMultiTenantAllocation) { allocator->StartModelAllocation(model1, op_resolver, &node_and_registration1, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model1, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model1, eval_tensors, + &scratch_buffer_handles)); const size_t single_model_used_bytes = allocator->used_bytes(); // Allocate for model 2. @@ -340,7 +347,8 @@ TF_LITE_MICRO_TEST(TestMultiTenantAllocation) { allocator->StartModelAllocation(model2, op_resolver, &node_and_registration2, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model2, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model2, eval_tensors, + &scratch_buffer_handles)); // Allocation for two instances of the same model takes less memory as `head` // of the arena is reused. @@ -350,6 +358,7 @@ TF_LITE_MICRO_TEST(TestMultiTenantAllocation) { TF_LITE_MICRO_TEST(TestAllocationForModelsWithBranches) { const tflite::Model* model = tflite::testing::GetSimpleModelWithBranch(); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver(); tflite::NodeAndRegistration* node_and_registration; constexpr size_t arena_size = 4096; @@ -362,7 +371,8 @@ TF_LITE_MICRO_TEST(TestAllocationForModelsWithBranches) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); uint8_t* start = eval_tensors[0].data.uint8; // Check test_helpers.cc BuildSimpleModelWithBranch for model structure. @@ -389,6 +399,7 @@ TF_LITE_MICRO_TEST(TestAllocationForModelsWithBranches) { TF_LITE_MICRO_TEST(TestAllocationForComplexModelAllocation) { const tflite::Model* model = tflite::testing::GetComplexMockModel(); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; tflite::AllOpsResolver op_resolver = tflite::testing::GetOpResolver(); tflite::NodeAndRegistration* node_and_registration; constexpr size_t arena_size = 2048; @@ -401,7 +412,8 @@ TF_LITE_MICRO_TEST(TestAllocationForComplexModelAllocation) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); size_t model_tensor_size = tflite::testing::GetModelTensorCount(model); TF_LITE_MICRO_EXPECT_EQ(static_cast(10), model_tensor_size); @@ -463,6 +475,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerBranchesAllOnline) { nbr_tensors, metadata_buffer, node_list, num_conns); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; + constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; tflite::MicroAllocator* allocator = @@ -473,7 +487,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerBranchesAllOnline) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); // Since all of the tensors are online planned and the model structure is // identical to that in TestAllocationForModelsWithBranches, @@ -524,6 +539,7 @@ TF_LITE_MICRO_TEST(OfflinePlannerBasic) { nbr_tensors, metadata_buffer, node_list, num_conns); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; tflite::MicroAllocator* allocator = @@ -534,7 +550,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerBasic) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); uint8_t* start = eval_tensors[0].data.uint8; TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[0].data.uint8 - start); @@ -577,6 +594,7 @@ TF_LITE_MICRO_TEST(OfflinePlannerOverlappingAllocation) { nbr_tensors, metadata_buffer, node_list, num_conns); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; tflite::MicroAllocator* allocator = @@ -587,7 +605,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerOverlappingAllocation) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); uint8_t* start = eval_tensors[0].data.uint8; TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[0].data.uint8 - start); @@ -633,6 +652,7 @@ TF_LITE_MICRO_TEST(OfflinePlannerOfflineOnline) { nbr_tensors, metadata_buffer, node_list, num_conns); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; tflite::MicroAllocator* allocator = @@ -643,7 +663,8 @@ TF_LITE_MICRO_TEST(OfflinePlannerOfflineOnline) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); uint8_t* start = eval_tensors[0].data.uint8; TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[0].data.uint8 - start); @@ -775,6 +796,7 @@ TF_LITE_MICRO_TEST(TestOperatorInputsNotInSubgraphInputs) { 1 /* only first tensor (t0) is in subgraph input list*/); TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; constexpr size_t arena_size = 4096; uint8_t arena[arena_size]; tflite::MicroAllocator* allocator = @@ -785,7 +807,8 @@ TF_LITE_MICRO_TEST(TestOperatorInputsNotInSubgraphInputs) { allocator->StartModelAllocation(model, op_resolver, &node_and_registration, &eval_tensors)); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors)); + kTfLiteOk, allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles)); uint8_t* start = eval_tensors[0].data.uint8; TF_LITE_MICRO_EXPECT_EQ(0, eval_tensors[0].data.uint8 - start); diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index a17e40fe0c7..8b003d8b829 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,31 +59,13 @@ TfLiteStatus ContextHelper::RequestScratchBufferInArena(TfLiteContext* ctx, size_t bytes, int* buffer_idx) { ContextHelper* helper = reinterpret_cast(ctx->impl_); - - // We can not forward the scratch buffer request to the allocator yet, - // otherwise the scratch buffer handles will ruin the data in `temp` section. - // These requests will be processed once the `temp` section is deallocated, - // i.e. after a node has been prepared. - - if (helper->scratch_buffer_count_ >= kMaxScratchBuffersPerOp) { - TF_LITE_REPORT_ERROR( - helper->error_reporter_, - "Node %d is allocating too many scratch buffers per op, max=%d", - helper->current_node_idx_, helper->scratch_buffer_count_); - } - helper->scrach_buffer_sizes_[helper->scratch_buffer_count_] = bytes; - // buffer_idx is 0 indexed. - *buffer_idx = helper->scratch_buffer_count_ + - helper->allocator_->GetScratchBufferCount(); - helper->scratch_buffer_count_++; - return kTfLiteOk; + return helper->allocator_->RequestScratchBufferInArena(bytes, buffer_idx); } void* ContextHelper::GetScratchBuffer(TfLiteContext* ctx, int buffer_idx) { ContextHelper* helper = reinterpret_cast(ctx->impl_); - - return helper->allocator_->GetScratchBuffer(helper->scratch_buffer_handles_, - buffer_idx); + ScratchBufferHandle* handle = helper->scratch_buffer_handles_ + buffer_idx; + return handle->data; } void ContextHelper::ReportOpError(struct TfLiteContext* context, @@ -110,37 +92,13 @@ TfLiteEvalTensor* ContextHelper::GetEvalTensor( return &helper->eval_tensors_[tensor_idx]; } -void ContextHelper::SetNodeIndex(int idx) { - if (scratch_buffer_count_ != 0) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Internal error: Please commit scratch buffers " - "befrore moving to the next node"); - } - current_node_idx_ = idx; -} - void ContextHelper::SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors) { eval_tensors_ = eval_tensors; } -void ContextHelper::SetScratchBufferHandles(void* scratch_buffer_handle) { - scratch_buffer_handles_ = scratch_buffer_handle; -} - -TfLiteStatus ContextHelper::CommitScratchBuffers() { - size_t initial_buffer_count = allocator_->GetScratchBufferCount(); - for (size_t i = 0; i < scratch_buffer_count_; i++) { - int buffer_id; - allocator_->RequestScratchBufferInArena( - current_node_idx_, scrach_buffer_sizes_[i], &buffer_id); - if (static_cast(buffer_id) != initial_buffer_count + i) { - TF_LITE_REPORT_ERROR( - error_reporter_, - "Internal error. Scratch buffers are not contiguous.\n"); - } - } - scratch_buffer_count_ = 0; - return kTfLiteOk; +void ContextHelper::SetScratchBufferHandles( + ScratchBufferHandle* scratch_buffer_handles) { + scratch_buffer_handles_ = scratch_buffer_handles; } } // namespace internal @@ -304,7 +262,6 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { context_.GetScratchBuffer = nullptr; for (size_t i = 0; i < subgraph_->operators()->size(); ++i) { - context_helper_.SetNodeIndex(i); auto* node = &(node_and_registrations_[i].node); auto* registration = node_and_registrations_[i].registration; size_t init_data_size; @@ -321,15 +278,12 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { registration->init(&context_, init_data, init_data_size); } } - context_helper_.SetNodeIndex(-1); // Both AllocatePersistentBuffer and RequestScratchBufferInArena is // available in Prepare stage. context_.RequestScratchBufferInArena = context_helper_.RequestScratchBufferInArena; for (size_t i = 0; i < subgraph_->operators()->size(); ++i) { - // Set node idx to annotate the lifetime for scratch buffers. - context_helper_.SetNodeIndex(i); auto* node = &(node_and_registrations_[i].node); auto* registration = node_and_registrations_[i].registration; if (registration->prepare) { @@ -342,10 +296,8 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { return kTfLiteError; } } - allocator_.ResetTempAllocations(); - context_helper_.CommitScratchBuffers(); + allocator_.FinishPrepareNodeAllocations(/*node_id=*/i); } - context_helper_.SetNodeIndex(-1); // Prepare is done, we're ready for Invoke. Memory allocation is no longer // allowed. Kernels can only fetch scratch buffers via GetScratchBuffer. @@ -353,12 +305,12 @@ TfLiteStatus MicroInterpreter::AllocateTensors() { context_.RequestScratchBufferInArena = nullptr; context_.GetScratchBuffer = context_helper_.GetScratchBuffer; - void* scratch_buffer_handles = nullptr; - TF_LITE_ENSURE_OK(&context_, allocator_.FinishModelAllocation(model_, eval_tensors_, - &scratch_buffer_handles)); - context_helper_.SetScratchBufferHandles(scratch_buffer_handles); + &scratch_buffer_handles_)); + // TODO(b/16157777): Remove this when ContextHelper is rolled into this class. + context_helper_.SetScratchBufferHandles(scratch_buffer_handles_); + TF_LITE_ENSURE_STATUS(ResetVariableTensors()); tensors_allocated_ = true; @@ -456,8 +408,8 @@ TfLiteTensor* MicroInterpreter::output(size_t index) { outputs().Get(index)); } if (output_tensor_ == nullptr) { - // TODO(b/160894903): This API will allocate TfLiteTensor structs from - // persistent (tail) memory and cache on this pointer. + // TODO(b/162311891): Drop these allocations when the interpreter supports + // handling buffers from TfLiteEvalTensor. output_tensor_ = allocator_.AllocatePersistentTfLiteTensor( model_, eval_tensors_, outputs().Get(index)); } diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index f36d9d80f96..31720c8e82e 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,8 +32,6 @@ namespace tflite { namespace internal { -constexpr size_t kMaxScratchBuffersPerOp = 8; - // A helper class to encapsulate the implementation of APIs in Context. // context->impl_ points to an instance of this class. // Check tensorflow/lite/c/common.h for detailed descriptions. @@ -55,28 +53,19 @@ class ContextHelper { int tensor_idx); static TfLiteEvalTensor* GetEvalTensor(const struct TfLiteContext* context, int tensor_idx); - // Commits all scratch buffer allocations to MicroAllocator. - TfLiteStatus CommitScratchBuffers(); - - // Sets the current node index to assist with scratch buffer allocations: - void SetNodeIndex(int idx); // Sets the pointer to a list of TfLiteEvalTensor instances. void SetTfLiteEvalTensors(TfLiteEvalTensor* eval_tensors); - // Sets the pointer to scratch buffer handle, which is needed by - // `GetScratchBuffer`. - void SetScratchBufferHandles(void* scratch_buffer_handle); + + // Sets the pointer to a list of ScratchBufferHandle instances. + void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles); private: MicroAllocator* allocator_ = nullptr; ErrorReporter* error_reporter_ = nullptr; const Model* model_ = nullptr; TfLiteEvalTensor* eval_tensors_ = nullptr; - void* scratch_buffer_handles_ = nullptr; - int current_node_idx_ = -1; - - size_t scrach_buffer_sizes_[kMaxScratchBuffersPerOp]; - size_t scratch_buffer_count_ = 0; + ScratchBufferHandle* scratch_buffer_handles_ = nullptr; }; } // namespace internal @@ -204,12 +193,15 @@ class MicroInterpreter { TfLiteStatus initialization_status_; - const SubGraph* subgraph_; - TfLiteEvalTensor* eval_tensors_; + const SubGraph* subgraph_ = nullptr; + TfLiteEvalTensor* eval_tensors_ = nullptr; + ScratchBufferHandle* scratch_buffer_handles_ = nullptr; + + // TODO(b/16157777): Drop this reference: internal::ContextHelper context_helper_; - // TODO(b/160894903): Clean these pointers up when all APIs are updated to new - // TfLiteEvalTensor buffers. + // TODO(b/162311891): Clean these pointers up when this class supports buffers + // from TfLiteEvalTensor. TfLiteTensor* input_tensor_; TfLiteTensor* output_tensor_; }; diff --git a/tensorflow/lite/micro/micro_interpreter_test.cc b/tensorflow/lite/micro/micro_interpreter_test.cc index 0d8b61a532a..07cc4cabf4b 100644 --- a/tensorflow/lite/micro/micro_interpreter_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_test.cc @@ -400,11 +400,15 @@ TF_LITE_MICRO_TEST(TestIncompleteInitializationAllocationsWithSmallArena) { // Interpreter fails because arena is too small: TF_LITE_MICRO_EXPECT_EQ(interpreter.Invoke(), kTfLiteError); + // The head will have some allocations because scratch buffer requests are + // stored in the head until memory plan is fully committed (e.g. model has to + // successfully allocate first). + TF_LITE_MICRO_EXPECT_EQ( + static_cast(128), + allocator->GetSimpleMemoryAllocator()->GetHeadUsedBytes()); + // Ensure allocations are zero (ignore tail since some internal structs are // initialized with this space): - TF_LITE_MICRO_EXPECT_EQ( - static_cast(0), - allocator->GetSimpleMemoryAllocator()->GetHeadUsedBytes()); TF_LITE_MICRO_EXPECT_EQ( static_cast(0), allocator diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index de1bfd00472..0175c8dbd6a 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/fully_connected.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -150,8 +151,7 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddConv2D() { - return AddBuiltin(BuiltinOperator_CONV_2D, - tflite::ops::micro::Register_CONV_2D(), ParseConv2D); + return AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), ParseConv2D); } TfLiteStatus AddCos() { @@ -161,8 +161,7 @@ class MicroMutableOpResolver : public MicroOpResolver { TfLiteStatus AddDepthwiseConv2D() { return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), - ParseDepthwiseConv2D); + Register_DEPTHWISE_CONV_2D(), ParseDepthwiseConv2D); } TfLiteStatus AddDequantize() { @@ -181,9 +180,9 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::ops::micro::Register_FLOOR(), ParseFloor); } - TfLiteStatus AddFullyConnected() { - return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED(), + TfLiteStatus AddFullyConnected( + const TfLiteRegistration& registration = Register_FULLY_CONNECTED()) { + return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, registration, ParseFullyConnected); } @@ -305,8 +304,8 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddQuantize() { - return AddBuiltin(BuiltinOperator_QUANTIZE, - tflite::ops::micro::Register_QUANTIZE(), ParseQuantize); + return AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(), + ParseQuantize); } TfLiteStatus AddReduceMax() { @@ -345,14 +344,18 @@ class MicroMutableOpResolver : public MicroOpResolver { tflite::ops::micro::Register_RSQRT(), ParseRsqrt); } + TfLiteStatus AddShape() { + return AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE(), ParseShape); + } + TfLiteStatus AddSin() { return AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN(), ParseSin); } TfLiteStatus AddSoftmax() { - return AddBuiltin(BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX(), ParseSoftmax); + return AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(), + ParseSoftmax); } TfLiteStatus AddSplit() { @@ -387,8 +390,7 @@ class MicroMutableOpResolver : public MicroOpResolver { } TfLiteStatus AddSvdf() { - return AddBuiltin(BuiltinOperator_SVDF, tflite::ops::micro::Register_SVDF(), - ParseSvdf); + return AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(), ParseSvdf); } TfLiteStatus AddTanh() { @@ -404,6 +406,8 @@ class MicroMutableOpResolver : public MicroOpResolver { unsigned int GetRegistrationLength() { return registrations_len_; } private: + TF_LITE_REMOVE_VIRTUAL_DELETE + TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, const TfLiteRegistration& registration, MicroOpResolver::BuiltinParseFunction parser) { @@ -459,8 +463,6 @@ class MicroMutableOpResolver : public MicroOpResolver { unsigned int num_buitin_ops_ = 0; ErrorReporter* error_reporter_; - - TF_LITE_REMOVE_VIRTUAL_DELETE }; }; // namespace tflite diff --git a/tensorflow/lite/micro/recording_micro_allocator.cc b/tensorflow/lite/micro/recording_micro_allocator.cc index 7e11523fea0..6bb52974ebc 100644 --- a/tensorflow/lite/micro/recording_micro_allocator.cc +++ b/tensorflow/lite/micro/recording_micro_allocator.cc @@ -54,6 +54,8 @@ RecordedAllocation RecordingMicroAllocator::GetRecordedAllocation( return recorded_persistent_tflite_tensor_data_; case RecordedAllocationType::kPersistentTfLiteTensorQuantizationData: return recorded_persistent_tflite_tensor_quantization_data_; + case RecordedAllocationType::kPersistentBufferData: + return recorded_persistent_buffer_data_; case RecordedAllocationType::kTfLiteTensorVariableBufferData: return recorded_tflite_tensor_variable_buffer_data_; case RecordedAllocationType::kNodeAndRegistrationArray: @@ -91,6 +93,8 @@ void RecordingMicroAllocator::PrintAllocations() const { PrintRecordedAllocation( RecordedAllocationType::kPersistentTfLiteTensorQuantizationData, "Persistent TfLiteTensor quantization data", "allocations"); + PrintRecordedAllocation(RecordedAllocationType::kPersistentBufferData, + "Persistent buffer data", "allocations"); PrintRecordedAllocation( RecordedAllocationType::kTfLiteTensorVariableBufferData, "TfLiteTensor variable buffer data", "allocations"); @@ -101,17 +105,27 @@ void RecordingMicroAllocator::PrintAllocations() const { "Operator runtime data", "OpData structs"); } +void* RecordingMicroAllocator::AllocatePersistentBuffer(size_t bytes) { + RecordedAllocation allocations = SnapshotAllocationUsage(); + void* buffer = MicroAllocator::AllocatePersistentBuffer(bytes); + RecordAllocationUsage(allocations, recorded_persistent_buffer_data_); + + return buffer; +} + void RecordingMicroAllocator::PrintRecordedAllocation( RecordedAllocationType allocation_type, const char* allocation_name, const char* allocation_description) const { #ifndef TF_LITE_STRIP_ERROR_STRINGS RecordedAllocation allocation = GetRecordedAllocation(allocation_type); - TF_LITE_REPORT_ERROR( - error_reporter(), - "[RecordingMicroAllocator] '%s' used %d bytes with alignment overhead " - "(requested %d bytes for %d %s)", - allocation_name, allocation.used_bytes, allocation.requested_bytes, - allocation.count, allocation_description); + if (allocation.used_bytes > 0 || allocation.requested_bytes > 0) { + TF_LITE_REPORT_ERROR( + error_reporter(), + "[RecordingMicroAllocator] '%s' used %d bytes with alignment overhead " + "(requested %d bytes for %d %s)", + allocation_name, allocation.used_bytes, allocation.requested_bytes, + allocation.count, allocation_description); + } #endif } diff --git a/tensorflow/lite/micro/recording_micro_allocator.h b/tensorflow/lite/micro/recording_micro_allocator.h index 9243fec12e5..47246e17540 100644 --- a/tensorflow/lite/micro/recording_micro_allocator.h +++ b/tensorflow/lite/micro/recording_micro_allocator.h @@ -24,10 +24,12 @@ namespace tflite { // List of buckets currently recorded by this class. Each type keeps a list of // allocated information during model initialization. +// TODO(b/169834511): Add tracking for scratch buffer allocations. enum class RecordedAllocationType { kTfLiteEvalTensorData, kPersistentTfLiteTensorData, kPersistentTfLiteTensorQuantizationData, + kPersistentBufferData, kTfLiteTensorVariableBufferData, kNodeAndRegistrationArray, kOpData, @@ -66,6 +68,8 @@ class RecordingMicroAllocator : public MicroAllocator { // defined in RecordedAllocationType. void PrintAllocations() const; + void* AllocatePersistentBuffer(size_t bytes) override; + protected: TfLiteStatus AllocateNodeAndRegistrations( const Model* model, @@ -77,12 +81,12 @@ class RecordingMicroAllocator : public MicroAllocator { const Model* model, TfLiteEvalTensor** eval_tensors) override; TfLiteStatus AllocateVariables(const SubGraph* subgraph, TfLiteEvalTensor* eval_tensors) override; - // TODO(b/160894903): Once all kernels have been updated to the new API drop + // TODO(b/162311891): Once all kernels have been updated to the new API drop // this method. It is only used to record TfLiteTensor persistent allocations. TfLiteTensor* AllocatePersistentTfLiteTensorInternal( const Model* model, TfLiteEvalTensor* eval_tensors, int tensor_index) override; - // TODO(b/160894903): Once all kernels have been updated to the new API drop + // TODO(b/162311891): Once all kernels have been updated to the new API drop // this function since all allocations for quantized data will take place in // the temp section. TfLiteStatus PopulateTfLiteTensorFromFlatbuffer(const Model* model, @@ -108,6 +112,7 @@ class RecordingMicroAllocator : public MicroAllocator { RecordedAllocation recorded_tflite_eval_tensor_data_ = {}; RecordedAllocation recorded_persistent_tflite_tensor_data_ = {}; RecordedAllocation recorded_persistent_tflite_tensor_quantization_data_ = {}; + RecordedAllocation recorded_persistent_buffer_data_ = {}; RecordedAllocation recorded_tflite_tensor_variable_buffer_data_ = {}; RecordedAllocation recorded_node_and_registration_array_data_ = {}; RecordedAllocation recorded_op_data_ = {}; diff --git a/tensorflow/lite/micro/recording_micro_allocator_test.cc b/tensorflow/lite/micro/recording_micro_allocator_test.cc index f46bd29abdd..6854341de90 100644 --- a/tensorflow/lite/micro/recording_micro_allocator_test.cc +++ b/tensorflow/lite/micro/recording_micro_allocator_test.cc @@ -36,6 +36,7 @@ TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(TestRecordsTfLiteEvalTensorArrayData) { TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; tflite::AllOpsResolver all_ops_resolver; tflite::NodeAndRegistration* node_and_registration; const tflite::Model* model = tflite::GetModel(kTestConvModelData); @@ -55,7 +56,8 @@ TF_LITE_MICRO_TEST(TestRecordsTfLiteEvalTensorArrayData) { TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; - status = micro_allocator->FinishModelAllocation(model, eval_tensors); + status = micro_allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles); TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; @@ -78,6 +80,7 @@ TF_LITE_MICRO_TEST(TestRecordsTfLiteEvalTensorArrayData) { TF_LITE_MICRO_TEST(TestRecordsNodeAndRegistrationArrayData) { TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; tflite::AllOpsResolver all_ops_resolver; tflite::NodeAndRegistration* node_and_registration; const tflite::Model* model = tflite::GetModel(kTestConvModelData); @@ -95,7 +98,8 @@ TF_LITE_MICRO_TEST(TestRecordsNodeAndRegistrationArrayData) { TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; - status = micro_allocator->FinishModelAllocation(model, eval_tensors); + status = micro_allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles); TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; @@ -112,6 +116,7 @@ TF_LITE_MICRO_TEST(TestRecordsNodeAndRegistrationArrayData) { TF_LITE_MICRO_TEST(TestRecordsMultiTenantAllocations) { TfLiteEvalTensor* eval_tensors = nullptr; + tflite::ScratchBufferHandle* scratch_buffer_handles = nullptr; tflite::AllOpsResolver all_ops_resolver; tflite::NodeAndRegistration* node_and_registration; const tflite::Model* model = tflite::GetModel(kTestConvModelData); @@ -133,7 +138,8 @@ TF_LITE_MICRO_TEST(TestRecordsMultiTenantAllocations) { TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; - status = micro_allocator->FinishModelAllocation(model, eval_tensors); + status = micro_allocator->FinishModelAllocation(model, eval_tensors, + &scratch_buffer_handles); TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; @@ -143,8 +149,8 @@ TF_LITE_MICRO_TEST(TestRecordsMultiTenantAllocations) { TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; - status = kTfLiteOk, - micro_allocator->FinishModelAllocation(model, eval_tensors); + status = kTfLiteOk, micro_allocator->FinishModelAllocation( + model, eval_tensors, &scratch_buffer_handles); TF_LITE_MICRO_EXPECT_EQ(status, kTfLiteOk); if (status != kTfLiteOk) return 1; @@ -233,6 +239,43 @@ TF_LITE_MICRO_TEST(TestRecordsPersistentTfLiteTensorQuantizationData) { expected_requested_bytes); } +TF_LITE_MICRO_TEST(TestRecordsPersistentBufferData) { + uint8_t arena[kTestConvArenaSize]; + + tflite::RecordingMicroAllocator* micro_allocator = + tflite::RecordingMicroAllocator::Create(arena, kTestConvArenaSize, + micro_test::reporter); + TF_LITE_MICRO_EXPECT_NE(micro_allocator, nullptr); + if (micro_allocator == nullptr) return 1; + + void* buffer = micro_allocator->AllocatePersistentBuffer(/*bytes=*/100); + TF_LITE_MICRO_EXPECT_NE(buffer, nullptr); + if (buffer == nullptr) return 1; + + tflite::RecordedAllocation recorded_allocation = + micro_allocator->GetRecordedAllocation( + tflite::RecordedAllocationType::kPersistentBufferData); + + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.count, static_cast(1)); + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.requested_bytes, + static_cast(100)); + TF_LITE_MICRO_EXPECT_GE(recorded_allocation.used_bytes, + static_cast(100)); + + buffer = micro_allocator->AllocatePersistentBuffer(/*bytes=*/50); + TF_LITE_MICRO_EXPECT_NE(buffer, nullptr); + if (buffer == nullptr) return 1; + + recorded_allocation = micro_allocator->GetRecordedAllocation( + tflite::RecordedAllocationType::kPersistentBufferData); + + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.count, static_cast(2)); + TF_LITE_MICRO_EXPECT_EQ(recorded_allocation.requested_bytes, + static_cast(150)); + TF_LITE_MICRO_EXPECT_GE(recorded_allocation.used_bytes, + static_cast(150)); +} + // TODO(b/158124094): Find a way to audit OpData allocations on // cross-architectures. diff --git a/tensorflow/lite/micro/recording_simple_memory_allocator.cc b/tensorflow/lite/micro/recording_simple_memory_allocator.cc index ef2e9f31664..ef30aca4949 100644 --- a/tensorflow/lite/micro/recording_simple_memory_allocator.cc +++ b/tensorflow/lite/micro/recording_simple_memory_allocator.cc @@ -57,12 +57,13 @@ size_t RecordingSimpleMemoryAllocator::GetAllocatedCount() const { return alloc_count_; } -TfLiteStatus RecordingSimpleMemoryAllocator::EnsureHeadSize(size_t size, - size_t alignment) { - const uint8_t* previous_head = GetHead(); - TfLiteStatus status = SimpleMemoryAllocator::EnsureHeadSize(size, alignment); +TfLiteStatus RecordingSimpleMemoryAllocator::SetHeadBufferSize( + size_t size, size_t alignment) { + const uint8_t* previous_head = head(); + TfLiteStatus status = + SimpleMemoryAllocator::SetHeadBufferSize(size, alignment); if (status == kTfLiteOk) { - used_bytes_ += GetHead() - previous_head; + used_bytes_ += head() - previous_head; requested_head_bytes_ = size; } return status; @@ -70,10 +71,10 @@ TfLiteStatus RecordingSimpleMemoryAllocator::EnsureHeadSize(size_t size, uint8_t* RecordingSimpleMemoryAllocator::AllocateFromTail(size_t size, size_t alignment) { - const uint8_t* previous_tail = GetTail(); + const uint8_t* previous_tail = tail(); uint8_t* result = SimpleMemoryAllocator::AllocateFromTail(size, alignment); if (result != nullptr) { - used_bytes_ += previous_tail - GetTail(); + used_bytes_ += previous_tail - tail(); requested_tail_bytes_ += size; alloc_count_++; } diff --git a/tensorflow/lite/micro/recording_simple_memory_allocator.h b/tensorflow/lite/micro/recording_simple_memory_allocator.h index 8d3e9fb49d4..3526716e3b4 100644 --- a/tensorflow/lite/micro/recording_simple_memory_allocator.h +++ b/tensorflow/lite/micro/recording_simple_memory_allocator.h @@ -47,7 +47,7 @@ class RecordingSimpleMemoryAllocator : public SimpleMemoryAllocator { // Returns the number of alloc calls from the head or tail. size_t GetAllocatedCount() const; - TfLiteStatus EnsureHeadSize(size_t size, size_t alignment) override; + TfLiteStatus SetHeadBufferSize(size_t size, size_t alignment) override; uint8_t* AllocateFromTail(size_t size, size_t alignment) override; private: diff --git a/tensorflow/lite/micro/recording_simple_memory_allocator_test.cc b/tensorflow/lite/micro/recording_simple_memory_allocator_test.cc index 6450cb53cac..910c991978d 100644 --- a/tensorflow/lite/micro/recording_simple_memory_allocator_test.cc +++ b/tensorflow/lite/micro/recording_simple_memory_allocator_test.cc @@ -84,7 +84,7 @@ TF_LITE_MICRO_TEST(TestRecordsHeadSizeAdjustment) { arena_size); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/5, /*alignment=*/1)); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/5, /*alignment=*/1)); TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), static_cast(5)); TF_LITE_MICRO_EXPECT_EQ(allocator.GetRequestedBytes(), static_cast(5)); @@ -108,7 +108,7 @@ TF_LITE_MICRO_TEST(TestRecordsMisalignedHeadSizeAdjustments) { arena_size); TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/10, /*alignment=*/12)); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/10, /*alignment=*/12)); // Validate used bytes in 8 byte range that can included alignment of 12: TF_LITE_MICRO_EXPECT_GE(allocator.GetUsedBytes(), static_cast(10)); TF_LITE_MICRO_EXPECT_LE(allocator.GetUsedBytes(), static_cast(20)); @@ -125,8 +125,8 @@ TF_LITE_MICRO_TEST(TestDoesNotRecordFailedTailAllocations) { tflite::RecordingSimpleMemoryAllocator allocator(micro_test::reporter, arena, arena_size); - TF_LITE_MICRO_EXPECT_EQ( - kTfLiteError, allocator.EnsureHeadSize(/*size=*/2048, /*alignment=*/1)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, allocator.SetHeadBufferSize( + /*size=*/2048, /*alignment=*/1)); TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), static_cast(0)); TF_LITE_MICRO_EXPECT_EQ(allocator.GetRequestedBytes(), static_cast(0)); diff --git a/tensorflow/lite/micro/simple_memory_allocator.cc b/tensorflow/lite/micro/simple_memory_allocator.cc index bea1a9d7175..08b6789e911 100644 --- a/tensorflow/lite/micro/simple_memory_allocator.cc +++ b/tensorflow/lite/micro/simple_memory_allocator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -60,27 +60,22 @@ SimpleMemoryAllocator* SimpleMemoryAllocator::Create( SimpleMemoryAllocator::~SimpleMemoryAllocator() {} -TfLiteStatus SimpleMemoryAllocator::EnsureHeadSize(size_t size, - size_t alignment) { +TfLiteStatus SimpleMemoryAllocator::SetHeadBufferSize(size_t size, + size_t alignment) { if (head_ != temp_) { TF_LITE_REPORT_ERROR( error_reporter_, - "Internal error: EnsureHeadSize() needs to be called after" - "ResetTempAllocations()."); + "Internal error: SetHeadBufferSize() needs to be called " + "after ResetTempAllocations()."); return kTfLiteError; } uint8_t* const aligned_result = AlignPointerUp(buffer_head_, alignment); - if (aligned_result + size < head_) { - // Size is below the current head size, just return. - return kTfLiteOk; - } - const size_t available_memory = tail_ - aligned_result; if (available_memory < size) { TF_LITE_REPORT_ERROR( error_reporter_, - "Failed to adjust head size. Requested: %u, available %u, missing: %u", + "Failed to set head size. Requested: %u, available %u, missing: %u", size, available_memory, size - available_memory); return kTfLiteError; } @@ -123,11 +118,7 @@ uint8_t* SimpleMemoryAllocator::AllocateTemp(size_t size, size_t alignment) { void SimpleMemoryAllocator::ResetTempAllocations() { temp_ = head_; } -uint8_t* SimpleMemoryAllocator::GetHead() const { return head_; } - -uint8_t* SimpleMemoryAllocator::GetBufferHead() const { return buffer_head_; } - -uint8_t* SimpleMemoryAllocator::GetTail() const { return tail_; } +uint8_t* SimpleMemoryAllocator::GetHeadBuffer() const { return buffer_head_; } size_t SimpleMemoryAllocator::GetHeadUsedBytes() const { return head_ - buffer_head_; @@ -138,17 +129,21 @@ size_t SimpleMemoryAllocator::GetTailUsedBytes() const { } size_t SimpleMemoryAllocator::GetAvailableMemory(size_t alignment) const { - uint8_t* const aligned_head = AlignPointerUp(head_, alignment); + uint8_t* const aligned_temp = AlignPointerUp(temp_, alignment); uint8_t* const aligned_tail = AlignPointerDown(tail_, alignment); - return aligned_tail - aligned_head; + return aligned_tail - aligned_temp; } size_t SimpleMemoryAllocator::GetUsedBytes() const { - return GetBufferSize() - (tail_ - head_); + return GetBufferSize() - (tail_ - temp_); } size_t SimpleMemoryAllocator::GetBufferSize() const { return buffer_tail_ - buffer_head_; } +uint8_t* SimpleMemoryAllocator::head() const { return head_; } + +uint8_t* SimpleMemoryAllocator::tail() const { return tail_; } + } // namespace tflite diff --git a/tensorflow/lite/micro/simple_memory_allocator.h b/tensorflow/lite/micro/simple_memory_allocator.h index 8c216f47848..35adaf1ad24 100644 --- a/tensorflow/lite/micro/simple_memory_allocator.h +++ b/tensorflow/lite/micro/simple_memory_allocator.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,13 +43,13 @@ class SimpleMemoryAllocator { uint8_t* buffer_head, size_t buffer_size); - // Ensure that the head (lowest address and moving upwards) memory allocation - // is at least a given size. This function will only increase the head size if - // the passed in value is larger than the current head size. Calls to this - // method will also invalidate all temporary allocation values. This call will - // fail if a chain of allocations through AllocateTemp() have not been cleaned - // up with a call to ResetTempAllocations(). - virtual TfLiteStatus EnsureHeadSize(size_t size, size_t alignment); + // Adjusts the head (lowest address and moving upwards) memory allocation to a + // given size. Calls to this method will also invalidate all temporary + // allocation values (it sets the location of temp space at the end of the + // head section). This call will fail if a chain of allocations through + // AllocateTemp() have not been cleaned up with a call to + // ResetTempAllocations(). + virtual TfLiteStatus SetHeadBufferSize(size_t size, size_t alignment); // Allocates memory starting at the tail of the arena (highest address and // moving downwards). @@ -69,18 +69,31 @@ class SimpleMemoryAllocator { // arena (lowest address). virtual void ResetTempAllocations(); - uint8_t* GetHead() const; - uint8_t* GetBufferHead() const; - uint8_t* GetTail() const; + // Returns a pointer to the buffer currently assigned to the head section. + // This buffer is set by calling SetHeadSize(). + uint8_t* GetHeadBuffer() const; + // Returns the size of the head section in bytes. size_t GetHeadUsedBytes() const; + + // Returns the size of all allocations in the tail section in bytes. size_t GetTailUsedBytes() const; - // Returns the number of bytes available with a given alignment. + // Returns the number of bytes available with a given alignment. This number + // takes in account any temporary allocations. size_t GetAvailableMemory(size_t alignment) const; + // Returns the number of used bytes in the allocator. This number takes in + // account any temporary allocations. size_t GetUsedBytes() const; + protected: + // Returns a pointer to the current end of the head buffer. + uint8_t* head() const; + + // Returns a pointer to the current end of the tail buffer. + uint8_t* tail() const; + private: size_t GetBufferSize() const; diff --git a/tensorflow/lite/micro/simple_memory_allocator_test.cc b/tensorflow/lite/micro/simple_memory_allocator_test.cc index adffc9566da..eea7a7fad86 100644 --- a/tensorflow/lite/micro/simple_memory_allocator_test.cc +++ b/tensorflow/lite/micro/simple_memory_allocator_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,20 +28,20 @@ TF_LITE_MICRO_TEST(TestEnsureHeadSizeSimpleAlignment) { tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, arena_size); - // First head adjustment TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/100, /*alignment=*/1)); - TF_LITE_MICRO_EXPECT(arena + 100 == allocator.GetHead()); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/100, /*alignment=*/1)); + TF_LITE_MICRO_EXPECT_EQ(static_cast(100), + allocator.GetHeadUsedBytes()); - // Second head adjusment is smaller, head size should still be 100. TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/10, /*alignment=*/1)); - TF_LITE_MICRO_EXPECT(arena + 100 == allocator.GetHead()); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/10, /*alignment=*/1)); + TF_LITE_MICRO_EXPECT_EQ(static_cast(10), + allocator.GetHeadUsedBytes()); - // Third head adjustment re-increases the head size: TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/1000, /*alignment=*/1)); - TF_LITE_MICRO_EXPECT(arena + 1000 == allocator.GetHead()); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/1000, /*alignment=*/1)); + TF_LITE_MICRO_EXPECT_EQ(static_cast(1000), + allocator.GetHeadUsedBytes()); } TF_LITE_MICRO_TEST(TestAdjustHeadSizeMisalignment) { @@ -52,25 +52,22 @@ TF_LITE_MICRO_TEST(TestAdjustHeadSizeMisalignment) { // First head adjustment of 100 bytes (aligned 12): TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/100, /*alignment=*/12)); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/100, /*alignment=*/12)); // Offset alignment of 12 can lead to allocation within 8 byte range of // requested bytes based to arena alignment at runtime: - TF_LITE_MICRO_EXPECT_GE(allocator.GetHead(), arena + 100); - TF_LITE_MICRO_EXPECT_LE(allocator.GetHead(), arena + 100 + 11); + TF_LITE_MICRO_EXPECT_GE(allocator.GetHeadUsedBytes(), 100); + TF_LITE_MICRO_EXPECT_LE(allocator.GetHeadUsedBytes(), 100 + 11); - // Second head adjusment shrinks the head size (aligned at 12), head size - // should still be 100: TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/10, /*alignment=*/12)); - TF_LITE_MICRO_EXPECT_GE(allocator.GetHead(), arena + 100); - TF_LITE_MICRO_EXPECT_LE(allocator.GetHead(), arena + 100 + 11); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/10, /*alignment=*/12)); + TF_LITE_MICRO_EXPECT_GE(allocator.GetHeadUsedBytes(), 10); + TF_LITE_MICRO_EXPECT_LE(allocator.GetHeadUsedBytes(), 100 + 11); - // Third head adjustment re-increases the head size (aligned at 12): TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/1000, /*alignment=*/12)); - TF_LITE_MICRO_EXPECT_GE(allocator.GetHead(), arena + 1000); - TF_LITE_MICRO_EXPECT_LE(allocator.GetHead(), arena + 1000 + 11); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/1000, /*alignment=*/12)); + TF_LITE_MICRO_EXPECT_GE(allocator.GetHeadUsedBytes(), 1000); + TF_LITE_MICRO_EXPECT_LE(allocator.GetHeadUsedBytes(), 1000 + 11); } TF_LITE_MICRO_TEST(TestAdjustHeadSizeMisalignedHandlesCorrectBytesAvailable) { @@ -81,7 +78,7 @@ TF_LITE_MICRO_TEST(TestAdjustHeadSizeMisalignedHandlesCorrectBytesAvailable) { // First head adjustment of 100 bytes (aligned 12): TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/100, /*alignment=*/12)); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/100, /*alignment=*/12)); // allocator.GetAvailableMemory() should also report the actual amount of // memory available based on a requested offset (12): @@ -90,23 +87,92 @@ TF_LITE_MICRO_TEST(TestAdjustHeadSizeMisalignedHandlesCorrectBytesAvailable) { TF_LITE_MICRO_EXPECT_LE(aligned_available_bytes, arena_size - 100); TF_LITE_MICRO_EXPECT_GE(aligned_available_bytes, arena_size - 100 - 24); - // Second head adjusment shrinks the head size (aligned at 12), head size - // should still be 100: TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/10, /*alignment=*/12)); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/10, /*alignment=*/12)); aligned_available_bytes = allocator.GetAvailableMemory(/*alignment=*/12); - TF_LITE_MICRO_EXPECT_LE(aligned_available_bytes, arena_size - 100); - TF_LITE_MICRO_EXPECT_GE(aligned_available_bytes, arena_size - 100 - 24); + TF_LITE_MICRO_EXPECT_LE(aligned_available_bytes, arena_size - 10); + TF_LITE_MICRO_EXPECT_GE(aligned_available_bytes, arena_size - 10 - 24); - // Third head adjustment re-increases the head size (aligned at 12): TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, allocator.EnsureHeadSize(/*size=*/1000, /*alignment=*/12)); + kTfLiteOk, allocator.SetHeadBufferSize(/*size=*/1000, /*alignment=*/12)); aligned_available_bytes = allocator.GetAvailableMemory(/*alignment=*/12); TF_LITE_MICRO_EXPECT_LE(aligned_available_bytes, arena_size - 1000); TF_LITE_MICRO_EXPECT_GE(aligned_available_bytes, arena_size - 1000 - 24); } +TF_LITE_MICRO_TEST(TestGetAvailableMemory) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + + constexpr size_t allocation_size = 100; + allocator.SetHeadBufferSize(/*size=*/allocation_size, + /*alignment=*/1); + allocator.AllocateFromTail(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetAvailableMemory(/*alignment=*/1), + arena_size - allocation_size * 2); +} + +TF_LITE_MICRO_TEST(TestGetAvailableMemoryWithTempAllocations) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + + constexpr size_t allocation_size = 100; + allocator.AllocateTemp(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetAvailableMemory(/*alignment=*/1), + arena_size - allocation_size); + + // Reset temp allocations and ensure GetAvailableMemory() is back to the + // starting size: + allocator.ResetTempAllocations(); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetAvailableMemory(/*alignment=*/1), + arena_size); +} + +TF_LITE_MICRO_TEST(TestGetUsedBytes) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), static_cast(0)); + + constexpr size_t allocation_size = 100; + allocator.SetHeadBufferSize(/*size=*/allocation_size, + /*alignment=*/1); + allocator.AllocateFromTail(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), allocation_size * 2); +} + +TF_LITE_MICRO_TEST(TestGetUsedBytesTempAllocations) { + constexpr size_t arena_size = 1024; + uint8_t arena[arena_size]; + tflite::SimpleMemoryAllocator allocator(micro_test::reporter, arena, + arena_size); + + constexpr size_t allocation_size = 100; + allocator.AllocateTemp(/*size=*/allocation_size, + /*alignment=*/1); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), allocation_size); + + // Reset temp allocations and ensure GetUsedBytes() is back to the starting + // size: + allocator.ResetTempAllocations(); + + TF_LITE_MICRO_EXPECT_EQ(allocator.GetUsedBytes(), static_cast(0)); +} + TF_LITE_MICRO_TEST(TestJustFits) { constexpr size_t arena_size = 1024; uint8_t arena[arena_size]; @@ -190,16 +256,16 @@ TF_LITE_MICRO_TEST(TestEnsureHeadSizeWithoutResettingTemp) { // Adjustment to head should fail since temp allocation was not followed by a // call to ResetTempAllocations(). - TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, allocator.EnsureHeadSize(100, 1)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, allocator.SetHeadBufferSize(100, 1)); allocator.ResetTempAllocations(); // Reduce head size back to zero. - TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator.EnsureHeadSize(0, 1)); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, allocator.SetHeadBufferSize(0, 1)); // The most recent head allocation should be in the same location as the // original temp allocation pointer. - TF_LITE_MICRO_EXPECT(temp == allocator.GetHead()); + TF_LITE_MICRO_EXPECT(temp == allocator.GetHeadBuffer()); } TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 26575a4d98d..82a57890231 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace testing { diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index d80f70a1bbc..57c6c365662 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -19,6 +19,7 @@ limitations under the License. // Useful functions for writing tests. #include +#include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" @@ -192,6 +193,21 @@ TfLiteTensor CreateSymmetricPerChannelQuantizedTensor( // Returns the number of tensors in the default subgraph for a tflite::Model. size_t GetModelTensorCount(const Model* model); +// Derives the quantization scaling factor from a min and max range. +template +inline float ScaleFromMinMax(const float min, const float max) { + return (max - min) / + static_cast((std::numeric_limits::max() * 1.0) - + std::numeric_limits::min()); +} + +// Derives the quantization zero point from a min and max range. +template +inline int ZeroPointFromMinMax(const float min, const float max) { + return static_cast(std::numeric_limits::min()) + + static_cast(-min / ScaleFromMinMax(min, max) + 0.5f); +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 207d500c53d..6b382929923 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -20,12 +20,8 @@ exports_files(["test_linux_binary.sh"]) cc_library( name = "micro_test", - srcs = [ - "test_utils.cc", - ], hdrs = [ "micro_test.h", - "test_utils.h", ], visibility = [ ":micro", diff --git a/tensorflow/lite/micro/testing/test_utils.cc b/tensorflow/lite/micro/testing/test_utils.cc deleted file mode 100644 index 8c16ed70bbc..00000000000 --- a/tensorflow/lite/micro/testing/test_utils.cc +++ /dev/null @@ -1,240 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/micro/testing/test_utils.h" - -#include "tensorflow/lite/micro/simple_memory_allocator.h" - -namespace tflite { -namespace testing { - -namespace { -// TODO(b/141330728): Refactor out of test_utils.cc -// The variables below (and the AllocatePersistentBuffer function) are only -// needed for the kernel tests and benchmarks, i.e. where we do not have an -// interpreter object, and the fully featured MicroAllocator. -// Currently, these need to be sufficient for all the kernel_tests. If that -// becomes problematic, we can investigate allowing the arena_size to be -// specified for each call to PopulatContext. -constexpr size_t kArenaSize = 10000; -uint8_t raw_arena_[kArenaSize]; -SimpleMemoryAllocator* simple_memory_allocator_ = nullptr; -constexpr size_t kBufferAlignment = 16; - -// We store the pointer to the ith scratch buffer to implement the Request/Get -// ScratchBuffer API for the tests. scratch_buffers_[i] will be the ith scratch -// buffer and will still be allocated from within raw_arena_. -constexpr int kNumScratchBuffers = 5; -uint8_t* scratch_buffers_[kNumScratchBuffers]; -int scratch_buffer_count_ = 0; - -// Note that the context parameter in this function is only needed to match the -// signature of TfLiteContext::AllocatePersistentBuffer and isn't needed in the -// implementation because we are assuming a single global -// simple_memory_allocator_ -void* AllocatePersistentBuffer(TfLiteContext* context, size_t bytes) { - TFLITE_DCHECK(simple_memory_allocator_ != nullptr); - return simple_memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); -} - -TfLiteStatus RequestScratchBufferInArena(TfLiteContext* context, size_t bytes, - int* buffer_index) { - TFLITE_DCHECK(simple_memory_allocator_ != nullptr); - TFLITE_DCHECK(buffer_index != nullptr); - - if (scratch_buffer_count_ == kNumScratchBuffers) { - TF_LITE_REPORT_ERROR( - static_cast(context->impl_), - "Exceeded the maximum number of scratch tensors allowed (%d).", - kNumScratchBuffers); - return kTfLiteError; - } - - // For tests, we allocate scratch buffers from the tail and keep them around - // for the lifetime of model. This means that the arena size in the tests will - // be more than what we would have if the scratch buffers could share memory. - scratch_buffers_[scratch_buffer_count_] = - simple_memory_allocator_->AllocateFromTail(bytes, kBufferAlignment); - TFLITE_DCHECK(scratch_buffers_[scratch_buffer_count_] != nullptr); - - *buffer_index = scratch_buffer_count_++; - return kTfLiteOk; -} - -void* GetScratchBuffer(TfLiteContext* context, int buffer_index) { - TFLITE_DCHECK(scratch_buffer_count_ <= kNumScratchBuffers); - if (buffer_index >= scratch_buffer_count_) { - return nullptr; - } - return scratch_buffers_[buffer_index]; -} - -TfLiteTensor* GetTensor(const struct TfLiteContext* context, int subgraph_idx) { - // TODO(b/160894903): Return this value from temp allocated memory. - return &context->tensors[subgraph_idx]; -} - -} // namespace - -uint8_t F2Q(float value, float min, float max) { - int32_t result = ZeroPointFromMinMax(min, max) + - (value / ScaleFromMinMax(min, max)) + 0.5f; - if (result < std::numeric_limits::min()) { - result = std::numeric_limits::min(); - } - if (result > std::numeric_limits::max()) { - result = std::numeric_limits::max(); - } - return result; -} - -// Converts a float value into a signed eight-bit quantized value. -int8_t F2QS(float value, float min, float max) { - return F2Q(value, min, max) + std::numeric_limits::min(); -} - -int32_t F2Q32(float value, float scale) { - double quantized = static_cast(value / scale); - if (quantized > std::numeric_limits::max()) { - quantized = std::numeric_limits::max(); - } else if (quantized < std::numeric_limits::min()) { - quantized = std::numeric_limits::min(); - } - return static_cast(quantized); -} - -// TODO(b/141330728): Move this method elsewhere as part clean up. -void PopulateContext(TfLiteTensor* tensors, int tensors_size, - ErrorReporter* error_reporter, TfLiteContext* context) { - simple_memory_allocator_ = - SimpleMemoryAllocator::Create(error_reporter, raw_arena_, kArenaSize); - TFLITE_DCHECK(simple_memory_allocator_ != nullptr); - scratch_buffer_count_ = 0; - - context->tensors_size = tensors_size; - context->tensors = tensors; - context->impl_ = static_cast(error_reporter); - context->GetExecutionPlan = nullptr; - context->ResizeTensor = nullptr; - context->ReportError = ReportOpError; - context->AddTensors = nullptr; - context->GetNodeAndRegistration = nullptr; - context->ReplaceNodeSubsetsWithDelegateKernels = nullptr; - context->recommended_num_threads = 1; - context->GetExternalContext = nullptr; - context->SetExternalContext = nullptr; - - context->GetTensor = GetTensor; - context->GetEvalTensor = nullptr; - - context->AllocatePersistentBuffer = AllocatePersistentBuffer; - context->RequestScratchBufferInArena = RequestScratchBufferInArena; - context->GetScratchBuffer = GetScratchBuffer; - - for (int i = 0; i < tensors_size; ++i) { - if (context->tensors[i].is_variable) { - ResetVariableTensor(&context->tensors[i]); - } - } -} - -TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims, - float min, float max, bool is_variable) { - TfLiteTensor result; - result.type = kTfLiteUInt8; - result.data.uint8 = const_cast(data); - result.dims = dims; - result.params = {ScaleFromMinMax(min, max), - ZeroPointFromMinMax(min, max)}; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(uint8_t); - result.is_variable = false; - return result; -} - -TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims, - float min, float max, bool is_variable) { - TfLiteTensor result; - result.type = kTfLiteInt8; - result.data.int8 = const_cast(data); - result.dims = dims; - result.params = {ScaleFromMinMax(min, max), - ZeroPointFromMinMax(min, max)}; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(int8_t); - result.is_variable = is_variable; - return result; -} - -TfLiteTensor CreateQuantizedTensor(const float* data, uint8_t* quantized_data, - TfLiteIntArray* dims, bool is_variable) { - TfLiteTensor result; - SymmetricQuantize(data, dims, quantized_data, &result.params.scale); - result.data.uint8 = quantized_data; - result.type = kTfLiteUInt8; - result.dims = dims; - result.params.zero_point = 128; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(uint8_t); - result.is_variable = is_variable; - return result; -} - -TfLiteTensor CreateQuantizedTensor(const float* data, int8_t* quantized_data, - TfLiteIntArray* dims, bool is_variable) { - TfLiteTensor result; - SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale); - result.data.int8 = quantized_data; - result.type = kTfLiteInt8; - result.dims = dims; - result.params.zero_point = 0; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(int8_t); - result.is_variable = is_variable; - return result; -} - -TfLiteTensor CreateQuantizedTensor(const float* data, int16_t* quantized_data, - TfLiteIntArray* dims, bool is_variable) { - TfLiteTensor result; - SignedSymmetricQuantize(data, dims, quantized_data, &result.params.scale); - result.data.i16 = quantized_data; - result.type = kTfLiteInt16; - result.dims = dims; - result.params.zero_point = 0; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(int16_t); - result.is_variable = is_variable; - return result; -} - -TfLiteTensor CreateQuantized32Tensor(const int32_t* data, TfLiteIntArray* dims, - float scale, bool is_variable) { - TfLiteTensor result; - result.type = kTfLiteInt32; - result.data.i32 = const_cast(data); - result.dims = dims; - // Quantized int32_t tensors always have a zero point of 0, since the range of - // int32_t values is large, and because zero point costs extra cycles during - // processing. - result.params = {scale, 0}; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(int32_t); - result.is_variable = is_variable; - return result; -} - -} // namespace testing -} // namespace tflite diff --git a/tensorflow/lite/micro/testing/test_utils.h b/tensorflow/lite/micro/testing/test_utils.h deleted file mode 100644 index 1b5f5c64979..00000000000 --- a/tensorflow/lite/micro/testing/test_utils.h +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_LITE_MICRO_TESTING_TEST_UTILS_H_ -#define TENSORFLOW_LITE_MICRO_TESTING_TEST_UTILS_H_ - -#include -#include -#include - -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/api/tensor_utils.h" -#include "tensorflow/lite/micro/micro_utils.h" -#include "tensorflow/lite/micro/test_helpers.h" -#include "tensorflow/lite/micro/testing/micro_test.h" - -namespace tflite { -namespace testing { - -// Note: These methods are deprecated, do not use. See b/141332970. - -// Derives the quantization range max from scaling factor and zero point. -template -inline float MaxFromZeroPointScale(const int zero_point, const float scale) { - return (std::numeric_limits::max() - zero_point) * scale; -} - -// Derives the quantization range min from scaling factor and zero point. -template -inline float MinFromZeroPointScale(const int zero_point, const float scale) { - return (std::numeric_limits::min() - zero_point) * scale; -} - -// Derives the quantization scaling factor from a min and max range. -template -inline float ScaleFromMinMax(const float min, const float max) { - return (max - min) / - static_cast((std::numeric_limits::max() * 1.0) - - std::numeric_limits::min()); -} - -// Derives the quantization zero point from a min and max range. -template -inline int ZeroPointFromMinMax(const float min, const float max) { - return static_cast(std::numeric_limits::min()) + - static_cast(-min / ScaleFromMinMax(min, max) + 0.5f); -} - -// Converts a float value into an unsigned eight-bit quantized value. -uint8_t F2Q(float value, float min, float max); - -// Converts a float value into a signed eight-bit quantized value. -int8_t F2QS(const float value, const float min, const float max); - -// Converts a float value into a signed thirty-two-bit quantized value. Note -// that values close to max int and min int may see significant error due to -// a lack of floating point granularity for large values. -int32_t F2Q32(const float value, const float scale); - -// TODO(b/141330728): Move this method elsewhere as part clean up. -void PopulateContext(TfLiteTensor* tensors, int tensors_size, - ErrorReporter* error_reporter, TfLiteContext* context); - -TfLiteTensor CreateQuantizedTensor(const uint8_t* data, TfLiteIntArray* dims, - float min, float max, - bool is_variable = false); - -TfLiteTensor CreateQuantizedTensor(const int8_t* data, TfLiteIntArray* dims, - float min, float max, - bool is_variable = false); - -TfLiteTensor CreateQuantizedTensor(const float* data, uint8_t* quantized_data, - TfLiteIntArray* dims, - bool is_variable = false); - -TfLiteTensor CreateQuantizedTensor(const float* data, int8_t* quantized_data, - TfLiteIntArray* dims, - bool is_variable = false); - -TfLiteTensor CreateQuantizedTensor(const float* data, int16_t* quantized_data, - TfLiteIntArray* dims, - bool is_variable = false); - -TfLiteTensor CreateQuantized32Tensor(const int32_t* data, TfLiteIntArray* dims, - float scale, bool is_variable = false); - -template -inline TfLiteTensor CreateTensor(const input_type* data, TfLiteIntArray* dims, - bool is_variable = false) { - TfLiteTensor result; - result.type = tensor_input_type; - result.data.raw = reinterpret_cast(const_cast(data)); - result.dims = dims; - result.allocation_type = kTfLiteMemNone; - result.bytes = ElementCount(*dims) * sizeof(input_type); - result.is_variable = is_variable; - return result; -} - -} // namespace testing -} // namespace tflite - -#endif // TENSORFLOW_LITE_MICRO_TESTING_TEST_UTILS_H_ diff --git a/tensorflow/lite/micro/testing/util_test.cc b/tensorflow/lite/micro/testing/util_test.cc index 261e9f29a25..1720c81ceaf 100644 --- a/tensorflow/lite/micro/testing/util_test.cc +++ b/tensorflow/lite/micro/testing/util_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/testing_helpers_test.cc b/tensorflow/lite/micro/testing_helpers_test.cc index 885bd873b53..7ef669fb9a5 100644 --- a/tensorflow/lite/micro/testing_helpers_test.cc +++ b/tensorflow/lite/micro/testing_helpers_test.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" -#include "tensorflow/lite/micro/testing/test_utils.h" TF_LITE_MICRO_TESTS_BEGIN diff --git a/tensorflow/lite/micro/tools/ci_build/test_all.sh b/tensorflow/lite/micro/tools/ci_build/test_all.sh index e79d0d4d1ad..354d26d9102 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_all.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_all.sh @@ -52,4 +52,7 @@ tensorflow/lite/micro/tools/ci_build/test_stm32f4.sh PRESUBMIT echo "Running Arduino tests at `date`" tensorflow/lite/micro/tools/ci_build/test_arduino.sh +echo "Running cortex_m_gcc_generic tests at `date`" +tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh + echo "Finished all micro tests at `date`" diff --git a/tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh b/tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh new file mode 100755 index 00000000000..596c88965e7 --- /dev/null +++ b/tensorflow/lite/micro/tools/ci_build/test_cortex_m_gcc_generic.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Tests the microcontroller code using a Cortex-M4/M4F platform. + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR=${SCRIPT_DIR}/../../../../.. +cd "${ROOT_DIR}" + +source tensorflow/lite/micro/tools/ci_build/helper_functions.sh + +TARGET=cortex_m_gcc_generic + +# TODO(b/143715361): downloading first to allow for parallel builds. +readable_run make -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F third_party_downloads + +# Build for Cortex-M4 (no FPU) without CMSIS +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4 microlite + +# Build for Cortex-M4F (FPU present) without CMSIS +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} CORTEX_M_CORE=M4F microlite + +# Build for Cortex-M4 (no FPU) with CMSIS +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4 microlite + +# Build for Cortex-M4 (FPU present) with CMSIS +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TAGS=cmsis-nn TARGET=${TARGET} CORTEX_M_CORE=M4F microlite diff --git a/tensorflow/lite/micro/tools/ci_build/test_sparkfun.sh b/tensorflow/lite/micro/tools/ci_build/test_sparkfun.sh index 87ff8537311..88367cd9c11 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_sparkfun.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_sparkfun.sh @@ -24,10 +24,13 @@ cd "${ROOT_DIR}" source tensorflow/lite/micro/tools/ci_build/helper_functions.sh -readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean - TARGET=sparkfun_edge # TODO(b/143715361): downloading first to allow for parallel builds. readable_run make -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} third_party_downloads + +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} build + +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean +readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} TAGS=cmsis-nn build diff --git a/tensorflow/lite/micro/tools/ci_build/test_x86.sh b/tensorflow/lite/micro/tools/ci_build/test_x86.sh index dfd779a485b..844dccbafb7 100755 --- a/tensorflow/lite/micro/tools/ci_build/test_x86.sh +++ b/tensorflow/lite/micro/tools/ci_build/test_x86.sh @@ -37,6 +37,13 @@ readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile BUILD_TYPE=no readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile BUILD_TYPE=release build +# Next, build with TAGS=posix. See b/170223867 for more details. +# TODO(b/168123200): Since the benchmarks can not currently be run as a test, we +# only build with TAGS=posix. Once benchmarks can be run (and are added to make +# test) we should switch to running all the tests with TAGS=posix. +readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean +readable_run make -s -j8 -f tensorflow/lite/micro/tools/make/Makefile build TAGS=posix + # Next, build w/o release so that we can run the tests and get additional # debugging info on failures. readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index f7f2de4f726..fa1aa3f0baf 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -90,17 +90,48 @@ TAG_DEFINES := $(foreach TAG,$(TAGS),-D$(shell echo $(TAG) | tr [a-z] [A-Z] | tr OPTIMIZATION_LEVEL := -O3 -CC_WARNINGS := -Werror -Wsign-compare -Wdouble-promotion \ - -Wshadow -Wunused-variable -Wmissing-field-initializers \ - -Wunused-function -Wswitch -Wvla -# TODO(b/150240249): Add in -fno-rtti once that works for the Xtensa toolchain. -CXXFLAGS := -std=c++11 -Wstrict-aliasing -DTF_LITE_STATIC_MEMORY \ - $(CC_WARNINGS) $(OPTIMIZATION_LEVEL) $(TAG_DEFINES) -CCFLAGS := -DTF_LITE_STATIC_MEMORY $(CC_WARNINGS) $(OPTIMIZATION_LEVEL) \ - $(TAG_DEFINES) +CC_WARNINGS := \ + -Werror \ + -Wsign-compare \ + -Wdouble-promotion \ + -Wshadow \ + -Wunused-variable \ + -Wmissing-field-initializers \ + -Wunused-function \ + -Wswitch \ + -Wvla \ + -Wall \ + -Wextra \ + -Wstrict-aliasing \ + -Wno-unused-parameter + +COMMON_FLAGS := \ + -fno-unwind-tables \ + -ffunction-sections \ + -fdata-sections \ + -fmessage-length=0 \ + -DTF_LITE_STATIC_MEMORY \ + -DTF_LITE_DISABLE_X86_NEON \ + $(OPTIMIZATION_LEVEL) \ + $(CC_WARNINGS) \ + $(TAG_DEFINES) + +CXXFLAGS := \ + -std=c++11 \ + -fno-rtti \ + -fno-exceptions \ + -fno-threadsafe-statics \ + $(COMMON_FLAGS) + +CCFLAGS := \ + -std=c11 \ + $(COMMON_FLAGS) + ARFLAGS := -r -LDFLAGS += -Wl,--fatal-warnings +LDFLAGS += \ + -Wl,--fatal-warnings \ + -Wl,--gc-sections # override these in the makefile.inc for specific compiler targets TARGET_TOOLCHAIN_PREFIX := @@ -194,7 +225,7 @@ tensorflow/lite/core/api/op_resolver.cc \ tensorflow/lite/core/api/tensor_utils.cc \ tensorflow/lite/kernels/internal/quantization_util.cc \ tensorflow/lite/kernels/kernel_util.cc \ -tensorflow/lite/micro/testing/test_utils.cc +tensorflow/lite/schema/schema_utils.cc MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS)) MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_BENCHMARK_SRCS), $(MICROLITE_CC_SRCS)) @@ -270,6 +301,7 @@ tensorflow/lite/kernels/op_macros.h \ tensorflow/lite/kernels/padding.h \ tensorflow/lite/portable_type_to_tflitetype.h \ tensorflow/lite/schema/schema_generated.h \ +tensorflow/lite/schema/schema_utils.h \ tensorflow/lite/version.h # TODO(b/165940489): Figure out how to avoid including fixed point @@ -323,11 +355,18 @@ $(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,)) $(eval $(call add_third_party_download,$(PERSON_MODEL_URL),$(PERSON_MODEL_MD5),person_model_grayscale,)) $(eval $(call add_third_party_download,$(PERSON_MODEL_INT8_URL),$(PERSON_MODEL_INT8_MD5),person_model_int8,)) -# These target-specific makefiles should modify or replace options like -# CXXFLAGS or LIBS to work for a specific targeted architecture. All logic -# based on platforms or architectures should happen within these files, to -# keep this main makefile focused on the sources and dependencies. -include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) +# The target-specific makefile must have a name that is exactly +# TARGET_makefile.inc and is only needed for cross-compilation (i.e. when TARGET +# is different from the HOST_OS). +ifneq ($(TARGET),$(HOST_OS)) + # The arduino is also special in that it does not have an arduino_makefile but + # the TARGET=arduino is still used to create a directory for the generated + # artifacts. We are using a workaround right now and will be separating the + # project generation from the Makefile in the future. + ifneq ($(TARGET),arduino) + include $(MAKEFILE_DIR)/targets/$(TARGET)_makefile.inc + endif +endif # Load dependencies for optimized kernel implementations. include $(wildcard $(MAKEFILE_DIR)/ext_libs/*.inc) diff --git a/tensorflow/lite/micro/tools/make/download_and_extract.sh b/tensorflow/lite/micro/tools/make/download_and_extract.sh index e72fd7a0184..da98ab1b042 100755 --- a/tensorflow/lite/micro/tools/make/download_and_extract.sh +++ b/tensorflow/lite/micro/tools/make/download_and_extract.sh @@ -72,6 +72,7 @@ patch_am_sdk() { patch_kissfft() { sed -i -E $'s@#ifdef FIXED_POINT@// Patched automatically by download_dependencies.sh so default is 16 bit.\\\n#ifndef FIXED_POINT\\\n#define FIXED_POINT (16)\\\n#endif\\\n// End patch.\\\n\\\n#ifdef FIXED_POINT@g' tensorflow/lite/micro/tools/make/downloads/kissfft/kiss_fft.h + sed -i -E '/^#include /d' tensorflow/lite/micro/tools/make/downloads/kissfft/kiss_fft.h # Fix for https://github.com/mborgerding/kissfft/issues/20 sed -i -E $'s@#ifdef FIXED_POINT@#ifdef FIXED_POINT\\\n#include /* Patched. */@g' tensorflow/lite/micro/tools/make/downloads/kissfft/kiss_fft.h @@ -114,6 +115,77 @@ patch_cmsis() { -iname '*.*' -exec \ sed -i -E $'s@#include "arm_nn_tables.h"@#include "cmsis/CMSIS/NN/Include/arm_nn_tables.h"@g' {} \; + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "arm_math_types.h"@#include "cmsis/CMSIS/DSP/Include/arm_math_types.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "arm_math_memory.h"@#include "cmsis/CMSIS/DSP/Include/arm_math_memory.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/none.h"@#include "cmsis/CMSIS/DSP/Include/dsp/none.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/utils.h"@#include "cmsis/CMSIS/DSP/Include/dsp/utils.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/basic_math_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/basic_math_functions.h"@g' {} \; + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/statistics_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/statistics_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/svm_defines.h"@#include "cmsis/CMSIS/DSP/Include/dsp/svm_defines.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/svm_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/svm_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/support_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/support_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/transform_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/transform_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/bayes_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/bayes_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/complex_math_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/complex_math_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/controller_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/controller_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/distance_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/distance_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/fast_math_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/fast_math_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/filtering_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/filtering_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/interpolation_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/interpolation_functions.h"@g' {} \; + + find tensorflow/lite/micro/tools/make/downloads/cmsis \ + -iname '*.*' -exec \ + sed -i -E $'s@#include "dsp/matrix_functions.h"@#include "cmsis/CMSIS/DSP/Include/dsp/matrix_functions.h"@g' {} \; + # Until the fix for https://github.com/ARMmbed/mbed-os/issues/12568 is # rolled into Mbed version used on the Arduino IDE, we have to replace # one intrinsic with a patched equivalent. diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc b/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc index 8bb0d58bad1..b988dac0c94 100644 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis.inc @@ -26,6 +26,7 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) $(CMSIS_PATH)/CMSIS/NN/Source/FullyConnectedFunctions/arm_fully_connected_mat_q7_vec_q15.c \ $(CMSIS_PATH)/CMSIS/NN/Source/FullyConnectedFunctions/arm_fully_connected_q7.c \ $(CMSIS_PATH)/CMSIS/NN/Source/FullyConnectedFunctions/arm_fully_connected_mat_q7_vec_q15_opt.c \ + $(CMSIS_PATH)/CMSIS/NN/Source/SVDFunctions/arm_svdf_s8.c \ $(CMSIS_PATH)/CMSIS/NN/Source/ConcatenationFunctions/arm_concatenation_s8_z.c \ $(CMSIS_PATH)/CMSIS/NN/Source/ConcatenationFunctions/arm_concatenation_s8_x.c \ $(CMSIS_PATH)/CMSIS/NN/Source/ConcatenationFunctions/arm_concatenation_s8_w.c \ @@ -84,6 +85,7 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) $(CMSIS_PATH)/CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_u8.c \ $(CMSIS_PATH)/CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_q7.c \ $(CMSIS_PATH)/CMSIS/NN/Source/SoftmaxFunctions/arm_softmax_with_batch_q7.c \ + $(CMSIS_PATH)/CMSIS/NN/Source/SVDFunctions/arm_svdf_s8.c \ $(CMSIS_PATH)/CMSIS/NN/Source/ReshapeFunctions/arm_reshape_s8.c # Cherry-picked list of headers that are needed to compile the CMSIS-NN @@ -96,6 +98,34 @@ ifneq ($(filter cmsis-nn,$(ALL_TAGS)),) $(CMSIS_PATH)CMSIS/NN/Include/arm_nn_tables.h \ $(CMSIS_PATH)CMSIS/NN/Include/arm_nn_types.h \ $(CMSIS_PATH)CMSIS/DSP/Include/arm_math.h \ - $(CMSIS_PATH)CMSIS/DSP/Include/arm_common_tables.h + $(CMSIS_PATH)CMSIS/DSP/Include/arm_common_tables.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/arm_math_types.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/arm_math_memory.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/none.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/utils.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/basic_math_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/statistics_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/svm_defines.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/svm_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/support_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/transform_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/bayes_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/complex_math_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/controller_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/distance_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/fast_math_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/filtering_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/interpolation_functions.h \ + $(CMSIS_PATH)CMSIS/DSP/Include/dsp/matrix_functions.h + + + # Need to add the CMSIS Core includes path. + # All other CMSIS header files are included with their relative path + # in the CMSIS-NN micro kernel source files in + # tensorflow/lite/micro/kernels/cmsis-nn + INCLUDES += \ + -I$(CMSIS_PATH)/CMSIS/Core/Include \ + -I$(CMSIS_PATH)/CMSIS/DSP/Include \ + -I$(CMSIS_PATH)/CMSIS/NN/Include endif diff --git a/tensorflow/lite/micro/tools/make/helper_functions.inc b/tensorflow/lite/micro/tools/make/helper_functions.inc index a2127e4bdca..6ff5fd1e3d5 100644 --- a/tensorflow/lite/micro/tools/make/helper_functions.inc +++ b/tensorflow/lite/micro/tools/make/helper_functions.inc @@ -459,6 +459,7 @@ define microlite_test ifeq (,$(findstring _test, $(1))) $(eval $(call generate_project_third_party_parsing)) endif + $(1)_LOCAL_SRCS := $(2) $(1)_LOCAL_SRCS := $$(call specialize,$$($(1)_LOCAL_SRCS)) ALL_SRCS += $$($(1)_LOCAL_SRCS) @@ -476,10 +477,23 @@ $(1)_bin: $$($(1)_BINARY).bin test_$(1): $$($(1)_BINARY) @test -f $$(TEST_SCRIPT) || (echo 'Unable to find the test script. Is the software emulation available in $$(TARGET)?'; exit 1) $$(TEST_SCRIPT) $$($(1)_BINARY) '~~~ALL TESTS PASSED~~~' + ifneq (,$(findstring _test,$(1))) MICROLITE_TEST_TARGETS += test_$(1) MICROLITE_BUILD_TARGETS += $$($(1)_BINARY) endif + +# The ifneq can make is seem that the body of the if block is executed when +# _benchmark is not found in $(1). Actually, the check is saying that if +# findstring does not return empty, i.e. if _benchmark is found in $(1), we +# should add something to the MICROLITE_BUILD_TARGETS. +# +# This ensures that a `make build` command will builds all the tests and +# benchmarks, though `make test` will only run the tests. +ifneq (,$(findstring _benchmark,$(1))) + MICROLITE_BUILD_TARGETS += $$($(1)_BINARY) +endif + $(eval $(call generate_microlite_projects,$(1),$(call specialize,$(2)),$(3))) endef diff --git a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc index 68792496ec3..4d5e9e542b2 100644 --- a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc @@ -1,143 +1,136 @@ -# Settings for apollo3 evb and SparkFun Edge platforms. -ifeq ($(TARGET),$(filter $(TARGET),\ - apollo3evb\ - sparkfun_edge\ - )) - export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) - TARGET_ARCH := cortex-m4 - TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- - TARGET_TOOLCHAIN_ROOT := $(TENSORFLOW_ROOT)$(MAKEFILE_DIR)/downloads/gcc_embedded/bin/ - # Download the Ambiq Apollo3 SDK and set this variable to find the header - # files: - APOLLO3_SDK := $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST) - # Need a pointer to the GNU ARM toolchain for crtbegin.o for the fp functions - # with the hard interfaces. - GCC_ARM := $(MAKEFILE_DIR)/downloads/gcc_embedded/ +export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) +TARGET_ARCH := cortex-m4 +TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- +TARGET_TOOLCHAIN_ROOT := $(TENSORFLOW_ROOT)$(MAKEFILE_DIR)/downloads/gcc_embedded/bin/ +# Download the Ambiq Apollo3 SDK and set this variable to find the header +# files: +APOLLO3_SDK := $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST) +# Need a pointer to the GNU ARM toolchain for crtbegin.o for the fp functions +# with the hard interfaces. +GCC_ARM := $(MAKEFILE_DIR)/downloads/gcc_embedded/ - $(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) - $(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) - $(eval $(call add_third_party_download,$(AM_SDK_URL),$(AM_SDK_MD5),$(AM_SDK_DEST),patch_am_sdk)) +$(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) +$(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) +$(eval $(call add_third_party_download,$(AM_SDK_URL),$(AM_SDK_MD5),$(AM_SDK_DEST),patch_am_sdk)) - ifeq ($(findstring sparkfun,$(TARGET)), sparkfun) - $(eval $(call add_third_party_download,$(SF_BSPS_URL),$(SF_BSPS_MD5),$(AM_SDK_DEST)/$(SF_BSPS_DEST),)) - # Make sure that we download the full Ambiq SDK before the SparkFun BSPs. +ifeq ($(findstring sparkfun,$(TARGET)), sparkfun) + $(eval $(call add_third_party_download,$(SF_BSPS_URL),$(SF_BSPS_MD5),$(AM_SDK_DEST)/$(SF_BSPS_DEST),)) + # Make sure that we download the full Ambiq SDK before the SparkFun BSPs. $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST)/$(SF_BSPS_DEST): $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST) - endif - - PLATFORM_FLAGS = \ - -DPART_apollo3 \ - -DAM_PACKAGE_BGA \ - -DAM_PART_APOLLO3 \ - -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ - -DTF_LITE_STATIC_MEMORY \ - -DNDEBUG \ - -DTF_LITE_MCU_DEBUG_LOG \ - -D __FPU_PRESENT=1 \ - -DARM_MATH_CM4 \ - -fno-rtti \ - -fmessage-length=0 \ - -fno-exceptions \ - -fno-unwind-tables \ - -ffunction-sections \ - -fdata-sections \ - -funsigned-char \ - -MMD \ - -mcpu=cortex-m4 \ - -mthumb \ - -mfpu=fpv4-sp-d16 \ - -mfloat-abi=hard \ - -std=gnu++11 \ - -Wvla \ - -Wall \ - -Wextra \ - -Wno-missing-field-initializers \ - -Wno-strict-aliasing \ - -Wno-type-limits \ - -Wno-unused-function \ - -Wno-unused-parameter \ - -fno-delete-null-pointer-checks \ - -fno-threadsafe-statics \ - -fomit-frame-pointer \ - -fno-use-cxa-atexit \ - -nostdlib \ - -ggdb \ - -O3 - CXXFLAGS += $(PLATFORM_FLAGS) - CCFLAGS += $(PLATFORM_FLAGS) - LDFLAGS += \ - -mthumb -mcpu=cortex-m4 -mfpu=fpv4-sp-d16 -mfloat-abi=hard \ - -nostartfiles -static \ - -Wl,--gc-sections -Wl,--entry,Reset_Handler \ - -Wl,--start-group -lm -lc -lgcc -Wl,--end-group \ - -fno-exceptions \ - -nostdlib --specs=nano.specs -t -lstdc++ -lc -lnosys -lm \ - -Wl,-T,$(TENSORFLOW_ROOT)$(APOLLO3_SDK)/boards/apollo3_evb/examples/hello_world/gcc_patched/apollo3evb.ld \ - -Wl,-Map=$(TENSORFLOW_ROOT)$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref - BUILD_TYPE := micro - ifeq ($(TARGET), apollo3evb) - BOARD_BSP_PATH := $(APOLLO3_SDK)/boards/apollo3_evb/bsp - endif - ifeq ($(findstring sparkfun,$(TARGET)), sparkfun) - BOARD_BSP_PATH := $(APOLLO3_SDK)/$(SF_BSPS_DEST)/$(subst sparkfun_,,$(TARGET))/bsp - INCLUDES+= \ - -I$(APOLLO3_SDK)/$(SF_BSPS_DEST)/common/third_party/hm01b0 - endif - MICROLITE_LIBS := \ - $(BOARD_BSP_PATH)/gcc/bin/libam_bsp.a \ - $(APOLLO3_SDK)/mcu/apollo3/hal/gcc/bin/libam_hal.a \ - $(GCC_ARM)/lib/gcc/arm-none-eabi/7.3.1/thumb/v7e-m/fpv4-sp/hard/crtbegin.o \ - -lm - INCLUDES += \ - -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ - -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Include/ \ - -I$(GCC_ARM)/arm-none-eabi/ \ - -I$(APOLLO3_SDK)/mcu/apollo3/ \ - -I$(APOLLO3_SDK)/mcu/apollo3/regs \ - -I$(APOLLO3_SDK)/mcu/apollo3/hal \ - -I$(APOLLO3_SDK)/CMSIS/AmbiqMicro/Include/ \ - -I$(BOARD_BSP_PATH) \ - -I$(APOLLO3_SDK)/devices/ \ - -I$(APOLLO3_SDK)/utils/ \ - - - # The startup_gcc.c file is an altered version of the examples/hello_world/gcc/startup_gcc.c - # file from Ambiq: - # - Increase the stack size from 1k to 20k - # - Change the application entry call from main() to _main() - # The am_*.c files should be copied from the Ambiq Apollo3 SDK - # _main.c contains application and target specific initialization, like - # setting clock speed, default uart setups, etc. and an implementation - # of the DebugLog interfaces. - MICROLITE_CC_SRCS += \ - $(APOLLO3_SDK)/boards/apollo3_evb/examples/hello_world/gcc_patched/startup_gcc.c \ - $(APOLLO3_SDK)/utils/am_util_delay.c \ - $(APOLLO3_SDK)/utils/am_util_faultisr.c \ - $(APOLLO3_SDK)/utils/am_util_id.c \ - $(APOLLO3_SDK)/utils/am_util_stdio.c \ - $(APOLLO3_SDK)/devices/am_devices_led.c - - CMSIS_SRC_DIR := $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source - THIRD_PARTY_CC_SRCS := \ - $(CMSIS_SRC_DIR)/BasicMathFunctions/arm_dot_prod_q15.c \ - $(CMSIS_SRC_DIR)/BasicMathFunctions/arm_mult_q15.c \ - $(CMSIS_SRC_DIR)/TransformFunctions/arm_rfft_init_q15.c \ - $(CMSIS_SRC_DIR)/TransformFunctions/arm_rfft_q15.c \ - $(CMSIS_SRC_DIR)/TransformFunctions/arm_bitreversal2.c \ - $(CMSIS_SRC_DIR)/TransformFunctions/arm_cfft_q15.c \ - $(CMSIS_SRC_DIR)/TransformFunctions/arm_cfft_radix4_q15.c \ - $(CMSIS_SRC_DIR)/CommonTables/arm_const_structs.c \ - $(CMSIS_SRC_DIR)/CommonTables/arm_common_tables.c \ - $(CMSIS_SRC_DIR)/StatisticsFunctions/arm_mean_q15.c \ - $(CMSIS_SRC_DIR)/StatisticsFunctions/arm_max_q7.c - - MICRO_SPEECH_TEST_SRCS += \ - $(AP3_MICRO_DIR)/_main.c - - TEST_SCRIPT := tensorflow/lite/micro/testing/test_apollo3evb_binary.sh - # These are tests that don't currently work on the Apollo3 board. - EXCLUDED_TESTS := \ - tensorflow/lite/micro/micro_interpreter_test.cc \ - tensorflow/lite/micro/simple_tensor_allocator_test.cc - MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) - endif + +PLATFORM_FLAGS = \ + -DPART_apollo3 \ + -DAM_PACKAGE_BGA \ + -DAM_PART_APOLLO3 \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTF_LITE_STATIC_MEMORY \ + -DNDEBUG \ + -DTF_LITE_MCU_DEBUG_LOG \ + -D __FPU_PRESENT=1 \ + -DARM_MATH_CM4 \ + -fno-rtti \ + -fmessage-length=0 \ + -fno-exceptions \ + -fno-unwind-tables \ + -ffunction-sections \ + -fdata-sections \ + -funsigned-char \ + -MMD \ + -mcpu=cortex-m4 \ + -mthumb \ + -mfpu=fpv4-sp-d16 \ + -mfloat-abi=hard \ + -std=gnu++11 \ + -Wvla \ + -Wall \ + -Wextra \ + -Wno-missing-field-initializers \ + -Wno-strict-aliasing \ + -Wno-type-limits \ + -Wno-unused-function \ + -Wno-unused-parameter \ + -fno-delete-null-pointer-checks \ + -fno-threadsafe-statics \ + -fomit-frame-pointer \ + -fno-use-cxa-atexit \ + -nostdlib \ + -ggdb \ + -O3 +CXXFLAGS += $(PLATFORM_FLAGS) +CCFLAGS += $(PLATFORM_FLAGS) +LDFLAGS += \ + -mthumb -mcpu=cortex-m4 -mfpu=fpv4-sp-d16 -mfloat-abi=hard \ + -nostartfiles -static \ + -Wl,--gc-sections -Wl,--entry,Reset_Handler \ + -Wl,--start-group -lm -lc -lgcc -Wl,--end-group \ + -fno-exceptions \ + -nostdlib --specs=nano.specs -t -lstdc++ -lc -lnosys -lm \ + -Wl,-T,$(TENSORFLOW_ROOT)$(APOLLO3_SDK)/boards/apollo3_evb/examples/hello_world/gcc_patched/apollo3evb.ld \ + -Wl,-Map=$(TENSORFLOW_ROOT)$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref +BUILD_TYPE := micro +ifeq ($(TARGET), apollo3evb) + BOARD_BSP_PATH := $(APOLLO3_SDK)/boards/apollo3_evb/bsp +endif +ifeq ($(findstring sparkfun,$(TARGET)), sparkfun) + BOARD_BSP_PATH := $(APOLLO3_SDK)/$(SF_BSPS_DEST)/$(subst sparkfun_,,$(TARGET))/bsp + INCLUDES+= \ + -I$(APOLLO3_SDK)/$(SF_BSPS_DEST)/common/third_party/hm01b0 +endif +MICROLITE_LIBS := \ + $(BOARD_BSP_PATH)/gcc/bin/libam_bsp.a \ + $(APOLLO3_SDK)/mcu/apollo3/hal/gcc/bin/libam_hal.a \ + $(GCC_ARM)/lib/gcc/arm-none-eabi/7.3.1/thumb/v7e-m/fpv4-sp/hard/crtbegin.o \ + -lm +INCLUDES += \ + -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ + -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Include/ \ + -I$(GCC_ARM)/arm-none-eabi/ \ + -I$(APOLLO3_SDK)/mcu/apollo3/ \ + -I$(APOLLO3_SDK)/mcu/apollo3/regs \ + -I$(APOLLO3_SDK)/mcu/apollo3/hal \ + -I$(APOLLO3_SDK)/CMSIS/AmbiqMicro/Include/ \ + -I$(BOARD_BSP_PATH) \ + -I$(APOLLO3_SDK)/devices/ \ + -I$(APOLLO3_SDK)/utils/ \ + + +# The startup_gcc.c file is an altered version of the examples/hello_world/gcc/startup_gcc.c +# file from Ambiq: +# - Increase the stack size from 1k to 20k +# - Change the application entry call from main() to _main() +# The am_*.c files should be copied from the Ambiq Apollo3 SDK +# _main.c contains application and target specific initialization, like +# setting clock speed, default uart setups, etc. and an implementation +# of the DebugLog interfaces. +MICROLITE_CC_SRCS += \ + $(APOLLO3_SDK)/boards/apollo3_evb/examples/hello_world/gcc_patched/startup_gcc.c \ + $(APOLLO3_SDK)/utils/am_util_delay.c \ + $(APOLLO3_SDK)/utils/am_util_faultisr.c \ + $(APOLLO3_SDK)/utils/am_util_id.c \ + $(APOLLO3_SDK)/utils/am_util_stdio.c \ + $(APOLLO3_SDK)/devices/am_devices_led.c + +CMSIS_SRC_DIR := $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source +THIRD_PARTY_CC_SRCS := \ +$(CMSIS_SRC_DIR)/BasicMathFunctions/arm_dot_prod_q15.c \ +$(CMSIS_SRC_DIR)/BasicMathFunctions/arm_mult_q15.c \ +$(CMSIS_SRC_DIR)/TransformFunctions/arm_rfft_init_q15.c \ +$(CMSIS_SRC_DIR)/TransformFunctions/arm_rfft_q15.c \ +$(CMSIS_SRC_DIR)/TransformFunctions/arm_bitreversal2.c \ +$(CMSIS_SRC_DIR)/TransformFunctions/arm_cfft_q15.c \ +$(CMSIS_SRC_DIR)/TransformFunctions/arm_cfft_radix4_q15.c \ +$(CMSIS_SRC_DIR)/CommonTables/arm_const_structs.c \ +$(CMSIS_SRC_DIR)/CommonTables/arm_common_tables.c \ +$(CMSIS_SRC_DIR)/StatisticsFunctions/arm_mean_q15.c \ +$(CMSIS_SRC_DIR)/StatisticsFunctions/arm_max_q7.c + +MICRO_SPEECH_TEST_SRCS += \ + $(AP3_MICRO_DIR)/_main.c + +TEST_SCRIPT := tensorflow/lite/micro/testing/test_apollo3evb_binary.sh +# These are tests that don't currently work on the Apollo3 board. +EXCLUDED_TESTS := \ + tensorflow/lite/micro/micro_interpreter_test.cc \ + tensorflow/lite/micro/simple_tensor_allocator_test.cc +MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc index 06fdddbbeb9..e516554c063 100644 --- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc @@ -1,86 +1,60 @@ -# Settings for Blue Pill platforms. -ifeq ($(TARGET), bluepill) - export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) - TARGET_ARCH := cortex-m3 - TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- +export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) +TARGET_ARCH := cortex-m3 +TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- - $(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) - $(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) - $(eval $(call add_third_party_download,$(STM32_BARE_LIB_URL),$(STM32_BARE_LIB_MD5),stm32_bare_lib,)) +$(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) +$(eval $(call add_third_party_download,$(CMSIS_URL),$(CMSIS_MD5),cmsis,patch_cmsis)) +$(eval $(call add_third_party_download,$(STM32_BARE_LIB_URL),$(STM32_BARE_LIB_MD5),stm32_bare_lib,)) - PLATFORM_FLAGS = \ - -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ - -DTF_LITE_MCU_DEBUG_LOG \ - -fmessage-length=0 \ - -fno-exceptions \ - -fno-unwind-tables \ - -ffunction-sections \ - -fdata-sections \ - -funsigned-char \ - -MMD \ - -mcpu=cortex-m3 \ - -mthumb \ - -Wall \ - -Wextra \ - -Wno-vla \ - -Wno-unused-parameter \ - -Wno-strict-aliasing \ - -Wno-shadow \ - -Wno-type-limits \ - -fno-delete-null-pointer-checks \ - -fomit-frame-pointer \ - -nostdlib \ - -g \ - -Os +PLATFORM_FLAGS = \ + -DTF_LITE_MCU_DEBUG_LOG \ + -mcpu=cortex-m3 \ + -mthumb \ + -Wno-vla \ + -Wno-strict-aliasing \ + -Wno-shadow \ + -Wno-type-limits \ + -fomit-frame-pointer \ + -nostdlib - # TODO(b/168334217): Currently we always add -DNDEBUG because the build is - # broken w/o it. Remove this workaround once the issue is resolved. - PLATFORM_FLAGS += -DNDEBUG +# TODO(b/168334217): Currently we always add -DNDEBUG because the build is +# broken w/o it. Remove this workaround once the issue is resolved. +PLATFORM_FLAGS += -DNDEBUG - CXXFLAGS += $(PLATFORM_FLAGS) -fno-rtti -fno-threadsafe-statics \ - -fno-use-cxa-atexit - CCFLAGS += $(PLATFORM_FLAGS) - LDFLAGS += \ - -T $(MAKEFILE_DIR)/targets/bluepill/bluepill.lds \ - -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \ - -Wl,--gc-sections - BUILD_TYPE := micro - MICROLITE_LIBS := \ - -lm - INCLUDES += \ - -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ - -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include - MICROLITE_CC_SRCS += \ - $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \ - $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc) - EXCLUDED_SRCS := \ - $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/debug_log.c - MICROLITE_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(MICROLITE_CC_SRCS)) - TEST_SCRIPT := tensorflow/lite/micro/testing/test_bluepill_binary.sh +CXXFLAGS += $(PLATFORM_FLAGS) -fno-use-cxa-atexit +CCFLAGS += $(PLATFORM_FLAGS) - # TODO(b/143286954): Figure out why some tests fail and enable ince the issues - # are resolved. - EXCLUDED_TESTS := \ - tensorflow/lite/micro/micro_interpreter_test.cc \ - tensorflow/lite/micro/micro_allocator_test.cc \ - tensorflow/lite/micro/memory_helpers_test.cc \ - tensorflow/lite/micro/memory_arena_threshold_test.cc \ - tensorflow/lite/micro/kernels/circular_buffer_test.cc - MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) +LDFLAGS += \ + -T $(MAKEFILE_DIR)/targets/bluepill/bluepill.lds \ + -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref - EXCLUDED_EXAMPLE_TESTS := \ - tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ - tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ - tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc - MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) +# Additional include paths needed for the stm_32_bare_lib only. +INCLUDES += \ + -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \ + -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include +MICROLITE_CC_SRCS += \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \ + $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc) +EXCLUDED_SRCS := \ + $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/debug_log.c +MICROLITE_CC_SRCS := $(filter-out $(EXCLUDED_SRCS), $(MICROLITE_CC_SRCS)) -# These are microcontroller-specific rules for converting the ELF output -# of the linker into a binary image that can be loaded directly. -OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy +# TODO(b/143286954): Figure out why some tests fail and enable ince the issues +# are resolved. +EXCLUDED_TESTS := \ + tensorflow/lite/micro/micro_interpreter_test.cc \ + tensorflow/lite/micro/micro_allocator_test.cc \ + tensorflow/lite/micro/memory_helpers_test.cc \ + tensorflow/lite/micro/memory_arena_threshold_test.cc \ + tensorflow/lite/micro/kernels/circular_buffer_test.cc +MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS)) -$(BINDIR)/%.bin: $(BINDIR)/% - @mkdir -p $(dir $@) - $(OBJCOPY) $< $@ -O binary +EXCLUDED_EXAMPLE_TESTS := \ + tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ + tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ + tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc +MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) + +TEST_SCRIPT := tensorflow/lite/micro/testing/test_bluepill_binary.sh -endif diff --git a/tensorflow/lite/micro/tools/make/targets/cortex_m4_generic_makefile.inc b/tensorflow/lite/micro/tools/make/targets/cortex_m4_generic_makefile.inc deleted file mode 100644 index cc34d684054..00000000000 --- a/tensorflow/lite/micro/tools/make/targets/cortex_m4_generic_makefile.inc +++ /dev/null @@ -1,51 +0,0 @@ -# Generic Makefile target for ARM Cortex M4 builds. -# REQUIRED: -# - TOOLCHAIN_PATH: The path to the ARM GCC toolchain to use. - -ifeq ($(TARGET), cortex_m4_generic) - TARGET_ARCH := arm - TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- - export PATH := $(TOOLCHAIN_PATH):$(PATH) - - PLATFORM_FLAGS = \ - -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ - -DTF_LITE_STATIC_MEMORY \ - -DNDEBUG \ - -DTF_LITE_MCU_DEBUG_LOG \ - -D __FPU_PRESENT=1 \ - -DARM_MATH_CM4 \ - -fno-rtti \ - -fmessage-length=0 \ - -fno-exceptions \ - -fno-unwind-tables \ - -ffunction-sections \ - -fdata-sections \ - -funsigned-char \ - -MMD \ - -mcpu=cortex-m4 \ - -mthumb \ - -mfpu=fpv4-sp-d16 \ - -mfloat-abi=softfp \ - -std=gnu++11 \ - -Wvla \ - -Wall \ - -Wextra \ - -Wno-shadow \ - -Wno-missing-field-initializers \ - -Wno-strict-aliasing \ - -Wno-type-limits \ - -Wno-unused-function \ - -Wno-unused-parameter \ - -fno-delete-null-pointer-checks \ - -fno-threadsafe-statics \ - -fomit-frame-pointer \ - -fno-use-cxa-atexit \ - -O3 - - CXXFLAGS += $(PLATFORM_FLAGS) - CCFLAGS += $(PLATFORM_FLAGS) - - LDFLAGS += -Wl,--gc-sections - -endif - diff --git a/tensorflow/lite/micro/tools/make/targets/cortex_m_gcc_generic_makefile.inc b/tensorflow/lite/micro/tools/make/targets/cortex_m_gcc_generic_makefile.inc new file mode 100644 index 00000000000..dd7ccca7ba5 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/cortex_m_gcc_generic_makefile.inc @@ -0,0 +1,31 @@ +# Generic Makefile target for ARM Cortex Mx gcc builds. +ifeq ($(TARGET), cortex_m_gcc_generic) + TARGET_ARCH := arm + TARGET_TOOLCHAIN_PREFIX := arm-none-eabi- + export PATH := $(MAKEFILE_DIR)/downloads/gcc_embedded/bin/:$(PATH) + + $(eval $(call add_third_party_download,$(GCC_EMBEDDED_URL),$(GCC_EMBEDDED_MD5),gcc_embedded,)) + + PLATFORM_FLAGS = \ + -DTF_LITE_MCU_DEBUG_LOG \ + -Wno-type-limits \ + -funsigned-char \ + -mcpu=cortex-m4 \ + -mfpu=fpv4-sp-d16 \ + -mthumb \ + -fomit-frame-pointer + +ifeq ($(CORTEX_M_CORE), M4F) + PLATFORM_FLAGS += -mfloat-abi=hard +else ifeq ($(CORTEX_M_CORE), M4) + PLATFORM_FLAGS += -mfloat-abi=softfp +else ifeq ($(CORTEX_M_CORE), ) + $(error CORTEX_M_CORE=[M4|M4F] not defined on the command line) +else + $(error invalid target defined in command line option CORTEX_M_CORE=[M4|M4F]) +endif + + CXXFLAGS += $(PLATFORM_FLAGS) + CCFLAGS += $(PLATFORM_FLAGS) + +endif diff --git a/tensorflow/lite/micro/tools/make/targets/linux_x86_makefile.inc b/tensorflow/lite/micro/tools/make/targets/linux_x86_makefile.inc deleted file mode 100644 index 8ea78e8f3e3..00000000000 --- a/tensorflow/lite/micro/tools/make/targets/linux_x86_makefile.inc +++ /dev/null @@ -1,9 +0,0 @@ -# Settings for x86 on Linux -ifeq ($(TARGET), linux) - ifeq ($(TARGET_ARCH), x86_64) - PLATFORM_FLAGS = \ - -DTF_LITE_DISABLE_X86_NEON - CXXFLAGS += $(PLATFORM_FLAGS) - CCFLAGS += $(PLATFORM_FLAGS) - endif -endif diff --git a/tensorflow/lite/micro/tools/make/targets/osx_makefile.inc b/tensorflow/lite/micro/tools/make/targets/osx_makefile.inc deleted file mode 100644 index 9b1e2220575..00000000000 --- a/tensorflow/lite/micro/tools/make/targets/osx_makefile.inc +++ /dev/null @@ -1,13 +0,0 @@ -# Settings for Mac OS platforms. -ifeq ($(TARGET), osx) - - # Make sure we can find the embedded GCC compiler. - export PATH := ${PATH}:tensorflow/lite/micro/tools/make/downloads/gcc_embedded/bin/ - - PLATFORM_FLAGS = \ - -DTF_LITE_DISABLE_X86_NEON - - CXXFLAGS += $(PLATFORM_FLAGS) - CCFLAGS += $(PLATFORM_FLAGS) - -endif \ No newline at end of file diff --git a/tensorflow/lite/micro/tools/make/targets/osx_x86_64_makefile.inc b/tensorflow/lite/micro/tools/make/targets/osx_x86_64_makefile.inc deleted file mode 100644 index 78febaf5ddd..00000000000 --- a/tensorflow/lite/micro/tools/make/targets/osx_x86_64_makefile.inc +++ /dev/null @@ -1,10 +0,0 @@ -# Settings for x86 on Mac -ifeq ($(TARGET), osx) - ifeq ($(TARGET_ARCH), x86_64) - PLATFORM_FLAGS = \ - -DTF_LITE_DISABLE_X86_NEON - CXXFLAGS += $(PLATFORM_FLAGS) - CCFLAGS += $(PLATFORM_FLAGS) - endif -endif - diff --git a/tensorflow/lite/micro/tools/make/targets/sparkfun_edge_makefile.inc b/tensorflow/lite/micro/tools/make/targets/sparkfun_edge_makefile.inc new file mode 100644 index 00000000000..0a4e53202eb --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/sparkfun_edge_makefile.inc @@ -0,0 +1,2 @@ +include $(MAKEFILE_DIR)/targets/apollo3evb_makefile.inc + diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc index 664e5c3373c..96bea86d0d7 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_hifimini_makefile.inc @@ -23,28 +23,25 @@ ifeq ($(TARGET), xtensa_hifimini) $(error XTENSA_CORE is undefined) endif - PLATFORM_ARGS = \ + PLATFORM_FLAGS = \ -DTF_LITE_MCU_DEBUG_LOG \ --xtensa-core=$(XTENSA_CORE) \ -mcoproc \ - -DXTENSA -DMAX_RFFT_PWR=9 -DMIN_RFFT_PWR=MAX_RFFT_PWR \ - -fdata-sections \ - -ffunction-sections \ - -fno-exceptions \ - -fno-unwind-tables \ - -fno-use-cxa-atexit \ - -fmessage-length=0 \ - -fno-threadsafe-statics + -DXTENSA \ + -DMAX_RFFT_PWR=9 \ + -DMIN_RFFT_PWR=MAX_RFFT_PWR + export PATH := $(XTENSA_BASE)/tools/$(XTENSA_TOOLS_VERSION)/XtensaTools/bin:$(PATH) TARGET_TOOLCHAIN_PREFIX := xt- CXX_TOOL := clang++ CC_TOOL := clang - CXXFLAGS += $(PLATFORM_ARGS) - CCFLAGS += $(PLATFORM_ARGS) + CXXFLAGS += $(PLATFORM_FLAGS) + CCFLAGS += $(PLATFORM_FLAGS) - LDFLAGS += -Wl,-gc-sections + # TODO(b/150240249): Do not remove -fno-rtti once that works for the Xtensa toolchain. + CXXFLAGS := $(filter-out -fno-rtti, $(CXXFLAGS)) TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_hifimini_binary.sh diff --git a/tensorflow/lite/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/micro/tools/make/third_party_downloads.inc index 3bbeaeaff54..f77cad59fba 100644 --- a/tensorflow/lite/micro/tools/make/third_party_downloads.inc +++ b/tensorflow/lite/micro/tools/make/third_party_downloads.inc @@ -9,11 +9,11 @@ GEMMLOWP_URL := "https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1 GEMMLOWP_MD5 := "7e8191b24853d75de2af87622ad293ba" ifeq ($(HOST_OS),windows) - FLATBUFFERS_URL := "http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.zip" - FLATBUFFERS_MD5 := "a1afdbf114dec01a861c1b8c917d0fc7" + FLATBUFFERS_URL := "http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/dca12522a9f9e37f126ab925fd385c807ab4f84e.zip" + FLATBUFFERS_MD5 := "aa9adc93eb9b33fa1a2a90969e48baee" else - FLATBUFFERS_URL := "http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz" - FLATBUFFERS_MD5 := "c62ffefb3d4548b127cca14ce047f16c" + FLATBUFFERS_URL := "http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/dca12522a9f9e37f126ab925fd385c807ab4f84e.tar.gz" + FLATBUFFERS_MD5 := "dfa0ac3073b78ddacdcacf8ca189be91" endif ifeq ($(HOST_OS),osx) @@ -33,8 +33,8 @@ LEON_BCC2_MD5 := "cdf78082be4882da2a92c9baa82fe765" TSIM_URL := "http://mirror.tensorflow.org/www.gaisler.com/anonftp/tsim/tsim-eval-2.0.63.tar.gz" TSIM_MD5 := "afa0095d3ed989a949e1467f94e41d2f" -CMSIS_URL := "http://mirror.tensorflow.org/github.com/ARM-software/CMSIS_5/archive/9daaa7a34a5627a24009462b8fa8413a00c4fdb1.zip" -CMSIS_MD5 := "b988dacff8925ffffcb7e5079cc713b7" +CMSIS_URL := "http://github.com/ARM-software/CMSIS_5/archive/01f5b32badf7b78c85a24a7149b56400fa6a2999.zip" +CMSIS_MD5 := "823916c6f1749c65fd0bfdeec20b30ed" AM_SDK_URL := "http://mirror.tensorflow.org/s3.asia.ambiqmicro.com/downloads/AmbiqSuite-Rel2.2.0.zip" AM_SDK_MD5 := "7605fa2d4d97e6bb7a1190c92b66b597" diff --git a/tensorflow/lite/model_builder.h b/tensorflow/lite/model_builder.h index e4233998a30..9ffb54ce2b8 100644 --- a/tensorflow/lite/model_builder.h +++ b/tensorflow/lite/model_builder.h @@ -26,23 +26,13 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/core/api/verifier.h" #include "tensorflow/lite/mutable_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/stderr_reporter.h" namespace tflite { -/// Abstract interface that verifies whether a given model is legit. -/// It facilitates the use-case to verify and build a model without loading it -/// twice. -class TfLiteVerifier { - public: - /// Returns true if the model is legit. - virtual bool Verify(const char* data, int length, - ErrorReporter* reporter) = 0; - virtual ~TfLiteVerifier() {} -}; - /// An RAII object that represents a read-only tflite model, copied from disk, /// or mmapped. This uses flatbuffers as the serialization format. /// diff --git a/tensorflow/lite/nnapi/nnapi_implementation.cc b/tensorflow/lite/nnapi/nnapi_implementation.cc index 25b0d8920dd..856cd27602d 100644 --- a/tensorflow/lite/nnapi/nnapi_implementation.cc +++ b/tensorflow/lite/nnapi/nnapi_implementation.cc @@ -146,9 +146,14 @@ const NnApi LoadNnApi() { void* libneuralnetworks = nullptr; // TODO(b/123243014): change RTLD_LOCAL? Assumes there can be multiple // instances of nn api RT - libneuralnetworks = dlopen("libneuralnetworks.so", RTLD_LAZY | RTLD_LOCAL); + static const char nnapi_library_name[] = "libneuralnetworks.so"; + libneuralnetworks = dlopen(nnapi_library_name, RTLD_LAZY | RTLD_LOCAL); if (libneuralnetworks == nullptr) { - NNAPI_LOG("nnapi error: unable to open library %s", "libneuralnetworks.so"); + const char* error = dlerror(); + if (error) { + NNAPI_LOG("%s\n", error); + } + NNAPI_LOG("nnapi error: unable to open library %s", nnapi_library_name); } nnapi.nnapi_exists = libneuralnetworks != nullptr; diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index e1d7195955f..5c863886523 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -1,4 +1,5 @@ load("@flatbuffers//:build_defs.bzl", "flatbuffer_py_library") +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( default_visibility = ["//tensorflow:internal"], @@ -17,6 +18,7 @@ py_library( srcs = [ "interpreter.py", ], + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ @@ -97,7 +99,7 @@ py_test( ], python_version = "PY3", # Increased thread count for reducing timeout failures. - shard_count = 4, + shard_count = 10, srcs_version = "PY2AND3", tags = [ "no_oss", @@ -107,9 +109,26 @@ py_test( "notsan", # b/160824139 ], deps = [ + ":convert", ":tflite_convert", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:session", + "//tensorflow/python:tf2", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/keras", + "//tensorflow/python/saved_model", + "//tensorflow/python/saved_model:save", + "//tensorflow/python/training:training_util", + "//tensorflow/python/training/tracking", + "//third_party/py/numpy", ], ) @@ -128,6 +147,7 @@ py_library( "//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops", "//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py", "//tensorflow/lite/experimental/tensorboard:ops_util", + "//tensorflow/lite/python/keras/saving:saving_utils", "//tensorflow/lite/python/optimize:calibrator", "//tensorflow/python:graph_util", "//tensorflow/python/keras", @@ -159,13 +179,11 @@ py_test( name = "lite_v2_test", srcs = ["lite_v2_test.py"], python_version = "PY3", - shard_count = 4, + shard_count = 12, srcs_version = "PY2AND3", tags = [ "no_mac", # TODO(b/148247402): flatbuffers import broken on Mac OS. "no_windows", - "noasan", # b/168812578 - "notsan", # b/168812578 ], deps = [ ":lite", diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 84140a10356..68f49d50498 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -35,6 +35,7 @@ from tensorflow.lite.python import wrap_toco from tensorflow.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util import deprecation @@ -159,6 +160,19 @@ def mlir_sparsify(input_data_str): return wrap_toco.wrapped_experimental_mlir_sparsify(input_data_str) +def register_custom_opdefs(custom_opdefs_list): + """Register the given custom opdefs to the TensorFlow global op registry. + + Args: + custom_opdefs_list: String representing the custom ops OpDefs that are + included in the GraphDef. + + Returns: + True if the registration is successfully completed. + """ + return wrap_toco.wrapped_register_custom_opdefs(custom_opdefs_list) + + def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str, @@ -288,7 +302,7 @@ Alternative, use virtualenv.""") pass -def build_toco_flags(inference_type=lite_constants.FLOAT, +def build_toco_flags(inference_type=dtypes.float32, inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, output_format=lite_constants.TFLITE, @@ -339,7 +353,7 @@ def build_toco_flags(inference_type=lite_constants.FLOAT, def build_toco_convert_protos(input_tensors, output_tensors, - inference_type=lite_constants.FLOAT, + inference_type=dtypes.float32, inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, input_shapes=None, diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py index e3654217b3a..3cccce38669 100644 --- a/tensorflow/lite/python/convert_test.py +++ b/tensorflow/lite/python/convert_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np from tensorflow.lite.python import convert -from tensorflow.lite.python import lite_constants from tensorflow.lite.python import op_hint from tensorflow.lite.python.interpreter import Interpreter from tensorflow.python.client import session @@ -59,7 +58,7 @@ class ConvertTest(test_util.TensorFlowTestCase): tflite_model = convert.toco_convert( sess.graph_def, [in_tensor], [out_tensor], - inference_type=lite_constants.QUANTIZED_UINT8, + inference_type=dtypes.uint8, quantized_input_stats=[(0., 1.)]) self.assertTrue(tflite_model) @@ -73,7 +72,7 @@ class ConvertTest(test_util.TensorFlowTestCase): tflite_model = convert.toco_convert_graph_def( sess.graph_def, [("input", [1, 16, 16, 3])], ["add"], enable_mlir_converter=False, - inference_type=lite_constants.FLOAT) + inference_type=dtypes.float32) self.assertTrue(tflite_model) # Check values from converted model. @@ -111,7 +110,7 @@ class ConvertTest(test_util.TensorFlowTestCase): input_arrays_map, output_arrays, enable_mlir_converter=False, - inference_type=lite_constants.QUANTIZED_UINT8, + inference_type=dtypes.uint8, quantized_input_stats=[(0., 1.), (0., 1.)]) self.assertTrue(tflite_model) @@ -158,7 +157,7 @@ class ConvertTest(test_util.TensorFlowTestCase): input_arrays_map, output_arrays, enable_mlir_converter=False, - inference_type=lite_constants.QUANTIZED_UINT8) + inference_type=dtypes.uint8) self.assertEqual( "std_dev and mean must be defined when inference_type or " "inference_input_type is QUANTIZED_UINT8 or INT8.", diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index 2e5e5a1baf9..cd5a237b0ef 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -26,7 +26,8 @@ import os import numpy as np # pylint: disable=g-import-not-at-top -if not os.path.splitext(__file__)[0].endswith('tflite_runtime/interpreter'): +if not os.path.splitext(__file__)[0].endswith( + os.path.join('tflite_runtime', 'interpreter')): # This file is part of tensorflow package. from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper from tensorflow.python.util.tf_export import tf_export as _tf_export diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD index 427f54d0e2c..3055d37fa07 100644 --- a/tensorflow/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/lite/python/interpreter_wrapper/BUILD @@ -1,5 +1,8 @@ load("//tensorflow:tensorflow.bzl", "pybind_extension") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") + package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -9,6 +12,7 @@ cc_library( name = "numpy", srcs = ["numpy.cc"], hdrs = ["numpy.h"], + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", @@ -23,6 +27,7 @@ cc_library( hdrs = [ "interpreter_wrapper.h", ], + compatible_with = get_compatible_with_portable(), deps = [ ":numpy", ":python_error_reporter", @@ -45,8 +50,9 @@ cc_library( name = "python_error_reporter", srcs = ["python_error_reporter.cc"], hdrs = ["python_error_reporter.h"], + compatible_with = get_compatible_with_portable(), deps = [ - "//tensorflow/lite/core/api", + "//tensorflow/lite:stateful_error_reporter", "//third_party/python_runtime:headers", # buildcleaner: keep ], ) @@ -55,6 +61,7 @@ cc_library( name = "python_utils", srcs = ["python_utils.cc"], hdrs = ["python_utils.h"], + compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/lite:framework", "//tensorflow/lite:string_util", @@ -75,6 +82,7 @@ pybind_extension( "interpreter_wrapper_pybind11.cc", ], hdrs = ["interpreter_wrapper.h"], + compatible_with = get_compatible_with_portable(), link_in_framework = True, module_name = "_pywrap_tensorflow_interpreter_wrapper", deps = [ diff --git a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h index 7d4e308834a..b406187e5dd 100644 --- a/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h +++ b/tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h @@ -21,12 +21,12 @@ limitations under the License. #include #include -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/stateful_error_reporter.h" namespace tflite { namespace interpreter_wrapper { -class PythonErrorReporter : public tflite::ErrorReporter { +class PythonErrorReporter : public tflite::StatefulErrorReporter { public: PythonErrorReporter() {} @@ -38,7 +38,7 @@ class PythonErrorReporter : public tflite::ErrorReporter { PyObject* exception(); // Gets the last error message and clears the buffer. - std::string message(); + std::string message() override; private: std::stringstream buffer_; diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc index 1b31a0dcb54..886f1ef5335 100644 --- a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc +++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc @@ -22,6 +22,11 @@ namespace python_utils { int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) { #if PY_MAJOR_VERSION >= 3 + if (PyUnicode_Check(obj)) { + // const_cast<> is for CPython 3.7 finally adding const to the API. + *data = const_cast(PyUnicode_AsUTF8AndSize(obj, length)); + return *data == nullptr ? -1 : 0; + } return PyBytes_AsStringAndSize(obj, data, length); #else return PyString_AsStringAndSize(obj, data, length); diff --git a/tensorflow/lite/python/keras/saving/BUILD b/tensorflow/lite/python/keras/saving/BUILD new file mode 100644 index 00000000000..ff5c679a527 --- /dev/null +++ b/tensorflow/lite/python/keras/saving/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "saving_utils", + srcs = [ + "saving_utils.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/lite/python/keras/saving/saving_utils.py b/tensorflow/lite/python/keras/saving/saving_utils.py new file mode 100644 index 00000000000..03a442d2ee3 --- /dev/null +++ b/tensorflow/lite/python/keras/saving/saving_utils.py @@ -0,0 +1,83 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utility functions for TensorFlow models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +from tensorflow.python.util import nest +from tensorflow.python.util.compat import collections_abc + + +def _enforce_names_consistency(specs): + """Enforces that either all specs have names or none do.""" + + def _has_name(spec): + return hasattr(spec, 'name') and spec.name is not None + + def _clear_name(spec): + spec = copy.deepcopy(spec) + if hasattr(spec, 'name'): + spec._name = None # pylint:disable=protected-access + return spec + + flat_specs = nest.flatten(specs) + name_inconsistency = ( + any(_has_name(s) for s in flat_specs) and + not all(_has_name(s) for s in flat_specs)) + + if name_inconsistency: + specs = nest.map_structure(_clear_name, specs) + return specs + + +def model_input_signature(model, keep_original_batch_size=False): + """Inspect model to get its input signature. + + The model's input signature is a list with a single (possibly-nested) object. + This is due to the Keras-enforced restriction that tensor inputs must be + passed in as the first argument. + + For example, a model with input {'feature1': , 'feature2': } + will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] + + Args: + model: Keras Model object. + keep_original_batch_size: A boolean indicating whether we want to keep using + the original batch size or set it to None. Default is `False`, which means + that the batch dim of the returned input signature will always be set to + `None`. + + Returns: + A list containing either a single TensorSpec or an object with nested + TensorSpecs. This list does not contain the `training` argument. + """ + input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access + if input_specs is None: + return None + input_specs = _enforce_names_consistency(input_specs) + # Return a list with a single element as the model's input signature. + if isinstance(input_specs, + collections_abc.Sequence) and len(input_specs) == 1: + # Note that the isinstance check filters out single-element dictionaries, + # which should also be wrapped as a single-element list. + return input_specs + else: + return [input_specs] + diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 1d30f50c155..eb1e5509874 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -50,6 +50,7 @@ from tensorflow.lite.python.convert import toco_convert_protos # pylint: disabl from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import +from tensorflow.lite.python.keras.saving import saving_utils as _keras_saving_utils from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import @@ -162,9 +163,8 @@ class TargetSpec(object): supported_ops: Experimental flag, subject to change. Set of OpsSet options supported by the device. (default set([OpsSet.TFLITE_BUILTINS])) supported_types: List of types for constant values on the target device. - Supported values are types exported by lite.constants. Frequently, an - optimization choice is driven by the most compact (i.e. smallest) type in - this list (default [constants.FLOAT]) + Frequently, an optimization choice is driven by the most compact + (i.e. smallest) type in this list (default [tf.float32]) """ def __init__(self, supported_ops=None, supported_types=None): @@ -199,7 +199,7 @@ class QuantizationMode(object): return (self._any_optimization_enabled() and not self._is_int16x8_target_required() and self._representative_dataset is not None and - self._smallest_supported_type() == constants.INT8) + self._smallest_supported_type() == _dtypes.int8) def is_post_training_integer_quantize_8(self): """Post training integer 8 quantization.""" @@ -240,12 +240,12 @@ class QuantizationMode(object): return (self._any_optimization_enabled() and self._representative_dataset is None and not self.contains_training_quant_op() and - self._smallest_supported_type() == constants.INT8) + self._smallest_supported_type() == _dtypes.int8) def post_training_fp16(self): """Post training fp16 quantize.""" return (self._any_optimization_enabled() and - self._smallest_supported_type() == constants.FLOAT16) + self._smallest_supported_type() == _dtypes.float16) def fp32_execution(self): """If none of the above are true.""" @@ -258,36 +258,36 @@ class QuantizationMode(object): self.post_training_fp16()) def activations_type(self): - return constants.INT16 if self._is_int16x8_target_required() \ - else constants.INT8 + return _dtypes.int16 if self._is_int16x8_target_required() \ + else _dtypes.int8 def converter_flags(self, inference_ty=None, inference_input_ty=None): """Flags to the converter.""" if self.is_post_training_integer_quantize(): # The inference_input_type is for the quantizer, then we need to keep the # converter inference_input_type to float. - inference_input_ty = constants.FLOAT + inference_input_ty = _dtypes.float32 if self.training_time_int8_allow_float(): return { "inference_type": inference_ty if inference_ty else \ self.activations_type(), "inference_input_type": - inference_input_ty if inference_input_ty else constants.FLOAT, + inference_input_ty if inference_input_ty else _dtypes.float32, "post_training_quantize": False, # disable dynamic range quantization "quantize_to_float16": False # disable float16 quantization } elif self.post_training_dynamic_range_int8(): return { - "inference_type": constants.FLOAT, - "inference_input_type": constants.FLOAT, + "inference_type": _dtypes.float32, + "inference_input_type": _dtypes.float32, "post_training_quantize": True, # enable dynamic range quantization "quantize_to_float16": False # disable float16 quantization } elif self.post_training_fp16(): return { - "inference_type": constants.FLOAT, - "inference_input_type": constants.FLOAT, + "inference_type": _dtypes.float32, + "inference_input_type": _dtypes.float32, "post_training_quantize": True, "quantize_to_float16": True # enable float16 quantization } @@ -295,7 +295,7 @@ class QuantizationMode(object): # Note this might still trigger (uint8) quantization to be compatible with # TOCO. return { - "inference_type": inference_ty if inference_ty else constants.FLOAT, + "inference_type": inference_ty if inference_ty else _dtypes.float32, "inference_input_type": inference_input_ty, "post_training_quantize": False, # enable dynamic range quantization "quantize_to_float16": False # disable float16 quantization @@ -304,8 +304,8 @@ class QuantizationMode(object): def quantizer_flags(self, input_ty=None, output_ty=None): """Default flags to the TFMOT quantizer.""" - inference_input_type = input_ty if input_ty else constants.FLOAT - inference_output_type = output_ty if output_ty else constants.FLOAT + inference_input_type = input_ty if input_ty else _dtypes.float32 + inference_output_type = output_ty if output_ty else _dtypes.float32 if self.post_training_int8_no_float() \ or self.post_training_int16x8_no_float(): @@ -326,9 +326,8 @@ class QuantizationMode(object): else: return False, None - def flags_modify_model_io_type(self, - input_type=constants.FLOAT, - output_type=constants.FLOAT): + def flags_modify_model_io_type( + self, input_type=_dtypes.float32, output_type=_dtypes.float32): """Flags for modifying the input and output type of a tflite model.""" is_post_training_quantize = self.quantizer_flags(input_type, output_type)[0] is_training_time_only_quantize = self.training_time_int8_allow_float() and \ @@ -352,7 +351,7 @@ class QuantizationMode(object): return if self._target_spec.supported_types and (self._smallest_supported_type() != - constants.INT8): + _dtypes.int8): raise ValueError("TFLITE_BUILTINS_INT8 requires smallest supported " "type to be INT8.") @@ -371,7 +370,7 @@ class QuantizationMode(object): def _is_int8_target_required(self): return (set([OpsSet.TFLITE_BUILTINS_INT8]) == set( self._target_spec.supported_ops) or - set(self._target_spec.supported_types) == set([constants.INT8])) + set(self._target_spec.supported_types) == set([_dtypes.int8])) def _is_int16x8_target_required(self): return bool( @@ -396,7 +395,7 @@ class QuantizationMode(object): return min(self._target_spec.supported_types, key=lambda x: x.size) else: # The default smallest supported type is INT8. - return constants.INT8 + return _dtypes.int8 def contains_training_quant_op(self): """Checks if the graph contains any training-time quantization ops.""" @@ -555,18 +554,18 @@ class TFLiteConverterBaseV2(TFLiteConverterBase): def __init__(self): """Constructor for TFLiteConverter.""" super(TFLiteConverterBaseV2, self).__init__() - self.inference_input_type = constants.FLOAT - self.inference_output_type = constants.FLOAT + self.inference_input_type = _dtypes.float32 + self.inference_output_type = _dtypes.float32 def _validate_inference_input_output_types(self, quant_mode): """Validate inference_input_type and inference_output_type flags.""" - default_types = [constants.FLOAT] + default_types = [_dtypes.float32] # We support integer input/output for integer quantized models only. if quant_mode.training_time_int8_allow_float(): if quant_mode.is_post_training_integer_quantize_16x8(): - all_types = default_types + [constants.INT16] + all_types = default_types + [_dtypes.int16] else: - all_types = default_types + [constants.INT8, constants.QUANTIZED_UINT8] + all_types = default_types + [_dtypes.int8, _dtypes.uint8] if self.inference_input_type not in all_types or \ self.inference_output_type not in all_types: all_types_names = ["tf." + t.name for t in all_types] @@ -858,7 +857,8 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2): if not isinstance(self._keras_model.call, _def_function.Function): # Pass `keep_original_batch_size=True` will ensure that we get an input # signature including the batch dimension specified by the user. - input_signature = _saving_utils.model_input_signature( + # TODO(b/169898786): Use the Keras public API when TFLite moves out of TF + input_signature = _keras_saving_utils.model_input_signature( self._keras_model, keep_original_batch_size=True) func = _saving_utils.trace_model_call(self._keras_model, input_signature) @@ -1146,7 +1146,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): graph debug info for a set of nodes from the `graph_def`. """ super(TFLiteConverterBaseV1, self).__init__() - self.inference_type = constants.FLOAT + self.inference_type = _dtypes.float32 self.inference_input_type = None self.inference_output_type = None self.output_format = constants.TFLITE @@ -1193,7 +1193,7 @@ class TFLiteConverterBaseV1(TFLiteConverterBase): def _validate_quantized_input_stats(self, converter_kwargs, calibrate): """Ensure the `quantized_input_stats` flag is provided if required.""" - quantized_types = frozenset({constants.INT8, constants.QUANTIZED_UINT8}) + quantized_types = frozenset({_dtypes.int8, _dtypes.uint8}) requires_quantized_input_stats = ( (converter_kwargs["inference_type"] in quantized_types or @@ -1643,8 +1643,8 @@ class TFLiteConverter(TFLiteFrozenGraphConverter): quantized_input_stats: Dict of strings representing input tensor names mapped to tuple of floats representing the mean and standard deviation of the training data (e.g., {"foo" : (0., 1.)}). Only need if - `inference_input_type` is `QUANTIZED_UINT8`. real_input_value = - (quantized_input_value - mean_value) / std_dev_value. (default {}) + `inference_input_type` is `QUANTIZED_UINT8`. real_input_value = + (quantized_input_value - mean_value) / std_dev_value. (default {}) default_ranges_stats: Tuple of integers representing (min, max) range values for all arrays without a specified range. Intended for experimenting with quantization via "dummy quantization". (default None) diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py index d17fc94cd20..65e7b572a85 100644 --- a/tensorflow/lite/python/lite_test.py +++ b/tensorflow/lite/python/lite_test.py @@ -155,8 +155,8 @@ class FromSessionTest(TestModels, parameterized.TestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - converter.inference_input_type = lite_constants.QUANTIZED_UINT8 - converter.inference_type = lite_constants.FLOAT + converter.inference_input_type = dtypes.uint8 + converter.inference_type = dtypes.float32 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev tflite_model = converter.convert() self.assertIsNotNone(tflite_model) @@ -788,8 +788,8 @@ class FromSessionTest(TestModels, parameterized.TestCase): quantized_converter = lite.TFLiteConverter.from_session( sess, [inp], [output]) quantized_converter.experimental_new_converter = enable_mlir_converter - quantized_converter.inference_input_type = lite_constants.INT8 - quantized_converter.inference_output_type = lite_constants.INT8 + quantized_converter.inference_input_type = dtypes.int8 + quantized_converter.inference_output_type = dtypes.int8 quantized_converter.optimizations = [lite.Optimize.DEFAULT] quantized_converter.representative_dataset = calibration_gen quantized_tflite_model = quantized_converter.convert() @@ -832,7 +832,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): quantized_converter.experimental_new_converter = enable_mlir_converter quantized_converter.optimizations = [lite.Optimize.DEFAULT] # Restricting to int8 type only - quantized_converter.target_spec.supported_types = [lite.constants.INT8] + quantized_converter.target_spec.supported_types = [dtypes.int8] # A representative dataset is required for full fixed point quantization. with self.assertRaises(ValueError) as error: quantized_converter.convert() @@ -857,7 +857,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1, in_tensor_2], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = { 'inputA': (0., 1.), 'inputB': (0., 1.) @@ -898,7 +898,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev converter.default_ranges_stats = (0, 6) # min, max tflite_model = converter.convert() @@ -954,16 +954,15 @@ class FromSessionTest(TestModels, parameterized.TestCase): interpreter.allocate_tensors() self.assertEqual(interpreter.get_tensor_details()[idx]['name'], node_name) self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'], - lite.constants.FLOAT) + dtypes.float32) # Convert model to quantized version quantized_converter = lite.TFLiteConverter.from_session( sess, [inp], [output]) quantized_converter.experimental_new_converter = enable_mlir_converter quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16] + quantized_converter.target_spec.supported_types = [dtypes.float16] if include_int8: - quantized_converter.target_spec.supported_types.append( - lite.constants.INT8) + quantized_converter.target_spec.supported_types.append(dtypes.int8) if use_rep_data: quantized_converter.representative_dataset = calibration_gen @@ -984,11 +983,11 @@ class FromSessionTest(TestModels, parameterized.TestCase): if is_float16_quantized: # Verify that bias constant is float16 type. self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'], - lite.constants.FLOAT16) + dtypes.float16) elif is_post_training_quantized: # Verify that bias constants is int32 type. self.assertEqual(interpreter.get_tensor_details()[idx]['dtype'], - lite.constants.INT32) + dtypes.int32) else: raise ValueError('Invalid test options.') @@ -1005,7 +1004,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): sess, [inp], [output]) quantized_converter.experimental_new_converter = enable_mlir_converter quantized_converter.optimizations = [lite.Optimize.DEFAULT] - quantized_converter.target_spec.supported_types = [lite.constants.FLOAT16] + quantized_converter.target_spec.supported_types = [dtypes.float16] # Specify only int8 builtin ops quantized_converter.target_spec.supported_ops = [ lite.OpsSet.TFLITE_BUILTINS_INT8 @@ -1017,8 +1016,8 @@ class FromSessionTest(TestModels, parameterized.TestCase): str(error.exception)) @parameterized.named_parameters( - ('InferenceType_INT8', lite_constants.INT8), - ('InferenceType_UINT8', lite_constants.QUANTIZED_UINT8)) + ('InferenceType_INT8', dtypes.int8), + ('InferenceType_UINT8', dtypes.uint8)) def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type): with ops.Graph().as_default(): in_tensor = array_ops.placeholder( @@ -1039,7 +1038,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): 'flag is set to tf.uint8 or tf.int8.', str(error.exception)) with self.assertRaises(ValueError) as error: - quantized_converter.inference_type = lite_constants.FLOAT + quantized_converter.inference_type = dtypes.float32 quantized_converter.inference_input_type = quantized_type quantized_converter.convert() self.assertEqual( @@ -1070,7 +1069,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1, in_tensor_2], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev with self.assertRaises(ValueError) as error: converter.convert() @@ -1091,9 +1090,9 @@ class FromSessionTest(TestModels, parameterized.TestCase): converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) # extra flags to trigger training time quantization conversion - converter.inference_type = lite_constants.INT8 - converter.inference_input_type = lite_constants.FLOAT - converter.inference_output_type = lite_constants.FLOAT + converter.inference_type = dtypes.int8 + converter.inference_input_type = dtypes.float32 + converter.inference_output_type = dtypes.float32 input_arrays = converter.get_input_arrays() converter.quantized_input_stats = { input_arrays[0]: (0., 1.) @@ -1255,7 +1254,7 @@ class FromSessionTest(TestModels, parameterized.TestCase): # Convert model and ensure model is not None. converter = lite.TFLiteConverter.from_session(sess, [in_tensor], [out_tensor]) - converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.inference_type = dtypes.uint8 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev tflite_model = converter.convert() self.assertIsNotNone(tflite_model) @@ -2334,7 +2333,7 @@ class DefaultConverterAttrsTest(LiteTest): self.assertEqual(converter.output_format, lite_constants.TFLITE) # Assert the default inference type is float. - self.assertEqual(converter.inference_type, lite_constants.FLOAT) + self.assertEqual(converter.inference_type, dtypes.float32) # Assert the default inference type overrides are None. self.assertIsNone(converter.inference_input_type) diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index 9cdfcade2f0..4851912226a 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -32,6 +32,7 @@ from tensorflow.lite.python import lite_v2_test_util from tensorflow.lite.python.convert import mlir_quantize from tensorflow.lite.python.interpreter import Interpreter from tensorflow.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.layers import recurrent @@ -74,9 +75,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertEqual(expected_value.numpy(), actual_value) @parameterized.named_parameters( - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8), - ('_INT16InputOutput', lite.constants.INT16)) + ('_INT8InputOutput', dtypes.int8), + ('_UINT8InputOutput', dtypes.uint8), + ('_INT16InputOutput', dtypes.int16)) @test_util.run_v2_only def testInvalidFloat(self, inference_input_output_type): root = self._getSimpleVariableModel() @@ -194,9 +195,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) @parameterized.named_parameters( - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8), - ('_INT16InputOutput', lite.constants.INT16)) + ('_INT8InputOutput', dtypes.int8), + ('_UINT8InputOutput', dtypes.uint8), + ('_INT16InputOutput', dtypes.int16)) @test_util.run_v2_only def testInvalidPostTrainingDynamicRangeQuantization( self, inference_input_output_type): @@ -219,18 +220,18 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): 'must be tf.float32.', str(error.exception)) @parameterized.named_parameters( - ('_Default', False, False, lite.constants.FLOAT), - ('_INT8InputOutput', False, False, lite.constants.INT8), - ('_UINT8InputOutput', False, False, lite.constants.QUANTIZED_UINT8), - ('_INT16Quantize', False, True, lite.constants.FLOAT), - ('_INT16Quantize_INT16InputOutput', False, True, lite.constants.INT16), - ('_IntOnly', True, False, lite.constants.FLOAT), - ('_IntOnly_INT8InputOutput', True, False, lite.constants.INT8), + ('_Default', False, False, dtypes.float32), + ('_INT8InputOutput', False, False, dtypes.int8), + ('_UINT8InputOutput', False, False, dtypes.uint8), + ('_INT16Quantize', False, True, dtypes.float32), + ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16), + ('_IntOnly', True, False, dtypes.float32), + ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), ('_IntOnly_UINT8InputOutput', True, False, - lite.constants.QUANTIZED_UINT8), - ('_IntOnly_INT16Quantize', True, True, lite.constants.FLOAT), + dtypes.uint8), + ('_IntOnly_INT16Quantize', True, True, dtypes.float32), ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, - lite.constants.INT16)) + dtypes.int16)) def testIntegerQuantization(self, is_int_only, is_int16_quantize, inference_input_output_type): func, calibration_gen = self._getIntegerQuantizeModel() @@ -281,7 +282,7 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): self.assertLess(len(quantized_tflite_model), len(tflite_model)) @parameterized.named_parameters( - ('_INT16Quantize_INT8InputOutput', True, lite.constants.INT8)) + ('_INT16Quantize_INT8InputOutput', True, dtypes.int8)) def testInvalidIntegerQuantization(self, is_int16_quantize, inference_input_output_type): func, calibration_gen = self._getIntegerQuantizeModel() @@ -297,8 +298,8 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): lite.OpsSet.TFLITE_BUILTINS ] with self.assertRaises(ValueError) as error: - quantized_converter.inference_input_type = lite.constants.INT8 - quantized_converter.inference_output_type = lite.constants.INT8 + quantized_converter.inference_input_type = dtypes.int8 + quantized_converter.inference_output_type = dtypes.int8 quantized_converter.convert() self.assertEqual( "The inference_input_type and inference_output_type " @@ -377,9 +378,9 @@ class FromConcreteFunctionTest(lite_v2_test_util.ModelTest): return tf.keras.Sequential(QLinear(3, input_shape=(2,))) @parameterized.named_parameters( - ('_DefaultFLOAT32InputOutput', lite.constants.FLOAT), - ('_INT8InputOutput', lite.constants.INT8), - ('_UINT8InputOutput', lite.constants.QUANTIZED_UINT8)) + ('_DefaultFLOAT32InputOutput', dtypes.float32), + ('_INT8InputOutput', dtypes.int8), + ('_UINT8InputOutput', dtypes.uint8)) @test_util.run_v2_only def testTrainingTimeQuantization(self, inference_input_output_type): model = self._getTrainingTimeQuantizedModel() @@ -1019,7 +1020,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): self.assertAllClose(expected_value, actual_value, atol=1e-05) @test_util.run_v2_only - def testKerasBidirectionalRNN(self): + def testKerasBidirectionalRNNReturnSequence(self): input_data = tf.constant( np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32)) model = tf.keras.models.Sequential() @@ -1028,6 +1029,25 @@ class ControlFlowTest(lite_v2_test_util.ModelTest): tf.keras.layers.Bidirectional( recurrent_v2.LSTM(units=10, return_sequences=True), input_shape=(10, 10))) + model.add(tf.keras.layers.Flatten()) + model.add(tf.keras.layers.Dense(5)) + model.add(tf.keras.layers.Activation('softmax')) + + # Convert model. + converter = lite.TFLiteConverterV2.from_keras_model(model) + tflite_model = converter.convert() + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0] + + # Check values from converted model. + expected_value = model.predict(input_data) + self.assertAllClose(expected_value, actual_value, atol=1e-05) + + @test_util.run_v2_only + def testKerasBidirectionalRNN(self): + input_data = tf.constant( + np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32)) + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input')) model.add(tf.keras.layers.Bidirectional(recurrent_v2.LSTM(units=10))) model.add(tf.keras.layers.Dense(5)) model.add(tf.keras.layers.Activation('softmax')) diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD index 1a0d3db3b73..b921fc45cde 100644 --- a/tensorflow/lite/python/optimize/BUILD +++ b/tensorflow/lite/python/optimize/BUILD @@ -50,6 +50,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":_pywrap_tensorflow_lite_calibration_wrapper", # buildcleaner: keep + "//tensorflow/python:dtypes", "//tensorflow/python:util", "//third_party/py/numpy", ], @@ -67,8 +68,8 @@ py_test( tags = ["no_oss"], deps = [ ":calibrator", - "//tensorflow/lite/python:lite_constants", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//third_party/py/numpy", diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc index de3de413c1d..7639e493e83 100644 --- a/tensorflow/lite/python/optimize/calibration_wrapper.cc +++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc @@ -225,7 +225,8 @@ PyObject* CalibrationWrapper::SetTensor(int index, PyObject* value) { for (int j = 0; j < PyArray_NDIM(array); j++) { // Ensure the calibration data input shape is the same as the model input // shape unless the dimension is unknown. - if (tensor->dims_signature->size == tensor->dims->size && + if (tensor->dims_signature != nullptr && + tensor->dims_signature->size == tensor->dims->size && tensor->dims_signature->data[j] == -1) { has_unknown_dims = true; } else if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) { diff --git a/tensorflow/lite/python/optimize/calibrator.py b/tensorflow/lite/python/optimize/calibrator.py index 2b08ec690ff..dfef8b9cb79 100644 --- a/tensorflow/lite/python/optimize/calibrator.py +++ b/tensorflow/lite/python/optimize/calibrator.py @@ -18,8 +18,9 @@ from __future__ import division from __future__ import print_function import numpy as np + +from tensorflow.python.framework import dtypes from tensorflow.python.util.lazy_loader import LazyLoader -from tensorflow.lite.python import lite_constants # Lazy load since some of the performance benchmark skylark rules # break dependencies. Must use double quotes to match code internal rewrite @@ -60,7 +61,7 @@ class Calibrator(object): input_type, output_type, allow_float, - activations_type=lite_constants.INT8, + activations_type=dtypes.int8, resize_input=True): """Calibrates the model with specified generator and then quantizes it. diff --git a/tensorflow/lite/python/optimize/calibrator_test.py b/tensorflow/lite/python/optimize/calibrator_test.py index 371b3514ca3..a9ab12c6095 100644 --- a/tensorflow/lite/python/optimize/calibrator_test.py +++ b/tensorflow/lite/python/optimize/calibrator_test.py @@ -23,8 +23,8 @@ from absl.testing import parameterized import numpy as np from six.moves import range -from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python.optimize import calibrator as _calibrator +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -34,9 +34,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.named_parameters( # Activation type Int8 - ('UseActivationTypeInt8', constants.INT8), + ('UseActivationTypeInt8', dtypes.int8), # Activation type Int16 - ('UseActivationTypeInt16', constants.INT16)) + ('UseActivationTypeInt16', dtypes.int16)) def test_calibration_with_quantization(self, activations_type): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') @@ -49,16 +49,17 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] quantized_model = quantizer.calibrate_and_quantize(input_gen, - constants.FLOAT, - constants.FLOAT, False, + dtypes.float32, + dtypes.float32, + False, activations_type) self.assertIsNotNone(quantized_model) @parameterized.named_parameters( # Activation type Int8 - ('UseActivationTypeInt8', constants.INT8), + ('UseActivationTypeInt8', dtypes.int8), # Activation type Int16 - ('UseActivationTypeInt16', constants.INT16)) + ('UseActivationTypeInt16', dtypes.int16)) def test_calibration_with_quantization_allow_float(self, activations_type): model_path = resource_loader.get_path_to_datafile( 'test_data/mobilenet_like_model.bin') @@ -71,8 +72,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] quantized_model = quantizer.calibrate_and_quantize(input_gen, - constants.FLOAT, - constants.FLOAT, True, + dtypes.float32, + dtypes.float32, + True, activations_type) self.assertIsNotNone(quantized_model) @@ -88,7 +90,7 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.float32)] quantized_model = quantizer.calibrate_and_quantize_single( - input_gen, constants.FLOAT, constants.FLOAT, True, 'conv2d_8/BiasAdd') + input_gen, dtypes.float32, dtypes.float32, True, 'conv2d_8/BiasAdd') self.assertIsNotNone(quantized_model) def test_calibration_with_string_input(self): @@ -103,14 +105,14 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.array(u'Test' + str(i))] quantized_model = quantizer.calibrate_and_quantize_single( - input_gen, constants.FLOAT, constants.FLOAT, True, 'Identity') + input_gen, dtypes.float32, dtypes.float32, True, 'Identity') self.assertIsNotNone(quantized_model) @parameterized.named_parameters( # Activation type Int8 - ('UseActivationTypeInt8 - EnableMlirQuantizer', constants.INT8), + ('UseActivationTypeInt8 - EnableMlirQuantizer', dtypes.int8), # Activation type Int16 - ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', constants.INT16)) + ('UseActivationTypeInt16 - DisableEnableMlirQuantizer', dtypes.int16)) def test_calibration_with_quantization_multiple_inputs( self, activations_type): # Load multi add model from test data. @@ -126,8 +128,9 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 8, 8, 3), dtype=np.float32) for _ in range(4)] quantized_model = quantizer.calibrate_and_quantize(input_gen, - constants.FLOAT, - constants.FLOAT, False, + dtypes.float32, + dtypes.float32, + False, activations_type) self.assertIsNotNone(quantized_model) @@ -148,8 +151,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield i with self.assertRaises(RuntimeError): - quantizer.calibrate_and_quantize(empty_input_gen, constants.FLOAT, - constants.FLOAT, False) + quantizer.calibrate_and_quantize(empty_input_gen, dtypes.float32, + dtypes.float32, False) def test_invalid_shape_calibrator_gen(self): model_path = resource_loader.get_path_to_datafile( @@ -163,8 +166,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 2, 2, 3), dtype=np.float32)] with self.assertRaisesRegex(ValueError, 'Size mismatch'): - quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, constants.INT8, + quantizer.calibrate_and_quantize(input_gen, dtypes.float32, + dtypes.float32, False, dtypes.int8, False) def test_invalid_type_calibrator_gen(self): @@ -179,8 +182,8 @@ class CalibratorTest(test_util.TensorFlowTestCase, parameterized.TestCase): yield [np.ones(shape=(1, 5, 5, 3), dtype=np.int32)] with self.assertRaises(ValueError): - quantizer.calibrate_and_quantize(input_gen, constants.FLOAT, - constants.FLOAT, False, constants.INT8) + quantizer.calibrate_and_quantize(input_gen, dtypes.float32, + dtypes.float32, False, dtypes.int8) def test_calibration(self): model_path = resource_loader.get_path_to_datafile( diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index b5eb66eedc4..6a1e996135b 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -28,11 +28,12 @@ import six from six.moves import zip from tensorflow.lite.python import lite -from tensorflow.lite.python import lite_constants +from tensorflow.lite.python.convert import register_custom_opdefs from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.lite.toco.logging import gen_html from tensorflow.python import keras from tensorflow.python import tf2 +from tensorflow.python.framework import dtypes from tensorflow.python.platform import app @@ -62,13 +63,14 @@ def _parse_inference_type(value, flag): ValueError: Unsupported value. """ if value == "FLOAT": - return lite_constants.FLOAT - if value == "QUANTIZED_UINT8": - return lite_constants.QUANTIZED_UINT8 + return dtypes.float32 if value == "INT8": - return lite_constants.INT8 - raise ValueError("Unsupported value for --{0}. Only FLOAT and " - "QUANTIZED_UINT8 are supported.".format(flag)) + return dtypes.int8 + if value == "UINT8" or value == "QUANTIZED_UINT8": + return dtypes.uint8 + raise ValueError( + "Unsupported value for `{}` flag. Expected FLOAT, INT8 or UINT8, instead " + "got {}.".format(flag, value)) def _get_tflite_converter(flags): @@ -128,6 +130,10 @@ def _convert_tf1_model(flags): Raises: ValueError: Invalid flags. """ + # Register custom opdefs before converter object creation. + if flags.custom_opdefs: + register_custom_opdefs(_parse_array(flags.custom_opdefs)) + # Create converter. converter = _get_tflite_converter(flags) if flags.inference_type: @@ -146,10 +152,10 @@ def _convert_tf1_model(flags): # In quantized inference, mean_value has to be integer so that the real # value 0.0 is exactly representable. - if converter.inference_type == lite_constants.QUANTIZED_UINT8: - mean_values = _parse_array(flags.mean_values, type_fn=int) - else: + if converter.inference_type == dtypes.float32: mean_values = _parse_array(flags.mean_values, type_fn=float) + else: + mean_values = _parse_array(flags.mean_values, type_fn=int) quant_stats = list(zip(mean_values, std_dev_values)) if ((not flags.input_arrays and len(input_arrays) > 1) or (len(input_arrays) != len(quant_stats))): @@ -176,8 +182,7 @@ def _convert_tf1_model(flags): if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops - if flags.custom_opdefs: - converter._custom_opdefs = _parse_array(flags.custom_opdefs) # pylint: disable=protected-access + if flags.target_ops: ops_set_options = lite.OpsSet.get_options() converter.target_spec.supported_ops = set() @@ -189,13 +194,13 @@ def _convert_tf1_model(flags): if flags.post_training_quantize: converter.optimizations = [lite.Optimize.DEFAULT] - if converter.inference_type == lite_constants.QUANTIZED_UINT8: + if converter.inference_type != dtypes.float32: print("--post_training_quantize quantizes a graph of inference_type " - "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.") - converter.inference_type = lite_constants.FLOAT + "FLOAT. Overriding inference_type to FLOAT.") + converter.inference_type = dtypes.float32 if flags.quantize_to_float16: - converter.target_spec.supported_types = [lite.constants.FLOAT16] + converter.target_spec.supported_types = [dtypes.float16] if not flags.post_training_quantize: print("--quantize_to_float16 will only take effect with the " "--post_training_quantize flag enabled.") @@ -354,14 +359,15 @@ def _get_tf1_flags(parser): parser.add_argument( "--inference_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], - help="Target data type of real-number arrays in the output file.") + default="FLOAT", + help=("Target data type of real-number arrays in the output file. " + "Must be either FLOAT, INT8 or UINT8.")) parser.add_argument( "--inference_input_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], help=("Target data type of real-number input arrays. Allows for a " - "different type for input arrays in the case of quantization.")) + "different type for input arrays in the case of quantization. " + "Must be either FLOAT, INT8 or UINT8.")) # Input and output arrays flags. parser.add_argument( diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py index 17d466db326..1cd5e61e118 100644 --- a/tensorflow/lite/python/tflite_convert_test.py +++ b/tensorflow/lite/python/tflite_convert_test.py @@ -22,7 +22,9 @@ import os import numpy as np +from tensorflow.core.framework import graph_pb2 from tensorflow.lite.python import tflite_convert +from tensorflow.lite.python.convert import register_custom_opdefs from tensorflow.python import keras from tensorflow.python import tf2 from tensorflow.python.client import session @@ -31,6 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.framework.importer import import_graph_def from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import gfile @@ -165,6 +168,37 @@ class TfLiteConvertV1Test(TestModels): self._run(flags_str, should_succeed=True) os.remove(graph_def_file) + def testQATFrozenGraphDefUInt8(self): + with ops.Graph().as_default(): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + _ = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Write graph to file. + graph_def_file = self._getFilepath('model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + sess.close() + + # Define converter flags + flags_str = ('--std_dev_values=128,128 --mean_values=128,128 ' + '--graph_def_file={0} --input_arrays={1} ' + '--output_arrays={2}'.format( + graph_def_file, 'inputA,inputB', 'output')) + + # Set inference_type UINT8 and (default) inference_input_type UINT8 + flags_str_1 = flags_str + ' --inference_type=UINT8' + self._run(flags_str_1, should_succeed=True) + + # Set inference_type UINT8 and inference_input_type FLOAT + flags_str_2 = flags_str_1 + ' --inference_input_type=FLOAT' + self._run(flags_str_2, should_succeed=True) + + os.remove(graph_def_file) + def testSavedModel(self): saved_model_dir = self._getFilepath('model') with ops.Graph().as_default(): @@ -179,6 +213,73 @@ class TfLiteConvertV1Test(TestModels): flags_str = '--saved_model_dir={}'.format(saved_model_dir) self._run(flags_str, should_succeed=True) + def _createSavedModelWithCustomOp(self): + custom_opdefs_str = ( + 'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} ' + 'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: ' + '\'Output\' type: DT_FLOAT}') + + # Create a graph that has one add op. + new_graph = graph_pb2.GraphDef() + with ops.Graph().as_default(): + with session.Session() as sess: + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input') + out_tensor = in_tensor + in_tensor + inputs = {'x': in_tensor} + outputs = {'z': out_tensor} + + new_graph.CopyFrom(sess.graph_def) + + # Rename Add op name to CustomAdd. + for node in new_graph.node: + if node.op.startswith('Add'): + node.op = 'CustomAdd' + del node.attr['T'] + + # Register custom op defs to import modified graph def. + register_custom_opdefs([custom_opdefs_str]) + + # Store saved model. + saved_model_dir = self._getFilepath('model') + with ops.Graph().as_default(): + with session.Session() as sess: + import_graph_def(new_graph, name='') + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return (saved_model_dir, custom_opdefs_str) + + def testEnsureCustomOpdefsFlag(self): + saved_model_dir, _ = self._createSavedModelWithCustomOp() + + # Ensure --custom_opdefs. + flags_str = ('--saved_model_dir={0} --allow_custom_ops ' + '--experimental_new_converter'.format(saved_model_dir)) + self._run(flags_str, should_succeed=False) + + def testSavedModelWithCustomOpdefsFlag(self): + saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp() + + # Valid conversion. + flags_str = ( + '--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops ' + '--experimental_new_converter'.format(saved_model_dir, + custom_opdefs_str)) + self._run(flags_str, should_succeed=True) + + def testSavedModelWithInvalidCustomOpdefsFlag(self): + saved_model_dir, _ = self._createSavedModelWithCustomOp() + + invalid_custom_opdefs_str = ( + 'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} ' + 'output_arg: {name: \'Output\' type: DT_FLOAT}') + + # Valid conversion. + flags_str = ( + '--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops ' + '--experimental_new_converter'.format(saved_model_dir, + invalid_custom_opdefs_str)) + self._run(flags_str, should_succeed=False) + def testKerasFile(self): keras_file = self._getKerasModelFile() @@ -269,9 +370,9 @@ class TfLiteConvertV1Test(TestModels): 'attr : { name: \'nms_iou_threshold\' type: \'float\'} ' 'attr : { name: \'nms_score_threshold\' type: \'float\'} ' 'attr : { name: \'num_classes\' type: \'int\'} ' - 'attr : { name: \'w_scale\' type: \'int\'} ' - 'attr : { name: \'x_scale\' type: \'int\'} ' - 'attr : { name: \'y_scale\' type: \'int\'}') + 'attr : { name: \'w_scale\' type: \'float\'} ' + 'attr : { name: \'x_scale\' type: \'float\'} ' + 'attr : { name: \'y_scale\' type: \'float\'}') flags_str = ('--graph_def_file={0} --input_arrays={1} ' '--output_arrays={2} --input_shapes={3} ' diff --git a/tensorflow/lite/python/util.py b/tensorflow/lite/python/util.py index bb66d78a89c..4fc00778b74 100644 --- a/tensorflow/lite/python/util.py +++ b/tensorflow/lite/python/util.py @@ -763,6 +763,7 @@ def _modify_model_output_type(model, inference_output_type=dtypes.float32): if quant_opcode_idx == -1: quant_op = schema_fb.OperatorCodeT() quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE + quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE model.operatorCodes.append(quant_op) quant_opcode_idx = len(model.operatorCodes) - 1 # Change dequant op (int8 to float) to quant op (int8 to uint8) diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index 1950331ce02..528caeab7d3 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -24,7 +24,6 @@ import numpy as np from six.moves import range import tensorflow as tf -from tensorflow.lite.python import lite_constants from tensorflow.lite.python import util from tensorflow.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.client import session @@ -292,9 +291,9 @@ def _test_param_modify_integer_model_io_type(): # "DuringTraining": False, } map_types = { - "": lite_constants.FLOAT, - "INT8": lite_constants.INT8, - "UINT8": lite_constants.QUANTIZED_UINT8 + "": dtypes.float32, + "INT8": dtypes.int8, + "UINT8": dtypes.uint8, } for k1, v1 in map_model_type.items(): for k2, v2 in map_types.items(): diff --git a/tensorflow/lite/python/wrap_toco.py b/tensorflow/lite/python/wrap_toco.py index c6176275d81..60b33cea8fd 100644 --- a/tensorflow/lite/python/wrap_toco.py +++ b/tensorflow/lite/python/wrap_toco.py @@ -55,3 +55,8 @@ def wrapped_experimental_mlir_quantize(input_data_str, disable_per_channel, def wrapped_experimental_mlir_sparsify(input_data_str): """Wraps experimental mlir sparsify model.""" return _pywrap_toco_api.ExperimentalMlirSparsifyModel(input_data_str) + + +def wrapped_register_custom_opdefs(custom_opdefs_list): + """Wraps RegisterCustomOpdefs with lazy loader.""" + return _pywrap_toco_api.RegisterCustomOpdefs(custom_opdefs_list) diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD index 5e3ab511792..13a996cf56e 100644 --- a/tensorflow/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -1,8 +1,21 @@ load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite", "tflite_schema_utils_friends") load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") +# This is the package group declaration to which targets for TensorFlow Lite +# Flatbuffer schema utilities. +# +# Its usage should be rare, and is often abused by tools that are doing +# Flatbuffer creation/manipulation in unofficially supported ways. +package_group( + name = "utils_friends", + packages = [ + "//tensorflow/compiler/mlir/lite/...", + "//tensorflow/lite/...", + ] + tflite_schema_utils_friends(), +) + package( default_visibility = [ "//visibility:public", @@ -60,6 +73,7 @@ exports_files([ "schema_v1.fbs", "schema_v2.fbs", "schema_v3.fbs", + "schema_v3a.fbs", ]) flatbuffer_cc_library( @@ -102,7 +116,7 @@ cc_test( srcs = ["flatbuffer_compatibility_test.cc"], data = [ "schema.fbs", - "schema_v3.fbs", + "schema_v3a.fbs", ], tags = [ "no_oss", @@ -116,4 +130,17 @@ cc_test( ], ) +cc_library( + name = "schema_utils", + srcs = ["schema_utils.cc"], + hdrs = ["schema_utils.h"], + compatible_with = get_compatible_with_portable(), + visibility = [":utils_friends"], + deps = [ + ":schema_fbs", + "//tensorflow/lite/kernels/internal:compatibility", + "@flatbuffers", + ], +) + tflite_portable_test_suite() diff --git a/tensorflow/lite/schema/builtin_ops_header/generator.cc b/tensorflow/lite/schema/builtin_ops_header/generator.cc index e2967aee0ff..6a004223bf2 100644 --- a/tensorflow/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/lite/schema/builtin_ops_header/generator.cc @@ -46,7 +46,8 @@ extern "C" { #endif // __cplusplus // The enum for builtin operators. -// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. +// Note: CUSTOM, DELEGATE, and PLACEHOLDER_FOR_GREATER_OP_CODES are 3 special +// ops which are not real built-in ops. typedef enum { )"; diff --git a/tensorflow/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/lite/schema/flatbuffer_compatibility_test.cc index 8f88b6204a1..85002b2911e 100644 --- a/tensorflow/lite/schema/flatbuffer_compatibility_test.cc +++ b/tensorflow/lite/schema/flatbuffer_compatibility_test.cc @@ -61,8 +61,7 @@ TEST(SchemaTest, TestCompatibility) { // Read file contents of schemas into strings // TODO(aselle): Need a reliable way to load files. std::string base_contents, current_contents; - const char *base_filename = - TFLITE_TF_PREFIX "lite/schema/schema_v3.fbs"; + const char *base_filename = TFLITE_TF_PREFIX "lite/schema/schema_v3a.fbs"; const char *current_filename = TFLITE_TF_PREFIX "lite/schema/schema.fbs"; diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index baeb49f7b7a..71312a7b016 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -17,6 +17,8 @@ // Version 1: Add subgraphs to schema. // Version 2: Rename operators to conform to NN API. // Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. +// Version 3a: Add new builtin op code field. Has backward compatibility with +// version 3. namespace tflite; @@ -215,7 +217,7 @@ table Tensor { // object containing configuration parameters, builtins have a predetermined // set of acceptable options. -enum BuiltinOperator : byte { +enum BuiltinOperator : int32 { ADD = 0, AVERAGE_POOL_2D = 1, CONCATENATION = 2, @@ -248,7 +250,6 @@ enum BuiltinOperator : byte { SPACE_TO_DEPTH = 26, SVDF = 27, TANH = 28, - // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS CONCAT_EMBEDDINGS = 29, SKIP_GRAM = 30, CALL = 31, @@ -349,7 +350,8 @@ enum BuiltinOperator : byte { SELECT_V2 = 123, DENSIFY = 124, SEGMENT_SUM = 125, - BATCH_MATMUL = 126 + BATCH_MATMUL = 126, + PLACEHOLDER_FOR_GREATER_OP_CODES = 127 } @@ -982,12 +984,21 @@ table BatchMatMulOptions { // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { - builtin_code:BuiltinOperator; + // This field is for backward compatibility. This field will be used when + // the value of the extended builtin_code field has less than + // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + deprecated_builtin_code:byte; custom_code:string; // The version of the operator. The version need to be bumped whenever new // parameters are introduced into an op. version:int = 1; + + // This field is introduced for resolving op builtin code shortage problem + // (the original BuiltinOperator enum field was represented as a byte). + // This field will be used when the value of the extended builtin_code field + // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + builtin_code:BuiltinOperator; } enum CustomOptionsFormat : byte { @@ -1066,6 +1077,32 @@ table Metadata { buffer:uint; } +// Map from an alias name of tensor to tensor index in the graph. +// This is used in Signature def. +table TensorMap { + // Represents the alias to use for this tensor. + name:string; + + // The actual tensor index in the primary graph, that 'name' corresponds to. + tensor_index:uint; +} + +// This corresponds to SignatureDef in Tensorflow SavedModel. +// The SignatureDef will be part of the SavedModel provided for conversion. +table SignatureDef { + // Named inputs for this signature. + inputs:[TensorMap]; + + // Named outputs for this signature. + outputs:[TensorMap]; + + // Exported method name for this signature. + method_name:string; + + // Key value which was in the Tensorflow SavedModel SignatureDef map. + key:string; +} + table Model { // Version of the schema. version:uint; @@ -1094,6 +1131,9 @@ table Model { // Metadata about the model. metadata:[Metadata]; + + // Optional SignatureDefs for the model. + signature_defs:[SignatureDef]; } root_type Model; diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index c5013edb179..bbc000cc5dc 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -781,11 +781,12 @@ enum BuiltinOperator { BuiltinOperator_DENSIFY = 124, BuiltinOperator_SEGMENT_SUM = 125, BuiltinOperator_BATCH_MATMUL = 126, + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_BATCH_MATMUL + BuiltinOperator_MAX = BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[128] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -913,13 +914,14 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[127] { BuiltinOperator_SELECT_V2, BuiltinOperator_DENSIFY, BuiltinOperator_SEGMENT_SUM, - BuiltinOperator_BATCH_MATMUL + BuiltinOperator_BATCH_MATMUL, + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES }; return values; } inline const char * const *EnumNamesBuiltinOperator() { - static const char * const names[128] = { + static const char * const names[129] = { "ADD", "AVERAGE_POOL_2D", "CONCATENATION", @@ -1047,13 +1049,14 @@ inline const char * const *EnumNamesBuiltinOperator() { "DENSIFY", "SEGMENT_SUM", "BATCH_MATMUL", + "PLACEHOLDER_FOR_GREATER_OP_CODES", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_BATCH_MATMUL)) return ""; + if (flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES)) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -9336,24 +9339,27 @@ flatbuffers::Offset CreateBatchMatMulOptions(flatbuffers::Fl struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; - tflite::BuiltinOperator builtin_code; + int8_t deprecated_builtin_code; std::string custom_code; int32_t version; + tflite::BuiltinOperator builtin_code; OperatorCodeT() - : builtin_code(tflite::BuiltinOperator_ADD), - version(1) { + : deprecated_builtin_code(0), + version(1), + builtin_code(tflite::BuiltinOperator_ADD) { } }; struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef OperatorCodeT NativeTableType; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { - VT_BUILTIN_CODE = 4, + VT_DEPRECATED_BUILTIN_CODE = 4, VT_CUSTOM_CODE = 6, - VT_VERSION = 8 + VT_VERSION = 8, + VT_BUILTIN_CODE = 10 }; - tflite::BuiltinOperator builtin_code() const { - return static_cast(GetField(VT_BUILTIN_CODE, 0)); + int8_t deprecated_builtin_code() const { + return GetField(VT_DEPRECATED_BUILTIN_CODE, 0); } const flatbuffers::String *custom_code() const { return GetPointer(VT_CUSTOM_CODE); @@ -9361,12 +9367,16 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { int32_t version() const { return GetField(VT_VERSION, 1); } + tflite::BuiltinOperator builtin_code() const { + return static_cast(GetField(VT_BUILTIN_CODE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_BUILTIN_CODE) && + VerifyField(verifier, VT_DEPRECATED_BUILTIN_CODE) && VerifyOffset(verifier, VT_CUSTOM_CODE) && verifier.VerifyString(custom_code()) && VerifyField(verifier, VT_VERSION) && + VerifyField(verifier, VT_BUILTIN_CODE) && verifier.EndTable(); } OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -9377,8 +9387,8 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { struct OperatorCodeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_builtin_code(tflite::BuiltinOperator builtin_code) { - fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); + void add_deprecated_builtin_code(int8_t deprecated_builtin_code) { + fbb_.AddElement(OperatorCode::VT_DEPRECATED_BUILTIN_CODE, deprecated_builtin_code, 0); } void add_custom_code(flatbuffers::Offset custom_code) { fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); @@ -9386,6 +9396,9 @@ struct OperatorCodeBuilder { void add_version(int32_t version) { fbb_.AddElement(OperatorCode::VT_VERSION, version, 1); } + void add_builtin_code(tflite::BuiltinOperator builtin_code) { + fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); + } explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -9400,27 +9413,31 @@ struct OperatorCodeBuilder { inline flatbuffers::Offset CreateOperatorCode( flatbuffers::FlatBufferBuilder &_fbb, - tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD, + int8_t deprecated_builtin_code = 0, flatbuffers::Offset custom_code = 0, - int32_t version = 1) { + int32_t version = 1, + tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD) { OperatorCodeBuilder builder_(_fbb); + builder_.add_builtin_code(builtin_code); builder_.add_version(version); builder_.add_custom_code(custom_code); - builder_.add_builtin_code(builtin_code); + builder_.add_deprecated_builtin_code(deprecated_builtin_code); return builder_.Finish(); } inline flatbuffers::Offset CreateOperatorCodeDirect( flatbuffers::FlatBufferBuilder &_fbb, - tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD, + int8_t deprecated_builtin_code = 0, const char *custom_code = nullptr, - int32_t version = 1) { + int32_t version = 1, + tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD) { auto custom_code__ = custom_code ? _fbb.CreateString(custom_code) : 0; return tflite::CreateOperatorCode( _fbb, - builtin_code, + deprecated_builtin_code, custom_code__, - version); + version, + builtin_code); } flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -13695,9 +13712,10 @@ inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_ inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = builtin_code(); _o->builtin_code = _e; } + { auto _e = deprecated_builtin_code(); _o->deprecated_builtin_code = _e; } { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); } { auto _e = version(); _o->version = _e; } + { auto _e = builtin_code(); _o->builtin_code = _e; } } inline flatbuffers::Offset OperatorCode::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -13708,14 +13726,16 @@ inline flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBuf (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorCodeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; - auto _builtin_code = _o->builtin_code; + auto _deprecated_builtin_code = _o->deprecated_builtin_code; auto _custom_code = _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); auto _version = _o->version; + auto _builtin_code = _o->builtin_code; return tflite::CreateOperatorCode( _fbb, - _builtin_code, + _deprecated_builtin_code, _custom_code, - _version); + _version, + _builtin_code); } inline OperatorT *Operator::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/lite/schema/schema_utils.cc b/tensorflow/lite/schema/schema_utils.cc new file mode 100644 index 00000000000..ea930fc9587 --- /dev/null +++ b/tensorflow/lite/schema/schema_utils.cc @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/schema/schema_utils.h" + +#include + +#include "tensorflow/lite/kernels/internal/compatibility.h" + +namespace tflite { + +// The following GetBuiltinCode methods are the utility methods for reading +// builtin operatore code, ensuring compatibility issues between v3 and v3a +// schema. Always the maximum value of the two fields always will be the correct +// value as follows: +// +// - Supporting schema version v3 models +// +// The `builtin_code` field is not available in the v3 models. Flatbuffer +// library will feed zero value, which is the default value in the v3a schema. +// The actual builtin operatore code value will exist in the +// `deprecated_builtin_code` field. At the same time, it implies that +// `deprecated_builtin_code` >= `builtin_code` and the maximum value of the two +// fields will be same with `deprecated_builtin_code'. +// +// - Supporting builtin operator codes beyonds 127 +// +// New builtin operators, whose operator code is larger than 127, can not be +// assigned to the `deprecated_builtin_code` field. In such cases, the +// value of the `builtin_code` field should be used for the builtin operator +// code. In the case, the maximum value of the two fields will be the value of +// the `builtin_code` as the right value. + +BuiltinOperator GetBuiltinCode(const OperatorCode *op_code) { + // Caller should guarantee that the given argument value is not a nullptr. + TFLITE_DCHECK(op_code != nullptr); + + return std::max( + op_code->builtin_code(), + static_cast(op_code->deprecated_builtin_code())); +} + +BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code) { + // Caller should guarantee that the given argument value is not a nullptr. + TFLITE_DCHECK(op_code != nullptr); + + return std::max(op_code->builtin_code, static_cast( + op_code->deprecated_builtin_code)); +} + +int8_t ConvertBuiltinCodeToDeprecatedBuiltinCode( + const BuiltinOperator builtin_code) { + return (builtin_code < BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES) + ? static_cast(builtin_code) + : static_cast( + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES); +} + +// The following methods are the following `OperatorCode` table object creation +// methods for backward compatibility. These are manually copied from the +// flatbuffer generated code from schema v3. They serve as overloads for the +// v3a's CreateOperatorCode functions in schema_generated.h and enable code that +// still assumes flatbuffer schema v3 to be unchanged with the inclusion of the +// schema_utils header. +// TODO(b/162392898): remove once all callers are updated to use schema v3a +// functions. + +flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, BuiltinOperator builtin_code, + flatbuffers::Offset custom_code, int32_t version) { + OperatorCodeBuilder builder_(_fbb); + builder_.add_version(version); + + int8_t deprecated_builtin_code = + static_cast(BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES); + if (builtin_code < BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES) { + deprecated_builtin_code = static_cast(builtin_code); + } + builder_.add_deprecated_builtin_code(deprecated_builtin_code); + builder_.add_custom_code(custom_code); + builder_.add_builtin_code(builtin_code); + return builder_.Finish(); +} + +flatbuffers::Offset CreateOperatorCodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, BuiltinOperator builtin_code, + const char *custom_code, int32_t version) { + auto custom_code__ = custom_code ? _fbb.CreateString(custom_code) : 0; + int8_t deprecated_builtin_code = + static_cast(BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES); + if (builtin_code < BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES) { + deprecated_builtin_code = static_cast(builtin_code); + } + return CreateOperatorCode(_fbb, deprecated_builtin_code, custom_code__, + version, builtin_code); +} + +} // namespace tflite diff --git a/tensorflow/lite/schema/schema_utils.h b/tensorflow/lite/schema/schema_utils.h new file mode 100644 index 00000000000..315a8d0daf4 --- /dev/null +++ b/tensorflow/lite/schema/schema_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SCHEMA_SCHEMA_UTILS_H_ +#define TENSORFLOW_LITE_SCHEMA_SCHEMA_UTILS_H_ + +#include "flatbuffers/flatbuffers.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { + +// The following methods are introduced to resolve op builtin code shortage +// problem. The new builtin opreator will be assigned to the extended builtin +// code field in the flatbuffer schema. Those methods helps to hide builtin code +// details. +BuiltinOperator GetBuiltinCode(const OperatorCode *op_code); + +BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code); + +int8_t ConvertBuiltinCodeToDeprecatedBuiltinCode( + const BuiltinOperator builtin_code); + +// The following methods are for backward compatibility for the early version +// three, which does not have an extended builtin code. +flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + flatbuffers::Offset custom_code = 0, + int32_t version = 1); + +flatbuffers::Offset CreateOperatorCodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + const char *custom_code = nullptr, int32_t version = 1); + +} // namespace tflite + +#endif // TENSORFLOW_LITE_SCHEMA_SCHEMA_UTILS_H_ diff --git a/tensorflow/lite/schema/schema_v3a.fbs b/tensorflow/lite/schema/schema_v3a.fbs new file mode 100644 index 00000000000..cae5a63c615 --- /dev/null +++ b/tensorflow/lite/schema/schema_v3a.fbs @@ -0,0 +1,1109 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. +// Version 3a: Add new builtin op code field. Has backward compatibility with +// version 3. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, + INT8 = 9, + FLOAT64 = 10, + COMPLEX128 = 11, +} + +// Custom quantization parameters for experimenting with new quantization +// techniques. +table CustomQuantization { + custom:[ubyte] (force_align: 16); +} + +// Represents a specific quantization technique's parameters. +union QuantizationDetails { + CustomQuantization, +} + +// Parameters for converting a quantized tensor back to float. +table QuantizationParameters { + // These four parameters are the asymmetric linear quantization parameters. + // Given a quantized value q, the corresponding float value f should be: + // f = scale * (q - zero_point) + // For other quantization types, the QuantizationDetails below is used. + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; // For dequantizing the tensor's values. + zero_point:[long]; + + // If this is not none, the other quantization parameters (i.e. min, max, + // scale, zero_point fields above) are ignored and the value of the + // QuantizationDetails union should be used. + details:QuantizationDetails; + + // Specifies the dimension of the Tensor's shape that the scales and + // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] + // with quantization params: + // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 + // will be quantized across the second dimension of t. + // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 + // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 + // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 + quantized_dimension:int; +} + +// Sparse tensors. +// We use a modification of the TACO format. +// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf +// +// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), +// potentially with a k-dimensional block (0 <= k <= n) with dims +// (dn, ..., dn+k-1), the format needs to specify: +// 1. In what order to traverse these dimensions. For example, to store a 2-D +// matrix in row major order, the traversal order would be (d0, d1), +// whereas to store it in column major order, the traversal order would be +// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order +// could be (d0, d1, d2, d3). +// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original +// tensor dimension in (d0, ..., dn-1). +// 3. In the traversal order defined above, the format (dense vs. sparse) and +// index metadata for each dimension. For a dense dimension, this is just +// the size of that dimension. For a sparse dimension, it's the same as +// the compressed index defined in the Compressed Sparse Row (CSR) format. +// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) + +// The storage type for a dimension. Currently we support: +// 1. DENSE: each coordinate in this dimension is stored implicitly. +// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The +// compression technique is the same what CSR uses. +// More types like a sparse dimension with a different compression technique +// could be added to the list in the future. +enum DimensionType : byte { + DENSE = 0, + SPARSE_CSR = 1, +} + +table Int32Vector { + values:[int]; +} + +table Uint16Vector { + values:[ushort] (force_align: 4); +} + +table Uint8Vector { + values:[ubyte] (force_align: 4); +} + +// Variable-typed buffer to store the index metadata for a sparse dimension. +// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 +// vector. We don't want the per-dimensional index to overflow that range. +union SparseIndexVector { + Int32Vector, + Uint16Vector, + Uint8Vector +} + +table DimensionMetadata { + // Whether a dimension is dense or sparse. + format:DimensionType; + // Index metadata used for a dimension. + // - If format is DimensionType.DENSE then we use the dense_size field to + // store the size of that dimension. Each index in that dimension is + // stored implicitly. + // - If format is DimensionType.SPARSE_CSR then we use array_segments and + // array_indices to encode that dimension. array_segments represents how + // to segment the indices array, each segment corresponds to one element + // in the previous dimension. array_indices represents the index of the + // non-zero elements within this dimension (as those in the CSR matrix + // format, where the first array is row pointers and the second array is + // column indices). + dense_size:int; + array_segments:SparseIndexVector; + array_indices:SparseIndexVector; +} + +// Parameters to encode a sparse TfLite tensor. +table SparsityParameters { + // The traversal order of the dimensions defined in the `shape` field of the + // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, + // ..., dn-1), + // - if not block sparse, the traversal_order is just a permutation of (d0, + // ..., dn-1). For example, a 2-D matrix stored in row-major order would + // have traversal_order = (d0, d1). + // - if block sparse with a k-dimensional block (0 <= k <= n), the + // traversal_order has n + k elements. The first n elements are still a + // permutation of (d0, ..., dn-1). The lask k elements are a permutation + // of (dn, ..., dn+k-1), defining how to traverse a block internally. For + // example, a 2-D matrix with 2-D blocks, both stored in row-major order + // would have traversal_order = (d0, d1, d2, d3). + traversal_order:[int]; + // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + // stores how a block dimension in (dn, ..., dn+k-1) maps to the original + // tensor dimension in (d0, ..., dn). + // It's stored in the order of (dn, ..., dn+k-1). + // If not block-sparse, this field is NULL. + block_map:[int]; + // In the traversal order defined above, the metadata needed for + // each dimension to locate the non-zero values in the original dense tensor. + // The size of the dim_metadata array = the size of the traversal_order array + // = n + k. + dim_metadata:[DimensionMetadata]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, height, width, number of channels] (That's + // Tensorflow's NHWC). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existent empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; + + // Parameters to encode a sparse tensor. See the example in + // tensorflow/lite/testdata/sparse_tensor.json. + sparsity:SparsityParameters; // Optional. + + // Encodes `shape` with unknown dimensions. Unknown dimensions are + // represented with -1. + shape_signature:[int]; // Optional. +} + +// A list of builtin operators. Builtin operators are slightly faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. + +enum BuiltinOperator : int32 { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + DEPTH_TO_SPACE = 5, + DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed + // since different model developers use RELU1 in different ways. Never + // create another op called RELU1. + RELU_N1_TO_1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, + PAD = 34, + UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, + BATCH_TO_SPACE_ND = 37, + SPACE_TO_BATCH_ND = 38, + TRANSPOSE = 39, + MEAN = 40, + SUB = 41, + DIV = 42, + SQUEEZE = 43, + UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + STRIDED_SLICE = 45, + BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, + CAST = 53, + PRELU = 54, + MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, + LEAKY_RELU = 98, + SQUARED_DIFFERENCE = 99, + MIRROR_PAD = 100, + ABS = 101, + SPLIT_V = 102, + UNIQUE = 103, + CEIL = 104, + REVERSE_V2 = 105, + ADD_N = 106, + GATHER_ND = 107, + COS = 108, + WHERE = 109, + RANK = 110, + ELU = 111, + REVERSE_SEQUENCE = 112, + MATRIX_DIAG = 113, + QUANTIZE = 114, + MATRIX_SET_DIAG = 115, + ROUND = 116, + HARD_SWISH = 117, + IF = 118, + WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, + SCATTER_ND = 122, + SELECT_V2 = 123, + DENSIFY = 124, + SEGMENT_SUM = 125, + BATCH_MATMUL = 126, + PLACEHOLDER_FOR_GREATER_OP_CODES = 127 +} + + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, + PadOptions, + GatherOptions, + BatchToSpaceNDOptions, + SpaceToBatchNDOptions, + TransposeOptions, + ReducerOptions, + SubOptions, + DivOptions, + SqueezeOptions, + SequenceRNNOptions, + StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, + CastOptions, + DequantizeOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, + LeakyReluOptions, + SquaredDifferenceOptions, + MirrorPadOptions, + AbsOptions, + SplitVOptions, + UniqueOptions, + ReverseV2Options, + AddNOptions, + GatherNdOptions, + CosOptions, + WhereOptions, + RankOptions, + ReverseSequenceOptions, + MatrixDiagOptions, + QuantizeOptions, + MatrixSetDiagOptions, + HardSwishOptions, + IfOptions, + WhileOptions, + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options, + ScatterNdOptions, + SelectV2Options, + DensifyOptions, + SegmentSumOptions, + BatchMatMulOptions +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU_N1_TO_1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. + padding:Padding; + stride_w:int; + stride_h:int; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // See comments in lite/c/builtin_op_data.h for more details. + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; + // For weights-only quantization, use asymmetric quantization for non + // constant inputs at evaluation time. + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow dynamic_rnn with RNNCell. +table SequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + asymmetric_quantize_inputs:bool; +} + +// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. +table BidirectionalSequenceRNNOptions { + time_major:bool; + fused_activation_function:ActivationFunctionType; + merge_outputs: bool; + asymmetric_quantize_inputs:bool; +} + +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. + fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimension is preserved. Furthermore, + // all but the last dimension of the input and output shapes will be equal. + keep_num_dims: bool; + + // Parameters for FullyConnected version 7 or above. + // If set to true, then weights-only op will use asymmetric quantization for + // inputs. + asymmetric_quantize_inputs: bool; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 4. + pot_scale_int16:bool = true; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; + + // Parameters for LSTM version 4 or above. + asymmetric_quantize_inputs: bool; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; + + // Parameter for Unidirectional Sequence LSTM version 4. + asymmetric_quantize_inputs:bool; +} + +table BidirectionalSequenceLSTMOptions { + // Parameters supported by version 1: + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; + + // Parameters supported by version 2: + // If true then first dimension is sequence, otherwise batch. + // Version 1 implementations assumed time_major to be true, so this default + // value should never change. + time_major: bool = true; + + // Parameters for version 3 or above. + asymmetric_quantize_inputs:bool; +} + +table ResizeBilinearOptions { + new_height: int (deprecated); + new_width: int (deprecated); + align_corners: bool; + half_pixel_centers: bool; +} + +table ResizeNearestNeighborOptions { + align_corners: bool; + half_pixel_centers: bool; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table PadOptions { +} + +table PadV2Options { +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SpaceToBatchNDOptions { +} + +table BatchToSpaceNDOptions { +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +table DepthToSpaceOptions { + block_size: int; +} + +table SubOptions { + fused_activation_function:ActivationFunctionType; + // Parameters supported by version 5 + pot_scale_int16:bool = true; +} + +table DivOptions { + fused_activation_function:ActivationFunctionType; +} + +table TopKV2Options { +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +table GatherOptions { + axis: int; +} + +table TransposeOptions { +} + +table ExpOptions { +} + +table CosOptions { +} + +table ReducerOptions { + keep_dims: bool; +} + +table SqueezeOptions { + squeeze_dims:[int]; +} + +table SplitOptions { + num_splits: int; +} + +table SplitVOptions { + num_splits: int; +} + +table StridedSliceOptions { + begin_mask: int; + end_mask: int; + ellipsis_mask: int; + new_axis_mask: int; + shrink_axis_mask: int; +} + +table LogSoftmaxOptions { +} + +table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; +} + +table DequantizeOptions { +} + +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table RankOptions { +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table AbsOptions { +} + + +table HardSwishOptions { +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { +} + +table LeakyReluOptions { + alpha:float; +} + +table SquaredDifferenceOptions { +} + +enum MirrorPadMode : byte { + // Doesn't include borders. + REFLECT = 0, + // Includes borders. + SYMMETRIC = 1, +} + +table MirrorPadOptions { + mode:MirrorPadMode; +} + +table UniqueOptions { + idx_out_type:TensorType = INT32; +} + +table ReverseV2Options { +} + +table AddNOptions { +} + +table GatherNdOptions { +} + +table WhereOptions { +} + +table ReverseSequenceOptions { + seq_dim:int; + batch_dim:int = 0; +} + +table MatrixDiagOptions { +} + +table QuantizeOptions { +} + +table MatrixSetDiagOptions { +} + +table IfOptions { + then_subgraph_index:int; + else_subgraph_index:int; +} + +table WhileOptions { + cond_subgraph_index:int; + body_subgraph_index:int; +} + +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + +table ScatterNdOptions { +} + +table SelectV2Options { +} + +table DensifyOptions { +} + +table SegmentSumOptions { +} + +table BatchMatMulOptions { + adj_x:bool; + adj_y:bool; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + // This field is for backward compatibility. This field will be used when + // the value of the extended builtin_code field has less than + // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + deprecated_builtin_code:byte; + custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; + + // This field is introduced for resolving op builtin code shortage problem. + // This field will be used when the value of the extended builtin_code field + // is greater than BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. + builtin_code:BuiltinOperator; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; + + // A list of indices to the subgraph's "tensors" that are internal to an Op. + // Internal tensors are those that do not flow in or out of the operation, + // but instead are part of internal computation. As such, the operation's + // implementation may manage its memory more efficiently. They are needed + // however (i.e. not just an implementation detail) since they are part of the + // computation, which may require relevant metadata such as quantization + // parameters. + intermediates:[int]; +} + +// The root type, defining a subgraph, which typically represents an entire +// model. +table SubGraph { + // A list of all tensors used in this subgraph. + tensors:[Tensor]; + + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. + inputs:[int]; + + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of this subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. The generous alignment accommodates mmap-friendly data structures. +table Buffer { + data:[ubyte] (force_align: 16); +} + +table Metadata { + // A human readable string to uniquely identify a Metadata. + name:string; + // An index to the buffers table. + buffer:uint; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. + buffers:[Buffer]; + + // Metadata about the model. Indirects into the existings buffers list. + // Deprecated, prefer to use metadata field. + metadata_buffer:[int]; + + // Metadata about the model. + metadata:[Metadata]; +} + +root_type Model; diff --git a/tensorflow/lite/special_rules.bzl b/tensorflow/lite/special_rules.bzl index cc5fd15e5d5..68143c976f4 100644 --- a/tensorflow/lite/special_rules.bzl +++ b/tensorflow/lite/special_rules.bzl @@ -68,3 +68,12 @@ def tflite_hexagon_nn_skel_libraries(): return ["//third_party/hexagon_nn_skel:libhexagon_nn_skel"] """ return [] + +def tflite_schema_utils_friends(): + """This is a no-op outside of Google. + + Return the package group declaration to which targets for Flatbuffer schema utilities.""" + + # Its usage should be rare, and is often abused by tools that are doing + # Flatbuffer creation/manipulation in unofficially supported ways." + return ["//..."] diff --git a/tensorflow/lite/stateful_error_reporter.h b/tensorflow/lite/stateful_error_reporter.h new file mode 100644 index 00000000000..cf6693431f9 --- /dev/null +++ b/tensorflow/lite/stateful_error_reporter.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ +#define TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ + +#include + +#include "tensorflow/lite/core/api/error_reporter.h" + +namespace tflite { + +// Similar to tflite::ErrorReporter, except that it allows callers to get the +// last error message. +class StatefulErrorReporter : public ErrorReporter { + public: + // Returns last error message. Returns empty string if no error is reported. + virtual std::string message() = 0; +}; + +} // namespace tflite + +#endif // TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py index d9cd6883a8d..7825dfb560a 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib.py @@ -28,11 +28,11 @@ from google.protobuf.message import DecodeError from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.lite.python import convert_saved_model as _convert_saved_model from tensorflow.lite.python import lite as _lite -from tensorflow.lite.python import lite_constants as constants from tensorflow.lite.python import util as _util from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.keras.preprocessing import image @@ -97,7 +97,7 @@ def _convert(converter, **kwargs): if "post_training_quantize" in kwargs: converter.optimizations = [_lite.Optimize.DEFAULT] if kwargs.get("quantize_to_float16", False): - converter.target_spec.supported_types = [constants.FLOAT16] + converter.target_spec.supported_types = [dtypes.float16] if kwargs.get("post_training_quantize_16x8", False): input_size = kwargs.get("model_input_size") diff --git a/tensorflow/lite/testing/selective_build_test.cc b/tensorflow/lite/testing/selective_build_test.cc index 1a9a5b2efdb..c3a0cf20ecc 100644 --- a/tensorflow/lite/testing/selective_build_test.cc +++ b/tensorflow/lite/testing/selective_build_test.cc @@ -68,7 +68,7 @@ TEST(SelectiveBuiltTest, AddModel) { EXPECT_THAT(RunWithRandomInputs(model), true); } -TEST(SelectiveBuiltTest, LGTMModel) { +TEST(SelectiveBuiltTest, LSTMModel) { std::string model = "third_party/tensorflow/lite/testdata/lstm.bin"; EXPECT_THAT(RunWithRandomInputs(model), true); } diff --git a/tensorflow/lite/toco/python/BUILD b/tensorflow/lite/toco/python/BUILD index ac7d94a37bd..b71e2f7728f 100644 --- a/tensorflow/lite/toco/python/BUILD +++ b/tensorflow/lite/toco/python/BUILD @@ -37,6 +37,9 @@ cc_library( deps = [ "@com_google_protobuf//:protobuf_headers", "//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863 + "//tensorflow/c:kernels", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", diff --git a/tensorflow/lite/toco/python/toco_python_api.cc b/tensorflow/lite/toco/python/toco_python_api.cc index eea3fb005bd..edcc1f805b4 100644 --- a/tensorflow/lite/toco/python/toco_python_api.cc +++ b/tensorflow/lite/toco/python/toco_python_api.cc @@ -20,10 +20,13 @@ limitations under the License. #include #include "google/protobuf/text_format.h" +#include "tensorflow/c/kernels.h" #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h" #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" @@ -317,4 +320,69 @@ PyObject* MlirSparsifyModel(PyObject* data) { builder.GetSize()); } +PyObject* RegisterCustomOpdefs(PyObject* list) { + if (!PyList_Check(list)) { + PyErr_SetString(PyExc_TypeError, "Expected list in argument"); + return nullptr; + } + + int64 size = PyList_Size(list); + for (int i = 0; i < size; ++i) { + // Get character array from Python object. + char* tf_opdefs; + Py_ssize_t len; + if (tflite::python_utils::ConvertFromPyString(PyList_GetItem(list, i), + &tf_opdefs, &len) == -1) { + PyErr_Format(PyExc_ValueError, + "Failed to convert Python string at index %d of custom op " + "defs argument", + i); + return nullptr; + } + + // Parse op def from character array. + tensorflow::OpDef opdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs, &opdef)) { + PyErr_Format( + PyExc_ValueError, + "Failed to parse opdefs at index %d of custom op defs argument: %s", + i, tf_opdefs); + return nullptr; + } + + // Register extra opdefs to TensorFlow global op registry. + tensorflow::OpRegistry::Global()->Register( + [opdef]( + tensorflow::OpRegistrationData* op_reg_data) -> tensorflow::Status { + *op_reg_data = tensorflow::OpRegistrationData(opdef); + return tensorflow::Status::OK(); + }); + + // Register the corresponding fake op kernel. + const char* node_name = opdef.name().c_str(); + const char* op_name = opdef.name().c_str(); + const char* device_name = "CPU"; + static auto fake_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + }; + + TF_KernelBuilder* builder = + TF_NewKernelBuilder(op_name, device_name, /*create_func=*/nullptr, + fake_compute_func, /*delete_func=*/nullptr); + + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(node_name, builder, status); + if (TF_GetCode(status) != TF_OK) { + TF_DeleteStatus(status); + PyErr_Format(PyExc_ValueError, + "Failed to register fake op kernel at index %d of custom op " + "defs argument", + i); + return nullptr; + } + TF_DeleteStatus(status); + } + + Py_RETURN_TRUE; +} + } // namespace toco diff --git a/tensorflow/lite/toco/python/toco_python_api.h b/tensorflow/lite/toco/python/toco_python_api.h index 058ae9fb942..df9d6e11bcf 100644 --- a/tensorflow/lite/toco/python/toco_python_api.h +++ b/tensorflow/lite/toco/python/toco_python_api.h @@ -49,6 +49,9 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, // Sparsifies model to encode sparse tensors with proper format. Throws error if // sparsification fails. PyObject* MlirSparsifyModel(PyObject* data); + +// Registers the given custom opdefs to TensorFlow global op registry. +PyObject* RegisterCustomOpdefs(PyObject* list); } // namespace toco #endif // TENSORFLOW_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index 18b531fd5f1..91bb77ff391 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -95,6 +95,7 @@ cc_library( ":types", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:tooling_util", "//tensorflow/lite/tools/optimize:quantize_weights", @@ -113,6 +114,7 @@ tf_cc_test( ":operator", "//tensorflow/core:ops", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest_main", "@flatbuffers", ], @@ -132,6 +134,7 @@ cc_library( ":types", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:tooling_util", "//tensorflow/lite/tools:verifier", @@ -172,6 +175,7 @@ tf_cc_test( "//tensorflow/core:ops", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest_main", "@flatbuffers", ], diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index d109ab875b5..3ef1c67c721 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/lite/context.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/toco/tflite/op_version.h" #include "tensorflow/lite/toco/tflite/operator.h" #include "tensorflow/lite/toco/tflite/types.h" diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index ced55921e50..b5a55cfa090 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/toco/tflite/builtin_operator.h" #include "tensorflow/lite/toco/tflite/operator.h" #include "tensorflow/lite/toco/tflite/types.h" @@ -171,10 +172,10 @@ class ExportTest : public ::testing::Test { auto* model = ::tflite::GetModel(result.data()); for (const ::tflite::OperatorCode* opcode : *model->operator_codes()) { - if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { - names.push_back( - std::string("builtin:") + - ::tflite::EnumNameBuiltinOperator(opcode->builtin_code())); + auto builtin_code = GetBuiltinCode(opcode); + if (builtin_code != ::tflite::BuiltinOperator_CUSTOM) { + names.push_back(std::string("builtin:") + + ::tflite::EnumNameBuiltinOperator(builtin_code)); } else { names.push_back(std::string("custom:") + opcode->custom_code()->c_str()); diff --git a/tensorflow/lite/toco/tflite/import.cc b/tensorflow/lite/toco/tflite/import.cc index 4253ab93160..1ea6907b73c 100644 --- a/tensorflow/lite/toco/tflite/import.cc +++ b/tensorflow/lite/toco/tflite/import.cc @@ -17,6 +17,7 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/toco/tflite/operator.h" #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/tooling_util.h" @@ -42,9 +43,9 @@ void LoadOperatorsTable(const ::tflite::Model& input_model, auto opcodes = input_model.operator_codes(); if (!opcodes) return; for (const auto* opcode : *opcodes) { - if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { - operators_table->push_back( - EnumNameBuiltinOperator(opcode->builtin_code())); + auto builtin_code = GetBuiltinCode(opcode); + if (builtin_code != ::tflite::BuiltinOperator_CUSTOM) { + operators_table->push_back(EnumNameBuiltinOperator(builtin_code)); } else { operators_table->push_back(opcode->custom_code()->c_str()); } diff --git a/tensorflow/lite/toco/tflite/import_test.cc b/tensorflow/lite/toco/tflite/import_test.cc index 6163ebab45b..fe7dd31a40a 100644 --- a/tensorflow/lite/toco/tflite/import_test.cc +++ b/tensorflow/lite/toco/tflite/import_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/version.h" namespace toco { diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index a94aca941c9..22fa1ff1cea 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -160,6 +160,7 @@ cc_library( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite:string", + "//tensorflow/lite/schema:schema_utils", "@com_googlesource_code_re2//:re2", ], ) @@ -194,7 +195,10 @@ cc_library( "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api:error_reporter", + "//tensorflow/lite/core/api:op_resolver", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -217,6 +221,7 @@ cc_test( "//tensorflow/lite:util", "//tensorflow/lite/core/api", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest", "@flatbuffers", @@ -278,6 +283,7 @@ cc_library( "//tensorflow/core:tensorflow", "//tensorflow/lite:framework", "//tensorflow/lite:util", + "//tensorflow/lite/schema:schema_utils", "@flatbuffers", "@jsoncpp_git//:jsoncpp", ], diff --git a/tensorflow/lite/tools/benchmark/android/BUILD b/tensorflow/lite/tools/benchmark/android/BUILD index 6645b730bac..be9c379af50 100644 --- a/tensorflow/lite/tools/benchmark/android/BUILD +++ b/tensorflow/lite/tools/benchmark/android/BUILD @@ -10,8 +10,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - # See README.md for details about building and executing this benchmark. android_binary( name = "benchmark_model", diff --git a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h index 31405dfb998..8917c254825 100644 --- a/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h +++ b/tensorflow/lite/tools/benchmark/experimental/c/c_api_types.h @@ -421,7 +421,7 @@ typedef struct TfLiteCustomAllocation { size_t bytes; } TfLiteCustomAllocation; -// An tensor in the interpreter system which is a wrapper around a buffer of +// A tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). #ifndef TF_LITE_STATIC_MEMORY typedef struct TfLiteTensor { diff --git a/tensorflow/lite/tools/cmake/README.md b/tensorflow/lite/tools/cmake/README.md index 7624b6623c2..23ddd2d8093 100644 --- a/tensorflow/lite/tools/cmake/README.md +++ b/tensorflow/lite/tools/cmake/README.md @@ -40,11 +40,39 @@ cd tflite_build cmake ../tensorflow_src/tensorflow/lite ``` +If you want to configure Android build with GPU delegate support, + +```sh +mkdir tflite_build +cd tflite_build +cmake -DCMAKE_TOOLCHAIN_FILE=/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a -DTFLITE_ENABLE_GPU=ON ../tensorflow_src/tensorflow/lite +``` + + #### Step 4. Build TensorFlow Lite +In the tflite_build directory, + ```sh cmake --build . -j ``` +Or + +```sh +make -j +``` + + **Note:** This should compile a static library `libtensorflow-lite.a` in the current directory. + + +#### Step 5. Build TensorFlow Lite Benchmark Tool + +In the tflite_build directory, + +```sh +make benchmark_model -j +``` diff --git a/tensorflow/lite/tools/cmake/modules/Findopencl_headers.cmake b/tensorflow/lite/tools/cmake/modules/Findopencl_headers.cmake new file mode 100644 index 00000000000..6384aa0d1da --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/Findopencl_headers.cmake @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(opencl_headers) diff --git a/tensorflow/lite/tools/cmake/modules/Findvulkan_headers.cmake b/tensorflow/lite/tools/cmake/modules/Findvulkan_headers.cmake new file mode 100644 index 00000000000..8dbb0e6d6ae --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/Findvulkan_headers.cmake @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(vulkan_headers) diff --git a/tensorflow/lite/tools/cmake/modules/opencl_headers.cmake b/tensorflow/lite/tools/cmake/modules/opencl_headers.cmake new file mode 100644 index 00000000000..c54e526f5ed --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/opencl_headers.cmake @@ -0,0 +1,40 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(TARGET opencl_headers OR opencl_headers_POPULATED) + return() +endif() + +include(FetchContent) + +OverridableFetchContent_Declare( + opencl_headers + GIT_REPOSITORY https://github.com/KhronosGroup/OpenCL-Headers + # GIT_TAG must keep in sync with tensorflow/third_party/opencl_headers/workspace.bzl + GIT_TAG 0d5f18c6e7196863bc1557a693f1509adfcee056 + GIT_PROGRESS TRUE + PREFIX "${CMAKE_BINARY_DIR}" + SOURCE_DIR "${CMAKE_BINARY_DIR}/opencl_headers" +) + +OverridableFetchContent_GetProperties(opencl_headers) +if(NOT opencl_headers) + OverridableFetchContent_Populate(opencl_headers) +endif() + +include_directories( + AFTER + "${opencl_headers_SOURCE_DIR}/" +) diff --git a/tensorflow/lite/tools/cmake/modules/vulkan_headers.cmake b/tensorflow/lite/tools/cmake/modules/vulkan_headers.cmake new file mode 100644 index 00000000000..4b8fc34104b --- /dev/null +++ b/tensorflow/lite/tools/cmake/modules/vulkan_headers.cmake @@ -0,0 +1,40 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(TARGET vulkan_headers OR vulkan_headers_POPULATED) + return() +endif() + +include(FetchContent) + +OverridableFetchContent_Declare( + vulkan_headers + GIT_REPOSITORY https://github.com/KhronosGroup/Vulkan-Headers + # GIT_TAG must keep in sync with tensorflow/third_party/vulkan_headers/workspace.bzl + GIT_TAG 0e57fc1cfa56a203efe43e4dfb9b3c9e9b105593 + GIT_PROGRESS TRUE + PREFIX "${CMAKE_BINARY_DIR}" + SOURCE_DIR "${CMAKE_BINARY_DIR}/vulkan_headers" +) + +OverridableFetchContent_GetProperties(vulkan_headers) +if(NOT vulkan_headers) + OverridableFetchContent_Populate(vulkan_headers) +endif() + +include_directories( + AFTER + "${vulkan_headers_SOURCE_DIR}/include" +) diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index ea6306aea5a..e5f205c0b0b 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -50,4 +50,5 @@ include_directories( "${PTHREADPOOL_SOURCE_DIR}/include" "${FP16_SOURCE_DIR}/include" "${XNNPACK_SOURCE_DIR}/include" + "${CPUINFO_SOURCE_DIR}/" ) diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/tools/evaluation/stages/BUILD index 17cb0f4ef9f..4363aff01a9 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD @@ -51,7 +51,7 @@ cc_library( "//tensorflow/lite/tools/evaluation/proto:preprocessing_steps_cc_proto", ] + select({ "//tensorflow:android": [ - "//tensorflow/core:android_jpeg_internal", + "//tensorflow/core:portable_jpeg_internal", ], "//conditions:default": [ "//tensorflow/core:jpeg_internal", diff --git a/tensorflow/lite/tools/gen_op_registration.cc b/tensorflow/lite/tools/gen_op_registration.cc index 2a95df799e8..76f1d4e4046 100644 --- a/tensorflow/lite/tools/gen_op_registration.cc +++ b/tensorflow/lite/tools/gen_op_registration.cc @@ -19,6 +19,7 @@ limitations under the License. #include "re2/re2.h" #include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { @@ -37,10 +38,11 @@ void ReadOpsFromModel(const ::tflite::Model* model, if (!opcodes) return; for (const auto* opcode : *opcodes) { const int version = opcode->version(); - if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { - auto iter_and_bool = builtin_ops->insert(std::make_pair( - tflite::EnumNameBuiltinOperator(opcode->builtin_code()), - std::make_pair(version, version))); + auto builtin_code = GetBuiltinCode(opcode); + if (builtin_code != ::tflite::BuiltinOperator_CUSTOM) { + auto iter_and_bool = builtin_ops->insert( + std::make_pair(tflite::EnumNameBuiltinOperator(builtin_code), + std::make_pair(version, version))); auto& versions = iter_and_bool.first->second; versions.first = std::min(versions.first, version); versions.second = std::max(versions.second, version); diff --git a/tensorflow/lite/tools/list_flex_ops.cc b/tensorflow/lite/tools/list_flex_ops.cc index b4db4f76d2b..188efc57f90 100644 --- a/tensorflow/lite/tools/list_flex_ops.cc +++ b/tensorflow/lite/tools/list_flex_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/util.h" namespace tflite { @@ -87,7 +88,7 @@ void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) { for (int i = 0; i < operators->size(); ++i) { const tflite::Operator* op = operators->Get(i); const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index()); - if (opcode->builtin_code() != tflite::BuiltinOperator_CUSTOM || + if (tflite::GetBuiltinCode(opcode) != tflite::BuiltinOperator_CUSTOM || !tflite::IsFlexOp(opcode->custom_code()->c_str())) { continue; } diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index eb07503fd5d..9d90e9526be 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -41,6 +41,10 @@ INCLUDES := \ -I$(MAKEFILE_DIR)/downloads/farmhash/src \ -I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ -I$(MAKEFILE_DIR)/downloads/fp16/include \ +-I$(MAKEFILE_DIR)/downloads/cpuinfo \ +-I$(MAKEFILE_DIR)/downloads/cpuinfo/include \ +-I$(MAKEFILE_DIR)/downloads/cpuinfo/src \ +-I$(MAKEFILE_DIR)/downloads/cpuinfo/deps/clog/include \ -I$(OBJDIR) # This is at the end so any globally-installed frameworks like protobuf don't # override local versions in the source tree. @@ -58,7 +62,7 @@ LIBS := \ # There are no rules for compiling objects for the host system (since we don't # generate things like the protobuf compiler that require that), so all of # these settings are for the target compiler. -CFLAGS := -O3 -DNDEBUG -fPIC $(EXTRA_CFLAGS) +CFLAGS := -O3 -DNDEBUG -DCPU_SETSIZE=__CPU_SETSIZE -fPIC $(EXTRA_CFLAGS) CXXFLAGS := $(CFLAGS) --std=c++11 $(EXTRA_CXXFLAGS) LDOPTS := -L/usr/local/lib ARFLAGS := -r @@ -69,6 +73,13 @@ ifeq ($(HOST_OS),windows) CXXFLAGS += -fext-numeric-literals -D__LITTLE_ENDIAN__ endif +ifeq ($(TARGET_ARCH),x86_64) + ifeq ($(TARGET),linux) + CXXFLAGS += -DTFLITE_HAVE_CPUINFO + TFLITE_HAVE_CPUINFO := true + endif +endif + # Auto-detect optimization opportunity if building natively. ifeq ($(HOST_OS),$(TARGET)) ifeq ($(HOST_ARCH),$(TARGET_ARCH)) @@ -127,7 +138,17 @@ $(wildcard tensorflow/lite/c/*.cc) \ $(wildcard tensorflow/lite/core/*.cc) \ $(wildcard tensorflow/lite/core/api/*.cc) \ $(wildcard tensorflow/lite/experimental/resource/*.cc) \ +$(wildcard tensorflow/lite/schema/schema_utils.cc) \ $(wildcard tensorflow/lite/tools/make/downloads/ruy/ruy/*.cc) +ifeq ($(TFLITE_HAVE_CPUINFO),true) +CORE_CC_ALL_SRCS += \ +$(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/*.c) \ +$(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/x86/*.c) \ +$(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/x86/linux/*.c) \ +$(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/x86/cache/*.c) \ +$(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/linux/*.c) \ +$(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/deps/clog/src/*.c) +endif ifneq ($(BUILD_TYPE),micro) CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/lite/kernels/*.cc) \ @@ -167,6 +188,7 @@ $(wildcard tensorflow/lite/*/*/*/*/*/*test.cc) \ $(wildcard tensorflow/lite/*/*/*/*/*/*tool.cc) \ $(wildcard tensorflow/lite/kernels/*test_main.cc) \ $(wildcard tensorflow/lite/kernels/*test_util*.cc) \ +$(wildcard tensorflow/lite/tools/make/downloads/cpuinfo/src/*/mock*.c) \ tensorflow/lite/tflite_with_xnnpack.cc \ $(MINIMAL_SRCS) diff --git a/tensorflow/lite/tools/make/download_dependencies.sh b/tensorflow/lite/tools/make/download_dependencies.sh index 27537823be2..698256302c3 100755 --- a/tensorflow/lite/tools/make/download_dependencies.sh +++ b/tensorflow/lite/tools/make/download_dependencies.sh @@ -51,6 +51,8 @@ FLATBUFFERS_SHA="62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc4 FFT2D_URL="https://storage.googleapis.com/mirror.tensorflow.org/www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz" FP16_URL="https://github.com/Maratyszcza/FP16/archive/febbb1c163726b5db24bed55cc9dc42529068997.zip" FFT2D_SHA="ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9" +CPUINFO_URL="https://github.com/pytorch/cpuinfo/archive/c2092219e7c874783a00a62edb94ddc672f57ab3.zip" +CPUINFO_SHA="ea56c399a4f6ca5f749e71acb6a7bfdc653eb65d8f658cb2e414a2fcdca1fe8b" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. replace_by_sed() { @@ -115,6 +117,7 @@ download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" "${FARMHASH_S download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" "${FLATBUFFERS_SHA}" download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" "${FFT2D_SHA}" download_and_extract "${FP16_URL}" "${DOWNLOADS_DIR}/fp16" +download_and_extract "${CPUINFO_URL}" "${DOWNLOADS_DIR}/cpuinfo" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index fe9ef54cb80..9beb7a239e8 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -23,6 +23,7 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", "@flatbuffers", ], @@ -39,6 +40,7 @@ tf_cc_test( ":modify_model_interface", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", "@flatbuffers", @@ -78,6 +80,7 @@ tf_cc_test( ":quantization_wrapper_utils", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", @@ -132,6 +135,7 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -165,6 +169,7 @@ cc_library( "//tensorflow/lite:framework", "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", ], ) @@ -188,6 +193,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/lite/testing:util", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -212,6 +218,7 @@ cc_library( # TODO(suharshs): Move the relevant quantization utils to a non-internal location. "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "//tensorflow/core:tflite_portable_logging", ], ) @@ -239,6 +246,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], @@ -269,6 +277,7 @@ cc_library( "//tensorflow/lite:util", "//tensorflow/lite/core/api", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@flatbuffers", ], ) @@ -316,6 +325,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_googletest//:gtest", "@flatbuffers", ], diff --git a/tensorflow/lite/tools/optimize/calibration/BUILD b/tensorflow/lite/tools/optimize/calibration/BUILD index 674ef0ae4f6..53bd1bb4faf 100644 --- a/tensorflow/lite/tools/optimize/calibration/BUILD +++ b/tensorflow/lite/tools/optimize/calibration/BUILD @@ -48,6 +48,7 @@ cc_library( "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/schema:schema_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/tools/optimize/calibration/calibrator.cc b/tensorflow/lite/tools/optimize/calibration/calibrator.cc index e4e1afad7e3..be8fad8a221 100644 --- a/tensorflow/lite/tools/optimize/calibration/calibrator.cc +++ b/tensorflow/lite/tools/optimize/calibration/calibrator.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/lite/model.h" #include "tensorflow/lite/op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/stderr_reporter.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.h" @@ -286,7 +287,7 @@ string GetOpName(const tflite::OperatorCode& opcode) { if (opcode.custom_code() != nullptr) { return opcode.custom_code()->str(); } - return tflite::EnumNamesBuiltinOperator()[opcode.builtin_code()]; + return tflite::EnumNamesBuiltinOperator()[GetBuiltinCode(&opcode)]; } // A |CalibrationReader| that owns the Calibrator. @@ -348,7 +349,7 @@ TfLiteStatus BuildLoggingInterpreter( op_info.node_index = i; auto op = operators->Get(i); auto operator_code = operator_codes->Get(op->opcode_index()); - op_info.builtin_op_code = operator_code->builtin_code(); + op_info.builtin_op_code = GetBuiltinCode(operator_code); op_info.name = GetOpName(*operator_code); op_info.is_custom_op = operator_code->custom_code() != nullptr; op_info.version = operator_code->version(); @@ -367,7 +368,7 @@ TfLiteStatus BuildLoggingInterpreter( custom_op_and_versions.insert( {op_info.name.c_str(), operator_code->version()}); } else { - op_info.registration = op_resolver.FindOp(operator_code->builtin_code(), + op_info.registration = op_resolver.FindOp(GetBuiltinCode(operator_code), operator_code->version()); builtin_op_and_versions.insert( {op_info.builtin_op_code, operator_code->version()}); diff --git a/tensorflow/lite/tools/optimize/model_utils.cc b/tensorflow/lite/tools/optimize/model_utils.cc index f30fa8b7bdd..7224c623c77 100644 --- a/tensorflow/lite/tools/optimize/model_utils.cc +++ b/tensorflow/lite/tools/optimize/model_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/tools/optimize/operator_property.h" namespace tflite { @@ -34,13 +35,15 @@ namespace { int32_t GetOrInsertOpCodeIndex(ModelT* model, const BuiltinOperator& op_code, int32_t version) { for (size_t i = 0; i < model->operator_codes.size(); ++i) { - if (model->operator_codes[i]->builtin_code == op_code) { + if (GetBuiltinCode(model->operator_codes[i].get()) == op_code) { return i; } } model->operator_codes.push_back(absl::make_unique()); int op_code_idx = model->operator_codes.size() - 1; model->operator_codes[op_code_idx]->builtin_code = op_code; + model->operator_codes[op_code_idx]->deprecated_builtin_code = + ConvertBuiltinCodeToDeprecatedBuiltinCode(op_code); // Version 2 and onwards supports INT8 inputs. model->operator_codes[op_code_idx]->version = version; diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index d40d4455e24..570f84251a3 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/tools/optimize/model_utils.h" namespace tflite { @@ -66,7 +67,7 @@ std::vector GetInputTensors(const TensorType& input_type, op_idx--) { OperatorT* op = subgraph->operators[op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); TensorT* input_tensor = subgraph->tensors[op->inputs[0]].get(); if (input_tensors.find(input_tensor) == input_tensors.end()) { continue; @@ -141,7 +142,7 @@ std::vector GetOutputTensors(const TensorType& output_type, op_idx--) { OperatorT* op = subgraph->operators[op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); TensorT* output_tensor = subgraph->tensors[op->outputs[0]].get(); if (output_tensors.find(output_tensor) == output_tensors.end()) { continue; @@ -217,7 +218,8 @@ TfLiteStatus SetOutputTypeToUINT8(ModelT* model, // Find Quant op code index. size_t quant_op_index = 0; for (size_t i = 0; i < model->operator_codes.size(); ++i) { - if (model->operator_codes[i]->builtin_code == BuiltinOperator_QUANTIZE) { + if (GetBuiltinCode(model->operator_codes[i].get()) == + BuiltinOperator_QUANTIZE) { quant_op_index = i; } } diff --git a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc index 99e0ad35b2d..08a31cd0051 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface_test.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace optimize { @@ -45,12 +46,18 @@ std::unique_ptr CreateQuantizedModelSingleInputOutput( // Op code quant_op_code->builtin_code = BuiltinOperator_QUANTIZE; + quant_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_QUANTIZE); quant_op_code->version = 2; fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; + fc_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_FULLY_CONNECTED); fc_op_code->version = 2; dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE; + dequant_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_DEQUANTIZE); dequant_op_code->version = 2; // Op. @@ -136,12 +143,18 @@ std::unique_ptr CreateQuantizedModelMultipleInputOutput( // Op code quant_op_code->builtin_code = BuiltinOperator_QUANTIZE; + quant_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_QUANTIZE); quant_op_code->version = 2; fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; + fc_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_FULLY_CONNECTED); fc_op_code->version = 2; dequant_op_code->builtin_code = BuiltinOperator_DEQUANTIZE; + dequant_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_DEQUANTIZE); dequant_op_code->version = 2; // Op. @@ -257,6 +270,8 @@ std::unique_ptr CreateFloatModel() { // Op code fc_op_code->builtin_code = BuiltinOperator_FULLY_CONNECTED; + fc_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_FULLY_CONNECTED); fc_op_code->version = 2; // Op. @@ -599,10 +614,12 @@ TEST(ModelInterface, Float) { EXPECT_EQ(model->subgraphs[0]->outputs.size(), 1); EXPECT_EQ(model->subgraphs[0]->outputs[0], 1); EXPECT_EQ(model->operator_codes.size(), 3); - EXPECT_EQ(model->operator_codes[0]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model->operator_codes[0].get()), BuiltinOperator_FULLY_CONNECTED); - EXPECT_EQ(model->operator_codes[1]->builtin_code, BuiltinOperator_DEQUANTIZE); - EXPECT_EQ(model->operator_codes[2]->builtin_code, BuiltinOperator_QUANTIZE); + EXPECT_EQ(GetBuiltinCode(model->operator_codes[1].get()), + BuiltinOperator_DEQUANTIZE); + EXPECT_EQ(GetBuiltinCode(model->operator_codes[2].get()), + BuiltinOperator_QUANTIZE); EXPECT_EQ(model->subgraphs[0]->operators.size(), 3); auto dequantize_op = model->subgraphs[0]->operators[0].get(); diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index ce56b186216..6375016d527 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/tools/optimize/operator_property.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace optimize { @@ -41,7 +42,8 @@ const OpVariant GetOperatorVariant(const ModelT* model, int subgraph_index, OpVariant op_variant; OperatorT* op = model->subgraphs.at(subgraph_index)->operators[op_index].get(); - op_variant.op_code = model->operator_codes[op->opcode_index]->builtin_code; + op_variant.op_code = + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); if (op_variant.op_code == BuiltinOperator_LSTM) { if (op->inputs.size() == 5) { // The 5 input ("basic") LSTM is not supported in this tooling (yet). @@ -97,6 +99,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.inputs = {{0, {}}, {1, {}}}; property.outputs = {{0, {}}}; property.version = 2; + property.quantize_input_as_activations = true; break; } case BuiltinOperator_BATCH_TO_SPACE_ND: diff --git a/tensorflow/lite/tools/optimize/python/BUILD b/tensorflow/lite/tools/optimize/python/BUILD index 050c0008924..34f57cbecaf 100644 --- a/tensorflow/lite/tools/optimize/python/BUILD +++ b/tensorflow/lite/tools/optimize/python/BUILD @@ -52,7 +52,7 @@ py_library( name = "modify_model_interface_constants", srcs = ["modify_model_interface_constants.py"], srcs_version = "PY3", - deps = ["//tensorflow/lite/python:lite_constants"], + deps = ["//tensorflow/python:dtypes"], ) pybind_extension( diff --git a/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py b/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py index cbe1aa92022..f7c7cc60d5c 100644 --- a/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py +++ b/tensorflow/lite/tools/optimize/python/modify_model_interface_constants.py @@ -19,12 +19,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.lite.python import lite_constants +from tensorflow.python.framework import dtypes STR_TO_TFLITE_TYPES = { - 'INT8': lite_constants.INT8, - 'INT16': lite_constants.INT16, - 'UINT8': lite_constants.QUANTIZED_UINT8 + 'INT8': dtypes.int8, + 'UINT8': dtypes.uint8, + 'INT16': dtypes.int16, } TFLITE_TO_STR_TYPES = {v: k for k, v in STR_TO_TFLITE_TYPES.items()} diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index 49009e49600..4ce0d01fd12 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/tools/optimize/test_util.h" @@ -471,8 +472,9 @@ TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensorFromMinMax) { readonly_model->UnPackTo(&model); auto subgraph = model.subgraphs[0].get(); auto conv_op = subgraph->operators.at(0).get(); - ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code, - BuiltinOperator_CONV_2D); + ASSERT_EQ( + GetBuiltinCode(model.operator_codes.at(conv_op->opcode_index).get()), + BuiltinOperator_CONV_2D); int32_t weights_tensor_idx = conv_op->inputs[1]; TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); @@ -520,8 +522,9 @@ TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensorNullQuantParams) { readonly_model->UnPackTo(&model); auto subgraph = model.subgraphs[0].get(); auto conv_op = subgraph->operators.at(0).get(); - ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code, - BuiltinOperator_CONV_2D); + ASSERT_EQ( + GetBuiltinCode(model.operator_codes.at(conv_op->opcode_index).get()), + BuiltinOperator_CONV_2D); int32_t weights_tensor_idx = conv_op->inputs[1]; TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); // Empty quantization parameters. @@ -554,8 +557,9 @@ TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensor) { readonly_model->UnPackTo(&model); auto subgraph = model.subgraphs[0].get(); auto conv_op = subgraph->operators.at(0).get(); - ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code, - BuiltinOperator_CONV_2D); + ASSERT_EQ( + GetBuiltinCode(model.operator_codes.at(conv_op->opcode_index).get()), + BuiltinOperator_CONV_2D); int32_t weights_tensor_idx = conv_op->inputs[1]; TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); @@ -586,8 +590,9 @@ TEST_F(QuantizationUtilsTest, QuantizeFloat16) { readonly_model->UnPackTo(&model); auto subgraph = model.subgraphs[0].get(); auto conv_op = subgraph->operators.at(0).get(); - ASSERT_EQ(model.operator_codes.at(conv_op->opcode_index)->builtin_code, - BuiltinOperator_CONV_2D); + ASSERT_EQ( + GetBuiltinCode(model.operator_codes.at(conv_op->opcode_index).get()), + BuiltinOperator_CONV_2D); int32_t weights_tensor_idx = conv_op->inputs[1]; TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc index 371de7d04bc..38fef0660d5 100644 --- a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" namespace tflite { namespace optimize { @@ -35,6 +36,8 @@ TEST(LstmPreprocess, Add2Tensors) { auto lstm_op = absl::make_unique(); lstm_op_code->builtin_code = BuiltinOperator_LSTM; + lstm_op_code->deprecated_builtin_code = + static_cast(BuiltinOperator_LSTM); lstm_op_code->version = 2; lstm_op->opcode_index = 0; lstm_op->inputs = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1, -1, @@ -68,7 +71,8 @@ TEST(LstmPreprocess, Add2Tensors) { EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26); EXPECT_EQ(model->buffers.size(), 1); - EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM); + EXPECT_EQ(GetBuiltinCode(model->operator_codes[0].get()), + BuiltinOperator_LSTM); EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0"); EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0"); EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1"); @@ -94,7 +98,8 @@ TEST(LstmPreprocess, Add2Tensors) { EXPECT_EQ(model->subgraphs[0]->tensors.size(), 26); EXPECT_EQ(model->buffers.size(), 1); - EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM); + EXPECT_EQ(GetBuiltinCode(model->operator_codes[0].get()), + BuiltinOperator_LSTM); EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0"); EXPECT_EQ(model->subgraphs[0]->tensors[21]->name, "intermediate_0_0"); EXPECT_EQ(model->subgraphs[0]->tensors[22]->name, "intermediate_0_1"); diff --git a/tensorflow/lite/tools/optimize/quantize_model.cc b/tensorflow/lite/tools/optimize/quantize_model.cc index ca9a51abefe..e7f1c7a8bdf 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.cc +++ b/tensorflow/lite/tools/optimize/quantize_model.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/tools/optimize/model_utils.h" #include "tensorflow/lite/tools/optimize/operator_property.h" #include "tensorflow/lite/tools/optimize/quantization_utils.h" @@ -59,7 +60,7 @@ operator_property::OperatorProperty GetOperatorProperty( const SubGraphT* subgraph = model->subgraphs[subgraph_index].get(); const OperatorT* op = subgraph->operators[op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); if (activations_type == TensorType_INT16 && !property.quantizable_int16) { property.quantizable = false; } @@ -521,7 +522,7 @@ TfLiteStatus QuantizeOpInput( SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); OperatorT* op = subgraph->operators[*op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); if (input_idx >= op->inputs.size()) { TF_LITE_REPORT_ERROR( error_reporter, @@ -716,7 +717,7 @@ TfLiteStatus QuantizeOpOutput( SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get(); OperatorT* op = subgraph->operators[op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); if (output_idx >= op->outputs.size()) { TF_LITE_REPORT_ERROR( error_reporter, @@ -812,7 +813,7 @@ TfLiteStatus QuantizeIntemediateTensors(ModelT* model, if (!property.intermediates.empty()) { OperatorT* op = subgraph->operators[op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); for (const std::pair& input : property.intermediates) { const int index_local = input.first; @@ -941,7 +942,7 @@ TfLiteStatus QuantizeWeightsInputOutput( for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) { OperatorT* op = subgraph->operators[op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); const string operator_name = subgraph->tensors[op->outputs[0]]->name; operator_property::OperatorProperty property = GetOperatorProperty(operator_names, model, subgraph_idx, op_idx, @@ -1001,7 +1002,7 @@ TfLiteStatus QuantizeBiases(ModelT* model, for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) { OperatorT* op = subgraph->operators[op_idx].get(); const BuiltinOperator op_code = - model->operator_codes[op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[op->opcode_index].get()); const string operator_name = subgraph->tensors[op->outputs[0]]->name; operator_property::OperatorProperty property = GetOperatorProperty(operator_names, model, subgraph_idx, op_idx, diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 0bf6bb8b5a9..32a23033019 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/tools/optimize/test_util.h" // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc @@ -153,7 +154,8 @@ TEST_P(QuantizeConvModelTest, TensorShapesAndStructureIsUnchanged) { } // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 1); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_CONV_2D); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_CONV_2D); EXPECT_EQ(model_.operator_codes[0]->version, 3); } @@ -166,9 +168,10 @@ TEST_P(QuantizeConvModelTest, OperatorsAreUnchanged) { readonly_model_->operator_codes()->size()); for (size_t i = 0; i < model_.operator_codes.size(); i++) { const auto float_model_op = readonly_model_->operator_codes()->Get(i); - EXPECT_EQ(model_.operator_codes[i]->builtin_code, - float_model_op->builtin_code()); - if (model_.operator_codes[i]->builtin_code == BuiltinOperator_CONV_2D) { + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[i].get()), + GetBuiltinCode(float_model_op)); + if (GetBuiltinCode(model_.operator_codes[i].get()) == + BuiltinOperator_CONV_2D) { EXPECT_EQ(model_.operator_codes[i]->version, 3); } else { EXPECT_EQ(model_.operator_codes[i]->version, 2); @@ -232,9 +235,9 @@ TEST_P(QuantizeConvModelTest, FloatInputAndOutput) { subgraph->operators[subgraph->operators.size() - 1]; const int32_t quant_idx = quant_op->opcode_index; const int32_t dequant_idx = dequant_op->opcode_index; - EXPECT_EQ(model_.operator_codes[quant_idx]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()), BuiltinOperator_QUANTIZE); - EXPECT_EQ(model_.operator_codes[dequant_idx]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()), BuiltinOperator_DEQUANTIZE); // The model should only have one input and output. EXPECT_EQ(subgraph->inputs.size(), 1); @@ -297,10 +300,12 @@ TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) { subgraph->operators[subgraph->operators.size() - 1]; const int32_t quant_op_uint8_int8_idx = quant_op_uint8_int8->opcode_index; const int32_t quant_op_int8_uint8_idx = quant_op_int8_uint8->opcode_index; - EXPECT_EQ(model_.operator_codes[quant_op_uint8_int8_idx]->builtin_code, - BuiltinOperator_QUANTIZE); - EXPECT_EQ(model_.operator_codes[quant_op_int8_uint8_idx]->builtin_code, - BuiltinOperator_QUANTIZE); + EXPECT_EQ( + GetBuiltinCode(model_.operator_codes[quant_op_uint8_int8_idx].get()), + BuiltinOperator_QUANTIZE); + EXPECT_EQ( + GetBuiltinCode(model_.operator_codes[quant_op_int8_uint8_idx].get()), + BuiltinOperator_QUANTIZE); // The model should only have one input and output. EXPECT_EQ(subgraph->inputs.size(), 1); EXPECT_EQ(subgraph->outputs.size(), 1); @@ -405,9 +410,9 @@ TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) { EXPECT_EQ(subgraph->operators.size(), 2); const auto& requant = subgraph->operators[0]; const auto& concat = subgraph->operators[1]; - EXPECT_EQ(model_.operator_codes[requant->opcode_index]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[requant->opcode_index].get()), BuiltinOperator_QUANTIZE); - EXPECT_EQ(model_.operator_codes[concat->opcode_index]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[concat->opcode_index].get()), BuiltinOperator_CONCATENATION); auto zero_point_control = tensor_type_ == TensorType_INT8 ? -128 : 0; @@ -469,10 +474,11 @@ TEST_P(QuantizeConcatModelTest, AddRequantBeforeConcat) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 2); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), BuiltinOperator_CONCATENATION); EXPECT_EQ(model_.operator_codes[0]->version, 2); - EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()), + BuiltinOperator_QUANTIZE); EXPECT_EQ(model_.operator_codes[1]->version, 2); } INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest, @@ -506,9 +512,9 @@ TEST_F(QuantizeSplitModelTest, QuantizeSplit) { EXPECT_EQ(subgraph->operators.size(), 2); const auto& split = subgraph->operators[0]; const auto& add = subgraph->operators[1]; - EXPECT_EQ(model_.operator_codes[split->opcode_index]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[split->opcode_index].get()), BuiltinOperator_SPLIT); - EXPECT_EQ(model_.operator_codes[add->opcode_index]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[add->opcode_index].get()), BuiltinOperator_ADD); // There should be 5 tensors: input, output, split, split/split_dim, split:1. @@ -541,7 +547,8 @@ TEST_F(QuantizeSplitModelTest, QuantizeSplit) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 2); - EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_SPLIT); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()), + BuiltinOperator_SPLIT); EXPECT_EQ(model_.operator_codes[0]->version, 2); } @@ -643,7 +650,8 @@ TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 1); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_CONV_2D); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_CONV_2D); EXPECT_EQ(model_.operator_codes[0]->version, 3); } @@ -762,7 +770,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvQuantization) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 1); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_CONV_2D); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_CONV_2D); EXPECT_EQ(model_.operator_codes[0]->version, 3); } @@ -785,7 +794,7 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { auto op = subgraph->operators[0].get(); // Model has a single softmax op. ASSERT_EQ(op->opcode_index, 0); - ASSERT_EQ(model_.operator_codes[0].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), BuiltinOperator_SOFTMAX); ASSERT_EQ(op->inputs.size(), 1); @@ -823,7 +832,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 1); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_SOFTMAX); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_SOFTMAX); EXPECT_EQ(model_.operator_codes[0]->version, 2); } @@ -846,7 +856,7 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { auto op = subgraph->operators[0].get(); // Model has a single AveragePool op. ASSERT_EQ(op->opcode_index, 0); - ASSERT_EQ(model_.operator_codes[0].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), BuiltinOperator_AVERAGE_POOL_2D); ASSERT_EQ(op->inputs.size(), 1); @@ -884,7 +894,7 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 1); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), BuiltinOperator_AVERAGE_POOL_2D); EXPECT_EQ(model_.operator_codes[0]->version, 2); } @@ -907,7 +917,7 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { // Verify Reshape is quantized. const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[1].get(); - ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), BuiltinOperator_RESHAPE); ASSERT_EQ(op->inputs.size(), 2); @@ -921,7 +931,6 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { EXPECT_EQ(subgraph->tensors[op->inputs[0]].get()->type, TensorType_INT8); EXPECT_EQ(subgraph->tensors[op->outputs[0]].get()->type, TensorType_INT8); - auto float_input_quant_params = float_graph->tensors()->Get(op->inputs[0])->quantization(); auto input_quant_params = @@ -940,9 +949,11 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 2); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_ADD); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_ADD); EXPECT_EQ(model_.operator_codes[0]->version, 2); - EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_RESHAPE); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()), + BuiltinOperator_RESHAPE); EXPECT_EQ(model_.operator_codes[1]->version, 1); } @@ -955,7 +966,7 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) { // Verify ADD is quantized. const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); - ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), BuiltinOperator_ADD); ASSERT_EQ(op->inputs.size(), 2); @@ -992,9 +1003,11 @@ TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyAddQuantization) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 2); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_ADD); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_ADD); EXPECT_EQ(model_.operator_codes[0]->version, 2); - EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_RESHAPE); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()), + BuiltinOperator_RESHAPE); EXPECT_EQ(model_.operator_codes[1]->version, 1); } @@ -1016,7 +1029,7 @@ TEST_F(QuantizeConstInputTest, VerifyConstOpInput) { // Verify ConstOp is quantized. const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); - ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), BuiltinOperator_ADD); ASSERT_EQ(op->inputs.size(), 2); @@ -1037,7 +1050,8 @@ TEST_F(QuantizeConstInputTest, VerifyConstOpInput) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 1); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_ADD); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_ADD); EXPECT_EQ(model_.operator_codes[0]->version, 2); } @@ -1058,7 +1072,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); - ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), BuiltinOperator_ARG_MAX); ASSERT_EQ(op->inputs.size(), 2); @@ -1080,7 +1094,8 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 1); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, BuiltinOperator_ARG_MAX); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), + BuiltinOperator_ARG_MAX); EXPECT_EQ(model_.operator_codes[0]->version, 2); } @@ -1281,7 +1296,7 @@ TEST_F(QuantizeFCTest, VerifyFC) { const auto& subgraph = model_.subgraphs[0]; auto op = subgraph->operators[0].get(); - ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), BuiltinOperator_FULLY_CONNECTED); ASSERT_EQ(op->inputs.size(), 3); @@ -1308,10 +1323,11 @@ TEST_F(QuantizeFCTest, VerifyFC) { // check op and versioning. EXPECT_EQ(model_.operator_codes.size(), 2); - EXPECT_EQ(model_.operator_codes[0]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[0].get()), BuiltinOperator_FULLY_CONNECTED); EXPECT_EQ(model_.operator_codes[0]->version, 4); - EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_RESHAPE); + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[1].get()), + BuiltinOperator_RESHAPE); EXPECT_EQ(model_.operator_codes[1]->version, 1); } @@ -1347,7 +1363,7 @@ TEST_P(QuantizeCustomOpTest, VerifyMixedQuantization) { TensorType_FLOAT32, TensorType_FLOAT32, GetParam()}; for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); - ASSERT_EQ(model_.operator_codes[op->opcode_index]->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), op_codes[i]); ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]); } @@ -1384,7 +1400,7 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { TensorType_INT16, TensorType_INT16, TensorType_FLOAT32}; for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); - ASSERT_EQ(model_.operator_codes[op->opcode_index]->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), op_codes[i]); ASSERT_EQ(subgraph->tensors[op->inputs[0]]->type, op_input_types[i]); } @@ -1415,13 +1431,13 @@ TEST_F(QuantizePackTest, VerifyPack) { const auto& op3 = subgraph->operators[3].get(); const auto& op4 = subgraph->operators[4].get(); - ASSERT_EQ(model_.operator_codes[op1->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op1->opcode_index].get()), BuiltinOperator_QUANTIZE); - ASSERT_EQ(model_.operator_codes[op2->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op2->opcode_index].get()), BuiltinOperator_QUANTIZE); - ASSERT_EQ(model_.operator_codes[op3->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op3->opcode_index].get()), BuiltinOperator_PACK); - ASSERT_EQ(model_.operator_codes[op4->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op4->opcode_index].get()), BuiltinOperator_DEQUANTIZE); const auto& pack_input0 = subgraph->tensors[op3->inputs[0]].get(); @@ -1473,27 +1489,27 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { const auto& dequant_op = subgraph->operators[subgraph->operators.size() - 1]; const int32_t quant_idx = quant_op->opcode_index; const int32_t dequant_idx = dequant_op->opcode_index; - EXPECT_EQ(model_.operator_codes[quant_idx]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[quant_idx].get()), BuiltinOperator_QUANTIZE); - EXPECT_EQ(model_.operator_codes[dequant_idx]->builtin_code, + EXPECT_EQ(GetBuiltinCode(model_.operator_codes[dequant_idx].get()), BuiltinOperator_DEQUANTIZE); const auto& requant1 = subgraph->operators[1].get(); // Check that we have RE operator. auto requant1_builtin_code = - model_.operator_codes[requant1->opcode_index].get()->builtin_code; + GetBuiltinCode(model_.operator_codes[requant1->opcode_index].get()); ASSERT_TRUE(requant1_builtin_code == tflite::BuiltinOperator_QUANTIZE); const auto& requant2 = subgraph->operators[2].get(); // Check that we have RE operator. auto requant2_builtin_code = - model_.operator_codes[requant2->opcode_index].get()->builtin_code; + GetBuiltinCode(model_.operator_codes[requant2->opcode_index].get()); ASSERT_TRUE(requant2_builtin_code == tflite::BuiltinOperator_QUANTIZE); const auto& op = subgraph->operators[3].get(); // Check that we have MINIMUM or MAXIMUM operator. auto op_builtin_code = - model_.operator_codes[op->opcode_index].get()->builtin_code; + GetBuiltinCode(model_.operator_codes[op->opcode_index].get()); ASSERT_TRUE(op_builtin_code == tflite::BuiltinOperator_MINIMUM || op_builtin_code == tflite::BuiltinOperator_MAXIMUM); @@ -1554,7 +1570,7 @@ TEST_F(QuantizeUnpackTest, VerifyUnpack) { auto float_graph = readonly_model_->subgraphs()->Get(0); - ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), BuiltinOperator_UNPACK); // Get unpack input and output tensors @@ -1603,7 +1619,7 @@ TEST_F(QuantizeTransposeTest, VerifyTranspose) { auto float_graph = readonly_model_->subgraphs()->Get(0); - ASSERT_EQ(model_.operator_codes[op->opcode_index].get()->builtin_code, + ASSERT_EQ(GetBuiltinCode(model_.operator_codes[op->opcode_index].get()), BuiltinOperator_TRANSPOSE); // The model should only have one input and one outputs. diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index 35c90d7859e..1b22cb56117 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/tools/optimize/model_utils.h" #include "tensorflow/lite/tools/optimize/quantization_utils.h" @@ -79,7 +80,7 @@ std::vector GetTensorConsumers(const ModelT* model, // the provided op. std::vector GetWeightInputIndices(const OperatorCodeT* op_code, const CustomOpMap& custom_op_map) { - const BuiltinOperator builtin_op_code = op_code->builtin_code; + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); if (builtin_op_code == BuiltinOperator_CUSTOM) { const std::string custom_code = op_code->custom_code; const auto& custom_op_info = custom_op_map.find(custom_code); @@ -132,7 +133,7 @@ bool IsQuantizedInput(const OperatorCodeT* op_code, bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme) { - const BuiltinOperator builtin_op_code = op_code->builtin_code; + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); // Operations that support hybrid evaluation. bool eval_hybrid = false; if (builtin_op_code == BuiltinOperator_CUSTOM) { @@ -196,6 +197,7 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( int subgraph_index, bool use_updated_hybrid_scheme) { SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + auto builtin_code = GetBuiltinCode(op_code); std::vector op_input_indices = GetWeightInputIndices(op_code, custom_op_map); @@ -203,8 +205,7 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( int32_t tensor_idx = op->inputs[op_input_idx]; if (tensor_idx == -1) { LOG(INFO) << "Skipping optional tensor input " << op_input_idx - << " of operation " - << EnumNameBuiltinOperator(op_code->builtin_code); + << " of operation " << EnumNameBuiltinOperator(builtin_code); continue; } @@ -232,16 +233,16 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( continue; } - if (op_code->builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + if (builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) { tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, /*dim=*/3}}); - } else if (op_code->builtin_code == BuiltinOperator_CONV_2D) { + } else if (builtin_code == BuiltinOperator_CONV_2D) { tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, /*dim=*/0}}); } else { - switch (op_code->builtin_code) { + switch (builtin_code) { case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: op->builtin_options.AsBidirectionalSequenceLSTMOptions() ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; @@ -288,13 +289,16 @@ TfLiteStatus InsertQuantizableInputTensorsFromOperator( // If a Dequantize op_code doesn't exist, adds it and returns its index. int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) { for (size_t i = 0; i < model->operator_codes.size(); ++i) { - if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) { + if (GetBuiltinCode(model->operator_codes[i].get()) == + BuiltinOperator_DEQUANTIZE) { return i; } } model->operator_codes.push_back(absl::make_unique()); int op_code_idx = model->operator_codes.size() - 1; model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE; + model->operator_codes[op_code_idx]->deprecated_builtin_code = + static_cast(BuiltinOperator_DEQUANTIZE); // Version 2 and onwards supports INT8 inputs. model->operator_codes[op_code_idx]->version = 2; @@ -330,7 +334,8 @@ void MakeTensor(const string& name, const std::vector& shape, // Updates operator code versions for the operators with INT8 inputs. void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) { for (int i = 0, end = model->operator_codes.size(); i < end; ++i) { - const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code; + const BuiltinOperator& op_code = + GetBuiltinCode(model->operator_codes[i].get()); if (op_code == BuiltinOperator_RNN || op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || @@ -361,7 +366,7 @@ bool IsQuantizationPassThroughOps( } const OperatorT* consumer_op = consumer_op_infos.front().op; const BuiltinOperator op_code = - model->operator_codes[consumer_op->opcode_index]->builtin_code; + GetBuiltinCode(model->operator_codes[consumer_op->opcode_index].get()); return op_code == BuiltinOperator_GATHER || op_code == BuiltinOperator_EMBEDDING_LOOKUP; } diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index 94bff2d5eb8..05b135fef98 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/tools/optimize/quantize_weights.h" + #include #include #include @@ -25,7 +27,7 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/tools/optimize/quantize_weights.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/tools/optimize/test_util.h" namespace { @@ -120,7 +122,7 @@ class QuantizeWeightsTest : public testing::Test { for (size_t i = 0; i < op->outputs()->size(); ++i) { if (op->outputs()->Get(i) == tensor_idx) { const uint32_t op_code_idx = op->opcode_index(); - *op_code = model->operator_codes()->Get(op_code_idx)->builtin_code(); + *op_code = GetBuiltinCode(model->operator_codes()->Get(op_code_idx)); return true; } } @@ -195,7 +197,7 @@ TEST_F(QuantizeWeightsTest, HybridConv) { ASSERT_EQ(quantized_graph->operators()->size(), 1); const auto op = quantized_graph->operators()->Get(0); const uint32_t op_code_idx = op->opcode_index(); - ASSERT_EQ(output_model->operator_codes()->Get(op_code_idx)->builtin_code(), + ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)), BuiltinOperator_CONV_2D); for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { const auto quant_tensor = quantized_graph->tensors()->Get(i); @@ -254,7 +256,7 @@ TEST_F(QuantizeWeightsTest, DequantizeConv) { for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { const auto op = quantized_graph->operators()->Get(i); const uint32_t op_code_idx = op->opcode_index(); - if (output_model->operator_codes()->Get(op_code_idx)->builtin_code() == + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == BuiltinOperator_DEQUANTIZE) { dequant_input_idx = op->inputs()->Get(0); dequant_output_idx = op->outputs()->Get(0); @@ -313,7 +315,7 @@ TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { const auto op = quantized_graph->operators()->Get(i); const uint32_t op_code_idx = op->opcode_index(); - if (output_model->operator_codes()->Get(op_code_idx)->builtin_code() == + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == BuiltinOperator_DEQUANTIZE) { dequant_input_idx = op->inputs()->Get(0); dequant_output_idx = op->outputs()->Get(0); @@ -365,7 +367,7 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { const auto op = quantized_graph->operators()->Get(i); const uint32_t op_code_idx = op->opcode_index(); const auto op_code = - output_model->operator_codes()->Get(op_code_idx)->builtin_code(); + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); if (op_code == BuiltinOperator_CONV_2D) { num_conv_ops++; // Ensure that each convolution's weights tensor is now INT8. @@ -399,7 +401,7 @@ TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { const auto op = quantized_graph->operators()->Get(i); const uint32_t op_code_idx = op->opcode_index(); const auto op_code = - output_model->operator_codes()->Get(op_code_idx)->builtin_code(); + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); if (op_code == BuiltinOperator_CONV_2D) { num_conv_ops++; // Ensure that each convolution's weights tensor is still FLOAT @@ -439,7 +441,7 @@ TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) { const auto op = quantized_graph->operators()->Get(i); const uint32_t op_code_idx = op->opcode_index(); const auto op_code = - output_model->operator_codes()->Get(op_code_idx)->builtin_code(); + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); if (op_code == BuiltinOperator_GATHER) { uint32_t input_tensor_index = op->inputs()->Get(0); const auto weights_tensor = @@ -479,7 +481,7 @@ TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) { const auto op = quantized_graph->operators()->Get(i); const uint32_t op_code_idx = op->opcode_index(); const auto op_code = - output_model->operator_codes()->Get(op_code_idx)->builtin_code(); + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); if (op_code == BuiltinOperator_CUSTOM) { uint32_t weights_tensor_index = op->inputs()->Get(1); const auto weights_tensor = @@ -525,7 +527,7 @@ TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) { const auto op = quantized_graph->operators()->Get(i); const uint32_t op_code_idx = op->opcode_index(); const auto op_code = - output_model->operator_codes()->Get(op_code_idx)->builtin_code(); + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); if (op_code == BuiltinOperator_CUSTOM) { uint32_t weights_tensor_index = op->inputs()->Get(1); const auto weights_tensor = @@ -560,7 +562,7 @@ TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) { ASSERT_EQ(quantized_graph->operators()->size(), 1); const auto op = quantized_graph->operators()->Get(0); const uint32_t op_code_idx = op->opcode_index(); - ASSERT_EQ(output_model->operator_codes()->Get(op_code_idx)->builtin_code(), + ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)), BuiltinOperator_CONV_2D); for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { const auto quant_tensor = quantized_graph->tensors()->Get(i); diff --git a/tensorflow/lite/tools/pip_package/README.md b/tensorflow/lite/tools/pip_package/README.md index eb6338a96bb..44c60972133 100644 --- a/tensorflow/lite/tools/pip_package/README.md +++ b/tensorflow/lite/tools/pip_package/README.md @@ -98,6 +98,12 @@ tensorflow/tools/ci_build/ci_build.sh PI-PYTHON38 \ tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh aarch64 ``` +### Native build for Windows + +```sh +bash tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh windows +``` + ## Enable TF OP support (Flex delegate) If you want to use TF ops with Python API, you need to enable flex support. diff --git a/tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh b/tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh index 66bf1bed00f..6724674df35 100755 --- a/tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh +++ b/tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh @@ -79,10 +79,23 @@ esac # include path for Python 3.x builds to work. export CROSSTOOL_PYTHON_INCLUDE_PATH +case "${TENSORFLOW_TARGET}" in + windows) + LIBRARY_EXTENSION=".pyd" + ;; + *) + LIBRARY_EXTENSION=".so" + ;; +esac + bazel build -c opt -s --config=monolithic --config=noaws --config=nogcp --config=nohdfs --config=nonccl \ ${BAZEL_FLAGS} ${CUSTOM_BAZEL_FLAGS} //tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper -cp "${TENSORFLOW_DIR}/bazel-bin/tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.so" \ +cp "${TENSORFLOW_DIR}/bazel-bin/tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper${LIBRARY_EXTENSION}" \ "${BUILD_DIR}/tflite_runtime" +# Bazel generates the wrapper library with r-x permissions for user. +# At least on Windows, we need write permissions to delete the file. +# Without this, setuptools fails to clean the build directory. +chmod u+w "${BUILD_DIR}/tflite_runtime/_pywrap_tensorflow_interpreter_wrapper${LIBRARY_EXTENSION}" # Build python wheel. cd "${BUILD_DIR}" diff --git a/tensorflow/lite/tools/pip_package/setup_with_bazel.py b/tensorflow/lite/tools/pip_package/setup_with_bazel.py index e3e9a35a62e..2c9decc7e55 100644 --- a/tensorflow/lite/tools/pip_package/setup_with_bazel.py +++ b/tensorflow/lite/tools/pip_package/setup_with_bazel.py @@ -63,7 +63,7 @@ setup( ], packages=find_packages(exclude=[]), package_dir={'': '.'}, - package_data={'': ['*.so']}, + package_data={'': ['*.so', '*.pyd']}, install_requires=[ 'numpy >= 1.16.0', 'pybind11 >= 2.4.3', diff --git a/tensorflow/lite/tools/sanitizers/BUILD b/tensorflow/lite/tools/sanitizers/BUILD new file mode 100644 index 00000000000..6da688fd9cb --- /dev/null +++ b/tensorflow/lite/tools/sanitizers/BUILD @@ -0,0 +1,21 @@ +load("//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_tensorflow_lite_sanitizers", + srcs = [ + "sanitizers_pybind11.cc", + ], + link_in_framework = True, + module_name = "_pywrap_tensorflow_lite_sanitizers", + deps = [ + "//tensorflow/python:pybind11_lib", + "@pybind11", + ], +) diff --git a/tensorflow/lite/tools/sanitizers/sanitizers_pybind11.cc b/tensorflow/lite/tools/sanitizers/sanitizers_pybind11.cc new file mode 100644 index 00000000000..42e4c2951f2 --- /dev/null +++ b/tensorflow/lite/tools/sanitizers/sanitizers_pybind11.cc @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" + +PYBIND11_MODULE(_pywrap_tensorflow_lite_sanitizers, m) { + m.def("TSan_Enabled", []() -> bool { +#ifdef THREAD_SANITIZER + return true; +#else + return false; +#endif + }); + m.def("MSan_Enabled", []() -> bool { +#ifdef MEMORY_SANITIZER + return true; +#else + return false; +#endif + }); + m.def("ASan_Enabled", []() -> bool { +#ifdef ADDRESS_SANITIZER + return true; +#else + return false; +#endif + }); +} diff --git a/tensorflow/lite/tools/verifier.cc b/tensorflow/lite/tools/verifier.cc index bdd4a35dac7..a4cc4e0c9a0 100644 --- a/tensorflow/lite/tools/verifier.cc +++ b/tensorflow/lite/tools/verifier.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/string_util.h" #include "tensorflow/lite/version.h" @@ -478,6 +479,7 @@ bool VerifySubGraphConsistency(const Model& model, const SubGraph& subgraph, return false; } const auto& opcode = model.operator_codes()->Get(op->opcode_index()); + auto builtin_code = GetBuiltinCode(opcode); // Check for invalid inputs by ensuring all exist in produced_tensors. for (const int input_idx : *op->inputs()) { if (input_idx == kTfLiteOptionalTensor) continue; @@ -488,8 +490,7 @@ bool VerifySubGraphConsistency(const Model& model, const SubGraph& subgraph, output_tensors.find(input_idx) == output_tensors.end()) { ReportError(error_reporter, "Input tensor %d to op %d (%s) is not produced", - input_idx, op_idx, - EnumNameBuiltinOperator(opcode->builtin_code())); + input_idx, op_idx, EnumNameBuiltinOperator(builtin_code)); return false; } } @@ -497,31 +498,29 @@ bool VerifySubGraphConsistency(const Model& model, const SubGraph& subgraph, // produced_tensors. for (const int output_idx : *op->outputs()) { if (constant_tensors.find(output_idx) != constant_tensors.end()) { - ReportError(error_reporter, - "Output tensor %d to op %d (%s) is a constant", - output_idx, op_idx, - EnumNameBuiltinOperator(opcode->builtin_code())); + ReportError( + error_reporter, "Output tensor %d to op %d (%s) is a constant", + output_idx, op_idx, EnumNameBuiltinOperator(builtin_code)); return false; } else if (variable_tensors.find(output_idx) != variable_tensors.end()) { - ReportError(error_reporter, - "Output tensor %d to op %d (%s) is a variable", - output_idx, op_idx, - EnumNameBuiltinOperator(opcode->builtin_code())); + ReportError( + error_reporter, "Output tensor %d to op %d (%s) is a variable", + output_idx, op_idx, EnumNameBuiltinOperator(builtin_code)); return false; } else if (subgraph_input_tensors.find(output_idx) != subgraph_input_tensors.end()) { ReportError(error_reporter, "Output tensor %d to op %d (%s) is a subgraph input", output_idx, op_idx, - EnumNameBuiltinOperator(opcode->builtin_code())); + EnumNameBuiltinOperator(builtin_code)); return false; } else if (output_tensors.find(output_idx) != output_tensors.end()) { ReportError(error_reporter, "Output tensor %d to op %d (%s) is an output from " "another op. There is a cycle in the graph", output_idx, op_idx, - EnumNameBuiltinOperator(opcode->builtin_code())); + EnumNameBuiltinOperator(builtin_code)); return false; } // This can be an input to a subsequent op. @@ -608,14 +607,15 @@ bool VerifyOps(const Model& model, const OpResolver& resolver, return true; } for (const auto& opcode : *model.operator_codes()) { - if (opcode->builtin_code() < BuiltinOperator_MIN || - opcode->builtin_code() > BuiltinOperator_MAX) { + auto builtin_code = GetBuiltinCode(opcode); + if (builtin_code < BuiltinOperator_MIN || + builtin_code > BuiltinOperator_MAX) { ReportError(error_reporter, "Operator id '%d' is out of range.", - opcode->builtin_code()); + builtin_code); return false; } - if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { + if (builtin_code == BuiltinOperator_CUSTOM) { if (IsNullOrEmptyString(opcode->custom_code())) { ReportError(error_reporter, "Invalid custom op name, cannot be null/empty."); @@ -627,10 +627,9 @@ bool VerifyOps(const Model& model, const OpResolver& resolver, return false; } } else { - if (!resolver.FindOp(opcode->builtin_code(), opcode->version())) { + if (!resolver.FindOp(builtin_code, opcode->version())) { ReportError(error_reporter, "Unsupported builtin op: %s, version: %d", - EnumNameBuiltinOperator(opcode->builtin_code()), - opcode->version()); + EnumNameBuiltinOperator(builtin_code), opcode->version()); return false; } } diff --git a/tensorflow/lite/tools/verifier.h b/tensorflow/lite/tools/verifier.h index 50b6432d4e3..35d7e30a172 100644 --- a/tensorflow/lite/tools/verifier.h +++ b/tensorflow/lite/tools/verifier.h @@ -18,8 +18,10 @@ limitations under the License. #include -#include "tensorflow/lite/error_reporter.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/error_reporter.h" // Legacy. +#include "tensorflow/lite/model.h" // Legacy. namespace tflite { diff --git a/tensorflow/lite/tools/verifier_test.cc b/tensorflow/lite/tools/verifier_test.cc index da4eb55c4f9..f37eaaa99be 100644 --- a/tensorflow/lite/tools/verifier_test.cc +++ b/tensorflow/lite/tools/verifier_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/lite/mutable_op_resolver.h" #include "tensorflow/lite/op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index a15fc27d43a..4d41b7d13e9 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -543,6 +543,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_CONCATENATION: + case BuiltinOperator_BATCH_MATMUL: case BuiltinOperator_SOFTMAX: case BuiltinOperator_MEAN: case BuiltinOperator_PAD: @@ -585,7 +586,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_LESS: case BuiltinOperator_LESS_EQUAL: case BuiltinOperator_SELECT: - case BuiltinOperator_BATCH_MATMUL: if (op_sig.input_types.at(0) == TensorType_INT8) { return 2; } diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index cac6779c48a..3c0a92f9df1 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -59,6 +59,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_AVERAGE_POOL_2D, 3}, "2.3.0"}, {{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"}, {{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"}, + {{BuiltinOperator_BATCH_MATMUL, 3}, kPendingReleaseVersion}, {{BuiltinOperator_CONV_2D, 1}, "1.5.0"}, {{BuiltinOperator_CONV_2D, 2}, "1.14.0"}, {{BuiltinOperator_CONV_2D, 3}, "1.14.0"}, diff --git a/tensorflow/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py index 3d22d1bb05b..a9d337bee16 100644 --- a/tensorflow/lite/tools/visualize.py +++ b/tensorflow/lite/tools/visualize.py @@ -452,6 +452,10 @@ def CreateHtmlFile(tflite_input, html_output): ("custom_code", None), ("version", None)] + # Update builtin code fields. + for idx, d in enumerate(data["operator_codes"]): + d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"]) + for subgraph_idx, g in enumerate(data["subgraphs"]): # Subgraph local specs on what to display html += "
" diff --git a/tensorflow/lite/tutorials/BUILD b/tensorflow/lite/tutorials/BUILD index 277d807b2ca..d8f9e254ece 100644 --- a/tensorflow/lite/tutorials/BUILD +++ b/tensorflow/lite/tutorials/BUILD @@ -5,8 +5,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - py_binary( name = "mnist_tflite", srcs = [ diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index 2aad135c23b..ff944035b22 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -8,6 +8,7 @@ tensorflow/go/op/wrappers.go tensorflow/lite/micro/build_def.bzl tensorflow/python/autograph/core/config.py tensorflow/python/eager/benchmarks_test_base.py +tensorflow/python/framework/tfrt_utils.py tensorflow/python/tpu/profiler/pip_package/BUILD tensorflow/python/tpu/profiler/pip_package/README tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh @@ -203,6 +204,7 @@ tensorflow/third_party/tensorrt/build_defs.bzl.tpl tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl tensorflow/third_party/tensorrt/tensorrt_configure.bzl tensorflow/third_party/termcolor.BUILD +tensorflow/third_party/tf_toolchains.BUILD tensorflow/third_party/tflite_mobilenet.BUILD tensorflow/third_party/tflite_mobilenet_float.BUILD tensorflow/third_party/tflite_mobilenet_quant.BUILD diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9471bc99dd6..214f7e13d5c 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -8,6 +8,9 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_monitoring_python_deps") @@ -62,6 +65,7 @@ visibility = [ "//third_party/py/tf_agents:__subpackages__", # For benchmarks. "//third_party/py/tf_slim:__subpackages__", "//third_party/py/tensorflow_docs:__subpackages__", + "//third_party/py/keras:__subpackages__", ] package( @@ -69,8 +73,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - # Description: py_library( @@ -127,6 +129,7 @@ py_library( "//tensorflow/python/tools/api/generator:__pkg__", "//tensorflow/tools/api/tests:__pkg__", "//tensorflow/tools/compatibility/update:__pkg__", + "//third_party/py/keras:__subpackages__", "//third_party/py/tensorflow_core:__subpackages__", ], deps = [ @@ -216,8 +219,10 @@ py_library( "//tensorflow/python/data", "//tensorflow/python/debug:debug_py", "//tensorflow/python/distribute", + "//tensorflow/python/distribute:combinations", # For tf.__internal__ API. "//tensorflow/python/distribute:distribute_config", "//tensorflow/python/distribute:estimator_training", + "//tensorflow/python/distribute:strategy_combinations", # For tf.__internal__, "//tensorflow/python/dlpack", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:monitoring", @@ -258,8 +263,9 @@ py_library( "//third_party/py/tensorflow_core:__subpackages__", ], deps = [ - ":framework_combinations", ":no_contrib", + "//tensorflow/python/distribute:multi_process_runner", + "//tensorflow/python/distribute:multi_worker_test_base", ], ) @@ -601,6 +607,7 @@ cc_library( cc_library( name = "pybind11_lib", hdrs = ["lib/core/pybind11_lib.h"], + compatible_with = get_compatible_with_portable(), features = ["-parse_headers"], visibility = tf_external_workspace_visible(visibility), deps = [ @@ -805,6 +812,7 @@ tf_python_pybind_extension( name = "_pywrap_tensor_float_32_execution", srcs = ["util/tensor_float_32.cc"], hdrs = ["//tensorflow/core/platform:tensor_float_32_hdr"], + compatible_with = get_compatible_with_portable(), module_name = "_pywrap_tensor_float_32_execution", deps = [ "@pybind11", @@ -2194,13 +2202,20 @@ py_library( ) # Including this as a dependency will result in tests to use TFRT. -# TODO(b/153582383): Move tf_ops_alwayslink dependency to c_api_tfrt instead. +# TODO(b/170139514): Add a curated list of TFRT native ops. +# TODO(kkb): Deprecated by `tfrt_utils`, remove. py_library( name = "is_tfrt_test_true", srcs = ["framework/is_tfrt_test_true.py"], srcs_version = "PY2AND3", ) +py_library( + name = "tfrt_utils", + srcs = ["framework/tfrt_utils.py"], + srcs_version = "PY2AND3", +) + py_library( name = "distributed_framework_test_lib", srcs_version = "PY2AND3", @@ -2498,7 +2513,6 @@ tf_py_test( srcs = ["framework/importer_test.py"], main = "framework/importer_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -2820,13 +2834,13 @@ tf_py_test( main = "framework/test_util_test.py", python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":control_flow_ops", ":errors", ":framework_combinations", ":framework_for_generated_wrappers", ":framework_test_lib", + ":lookup_ops", ":platform_test", ":random_ops", ":resource_variable_ops", @@ -3504,7 +3518,6 @@ tf_py_test( srcs = ["ops/collective_ops_test.py"], python_version = "PY3", tags = ["no_rocm"], - tfrt_enabled = True, deps = [ ":client_testlib", ":collective_ops", @@ -3620,7 +3633,6 @@ py_library( "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", - "//tensorflow/python/keras/engine:base_layer_utils", ], ) @@ -3713,6 +3725,7 @@ py_library( ":gradients", ":gradients_util", ":graph_to_function_def", + ":handle_data_util", ":pywrap_tensorflow", ":util", "//tensorflow/python/compat", @@ -3847,6 +3860,21 @@ py_library( ], ) +py_library( + name = "handle_data_util", + srcs = [ + "ops/handle_data_util.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":framework_ops", + ":protos_all_py", + ":pywrap_tf_session", + ":util", + ], +) + py_library( name = "gradients", srcs = [ @@ -3857,6 +3885,7 @@ py_library( deps = [ ":gradients_impl", ":gradients_util", + ":handle_data_util", ":pywrap_tf_session", ":unconnected_gradients", "//tensorflow/python/eager:forwardprop", @@ -4251,6 +4280,7 @@ py_library( ":auto_control_deps_utils", ":dtypes", ":framework_ops", + ":handle_data_util", ":pywrap_tf_session", ":resource_variable_ops_gen", ":tensor_shape", @@ -4284,6 +4314,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":handle_data_util", ":list_ops_gen", ], ) @@ -4432,6 +4463,7 @@ py_library( ":math_ops", ":random_ops_gen", ":random_seed", + ":stateless_random_ops", ], ) @@ -4455,7 +4487,6 @@ cuda_py_test( size = "medium", srcs = ["ops/stateful_random_ops_test.py"], python_version = "PY3", - tfrt_enabled = True, xla_enable_strict_auto_jit = False, xla_enabled = True, deps = [ @@ -4474,10 +4505,10 @@ py_library( srcs = ["ops/stateless_random_ops.py"], srcs_version = "PY2AND3", deps = [ + ":array_ops", ":dtypes", ":framework_ops", ":math_ops", - ":random_ops", ":stateless_random_ops_gen", ":stateless_random_ops_v2_gen", ], @@ -5069,6 +5100,7 @@ cuda_py_test( size = "medium", srcs = ["ops/gradient_checker_v2_test.py"], python_version = "PY3", + tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5141,6 +5173,7 @@ cuda_py_test( srcs = ["ops/image_grad_deterministic_test.py"], python_version = "PY3", shard_count = 5, + tfrt_enabled = True, deps = [ ":image_grad_test_base", ], @@ -5236,6 +5269,7 @@ cuda_py_test( srcs = ["ops/math_grad_test.py"], python_version = "PY3", tags = ["no_windows_gpu"], + tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5275,6 +5309,7 @@ cuda_py_test( tags = [ "no_windows_gpu", ], + tfrt_enabled = True, deps = [ ":framework_for_generated_wrappers", ":framework_test_lib", @@ -5311,7 +5346,6 @@ cuda_py_test( srcs = ["ops/nn_fused_batchnorm_test.py"], python_version = "PY3", shard_count = 24, - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5329,7 +5363,6 @@ cuda_py_test( srcs = ["ops/nn_test.py"], python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -5354,6 +5387,7 @@ py_test( ":client_testlib", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "@absl_py//absl/testing:parameterized", ], ) @@ -5363,7 +5397,6 @@ cuda_py_test( size = "medium", srcs = ["ops/nn_xent_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -5470,6 +5503,7 @@ py_library( py_strict_library( name = "tf_export", srcs = ["util/tf_export.py"], + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ @@ -5498,6 +5532,7 @@ py_strict_library( "util/tf_decorator.py", "util/tf_inspect.py", ], + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", visibility = [ "//tensorflow:__subpackages__", @@ -5515,6 +5550,7 @@ py_strict_library( py_strict_library( name = "tf_stack", srcs = ["util/tf_stack.py"], + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", # TODO(mdan): Remove public visibility. visibility = ["//visibility:public"], @@ -5527,6 +5563,7 @@ py_strict_library( pybind_extension( name = "_tf_stack", srcs = ["util/tf_stack.cc"], + compatible_with = get_compatible_with_portable(), # TODO(b/138203821): change to "util._tf_stack" once the bug is fixed. module_name = "_tf_stack", deps = [ @@ -5611,6 +5648,7 @@ tf_py_test( py_library( name = "global_test_configuration", + compatible_with = get_compatible_with_portable(), deps = if_mlir(["//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration"]) + tf_enable_mlir_bridge(), ) @@ -5629,6 +5667,7 @@ py_library( "util/**/*_test.py", ], ), + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", visibility = visibility + [ "//tensorflow:__pkg__", @@ -5936,9 +5975,9 @@ cc_library( deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:lib", - "//tensorflow/core:master_proto_cc", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:replay_log_proto_cc", + "//tensorflow/core/protobuf:master_proto_cc", + "//tensorflow/core/protobuf:replay_log_proto_cc", ], ) @@ -6021,6 +6060,7 @@ pywrap_tensorflow_macro( "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/experimental/ops", "//tensorflow/c/experimental/gradients", + "//tensorflow/c/experimental/gradients/tape", "//tensorflow/c/eager:mnist_gradients_testutil", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core/data/service:server_lib", @@ -6084,8 +6124,6 @@ filegroup( "//tensorflow/compiler/jit:get_compiler_ir", #tfe "//tensorflow/compiler/jit:flags", #tfe "//tensorflow/compiler/mlir/python:mlir", # mlir - "//tensorflow/core/common_runtime:device", # device_lib, tfe, tf_session - "//tensorflow/core/common_runtime:device_factory", # device_lib, tfe, tf_session "//tensorflow/core/common_runtime:graph_constructor", # tf_session "//tensorflow/core/common_runtime:quantize_training", # quantize_training "//tensorflow/core/common_runtime:session_options", # device_lib, tfe, tf_session @@ -6114,9 +6152,8 @@ filegroup( "//tensorflow/core/grappler/optimizers:meta_optimizer", # tf_optimizer "//tensorflow/core/grappler/utils:topological_sort", # tf_item "//tensorflow/core/platform:tensor_float_32_utils", # tensor_float_32 - "//tensorflow/core/profiler/internal:annotation_stack_impl", # profiler "//tensorflow/core/profiler/internal:print_model_analysis", # tfprof - "//tensorflow/core/profiler/internal:traceme_recorder_impl", # profiler + "//tensorflow/core/profiler/internal/cpu:traceme_recorder_impl", # profiler "//tensorflow/core/profiler/lib:profiler_session_impl", # profiler "//tensorflow/core/profiler/rpc:profiler_server_impl", # profiler "//tensorflow/core/profiler/rpc/client:profiler_client_impl", # profiler @@ -6333,7 +6370,6 @@ tf_py_test( "no_pip_gpu", # testInteractivePlacePrunedGraph fails on invalid assumption about GPU ops. "no_windows", ], - tfrt_enabled = True, deps = [ ":array_ops", ":client", @@ -6510,7 +6546,6 @@ tf_py_test( srcs = ["framework/convert_to_constants_test.py"], python_version = "PY3", tags = ["no_rocm"], - tfrt_enabled = True, deps = [ ":client_testlib", ":control_flow_v2_toggles", @@ -6699,7 +6734,6 @@ tf_py_test( srcs = ["ops/dequantize_op_test.py"], python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -6714,7 +6748,6 @@ tf_py_test( srcs = ["ops/quantized_ops_test.py"], python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":array_ops", ":client_testlib", @@ -6729,7 +6762,6 @@ tf_py_test( srcs = ["ops/quantized_conv_ops_test.py"], python_version = "PY3", tags = ["no_windows"], - tfrt_enabled = True, deps = [ ":client_testlib", ":framework_for_generated_wrappers", @@ -7534,6 +7566,7 @@ cc_library( "//tensorflow/c/eager:pywrap_required_hdrs", "//tensorflow/c/experimental/ops:pywrap_required_hdrs", "//tensorflow/c/experimental/gradients:pywrap_required_hdrs", + "//tensorflow/c/experimental/gradients/tape:pywrap_required_hdrs", "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", "//tensorflow/python/eager:pywrap_required_hdrs", "//tensorflow/core/common_runtime:core_cpu_lib_headers", diff --git a/tensorflow/python/autograph/g3doc/reference/common_errors.md b/tensorflow/python/autograph/g3doc/reference/common_errors.md index c663413c532..7263d01e14b 100644 --- a/tensorflow/python/autograph/g3doc/reference/common_errors.md +++ b/tensorflow/python/autograph/g3doc/reference/common_errors.md @@ -67,7 +67,33 @@ logger, with `WARNING` severity. To direct these warnings to `stdout`, use This exception is raised whenever a `tf.Tensor` is type-cast as a Python `bool`, in a context where eager execution is not active. The exception is only raised when graph execution is active, for example inside a `@tf.function` with -AutoGraph turned off. It can be caused by using a `tf.Tensor` value as: +AutoGraph turned off. + +**When AutoGraph is on**, it can be caused by: + * placing a Tensor-dependent `break`, `continue` or `return` inside a Python + loop (see example below) + * attempting to use a `tf.Tensor` in a list comprehension, by iterating over + it or using it in a condition) + +A typical example of mixing Python and TF control flow in an incompatible way +is: + +``` +for i in range(3): # Python loop + if i > tf.constant(0): # TF conditional + break # raises OperatorNotAllowedInGraphError +``` + +The way these errors are typically fixed is by ensuring all control flow is +TF control flow: + +``` +for i in tf.range(3): # TF loop + if i > tf.constant(0): # TF conditional + break # works +``` + +**When AutoGraph is off**, it can be caused by using a `tf.Tensor` value as: * the condition of an `if` or `while` statement: `if :` * the argument in a logical expression: `tensor and another_tensor` @@ -75,9 +101,6 @@ AutoGraph turned off. It can be caused by using a `tf.Tensor` value as: Note: These operations are allowed when executing eagerly. -Within the context of AutoGraph, it usually indicates eager-style control -flow that has not been converted by AutoGraph, for any reason. - When encountering this error, make sure that the function is either decorated with `@tf.function`, or called from another function decorated in this way. Also look at the console and logging output for conversion warnings (see the section diff --git a/tensorflow/python/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD index 30ee67d41bc..4c5475bbb74 100644 --- a/tensorflow/python/autograph/impl/BUILD +++ b/tensorflow/python/autograph/impl/BUILD @@ -42,7 +42,6 @@ tf_py_test( srcs = ["api_test.py"], python_version = "PY3", srcs_version = "PY3", - tfrt_enabled = True, deps = [ ":impl", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index ff8f55431ef..089c4848030 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -297,8 +297,9 @@ def _convert_actual(entity, program_ctx): def autograph_artifact(entity, extras=None): if inspect.ismethod(entity): - entity = entity.__func__ - setattr(entity, 'autograph_info__', extras) + setattr(entity.__func__, 'autograph_info__', extras) + else: + setattr(entity, 'autograph_info__', extras) return entity diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index c8e70c1385a..59ae5f4d856 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -1213,8 +1213,9 @@ class ApiTest(test.TestCase): obj = TestClass() # It's intended that tf_convert modifies the original method in this case. # This is not desirable, but options are limited. - api.tf_convert(obj.method, - ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) + converted = api.tf_convert( + obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) + self.assertTrue(converted()) self.assertTrue(obj.method()) def test_super_with_one_arg(self): diff --git a/tensorflow/python/autograph/lang/BUILD b/tensorflow/python/autograph/lang/BUILD index dca39bbbd8d..ceccc6f0c93 100644 --- a/tensorflow/python/autograph/lang/BUILD +++ b/tensorflow/python/autograph/lang/BUILD @@ -4,8 +4,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index 77d02e69976..13b3b7a1764 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -4,8 +4,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index 9bd139c031f..aaa4808cb0a 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -116,6 +116,33 @@ def _is_none_or_undef(value): or isinstance(value, variables.Undefined)) +def _verify_tf_condition(cond, tag): + """Ensures that the condition can be used in a TF control flow.""" + extra_hint = 'to check for None, use `is not None`' + cond = ops.convert_to_tensor_v2(cond) + + if cond.dtype != dtypes.bool: + raise ValueError( + 'condition of {} expected to be `tf.bool` scalar, got {}' + '; to use as boolean Tensor, use `tf.cast`' + '; {}'.format(tag, cond, extra_hint)) + + if cond.shape is None or cond.shape.ndims is None: + # TODO(mdan): Consider a explicit size check, if not too slow. + cond = array_ops.reshape(cond, ()) + + elif cond.shape.ndims > 0: + known_dims = [d for d in cond.shape.as_list() if d is not None] + if np.prod(known_dims) > 1: + raise ValueError( + 'condition of {} expected to be `tf.bool` scalar, got {}' + '; {}'.format(tag, cond, extra_hint)) + else: + cond = array_ops.reshape(cond, ()) + + return cond + + def _verify_loop_init_vars(init_vars, symbol_names, first_iter_vars=None): """Ensures that all values in the state are valid to use in a TF loop. @@ -1038,7 +1065,7 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): loop_vars = loop_vars[1:] set_state(loop_vars) - return test() + return _verify_tf_condition(test(), 'while loop') def aug_body(*loop_vars): if require_one_iteration: @@ -1141,6 +1168,8 @@ def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): def _tf_if_stmt( cond, body, orelse, get_state, set_state, symbol_names, nouts): """Overload of if_stmt that stages a TF cond.""" + cond = _verify_tf_condition(cond, 'if statement') + if not nouts: prev_get_state, prev_set_state = get_state, set_state # Control flow V1 wants at least one output. diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index 32b36a29797..8d3c63b5a89 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -35,6 +35,7 @@ from tensorflow.python.autograph.utils import testing from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -46,6 +47,20 @@ from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test +def _unranked_item(value): + rand_rank = random_ops.random_uniform( + shape=(), minval=3, maxval=4, dtype=dtypes.int32) + rand_shape = array_ops.ones([rand_rank], dtype=dtypes.int32) + return array_ops.fill(rand_shape, value) + + +def _partial_shaped_bools(): + rand_vect = math_ops.range( + random_ops.random_uniform( + shape=(), minval=2, maxval=3, dtype=dtypes.int32)) + return array_ops.expand_dims_v2(rand_vect, 0) < 0 + + class ForLoopTest(testing.AutoGraphTestCase): def test_tensor(self): @@ -871,6 +886,60 @@ class WhileLoopTest(testing.AutoGraphTestCase): with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"): self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32)) + def _fixed_while_loop(self, cond_fn): + def test_(): + return cond_fn(s) + + def body(): + nonlocal s + s += 1 + + def set_state(loop_vars): + nonlocal s + s, = loop_vars + + s = constant_op.constant(0) + control_flow.while_stmt( + test=test_, + body=body, + get_state=lambda: (s,), + set_state=set_state, + symbol_names=('s',), + opts={}) + return s + + def _assertFixedLoopResult(self, cond, expected): + def test_fn(): + return self._fixed_while_loop(cond) + self.assertEqual(test_fn(), expected) + + def test_tensor_legal_cond_scalar(self): + self._assertFixedLoopResult(lambda s: constant_op.constant(False), 0) + self._assertFixedLoopResult(lambda s: s < 2, 2) + + def test_tensor_legal_cond_single_element_nd(self): + self._assertFixedLoopResult(lambda s: constant_op.constant([[False]]), 0) + self._assertFixedLoopResult(lambda s: _unranked_item(False), 0) + + def _assertCondCheckFails(self, cond): + with self.assertRaisesRegex( + ValueError, 'condition of while loop expected to be `tf.bool`'): + self._fixed_while_loop(cond) + + def test_tensor_illegal_cond_not_bool(self): + self._assertCondCheckFails(lambda s: constant_op.constant(1)) + self._assertCondCheckFails(lambda s: s) + + def test_tensor_illegal_cond_not_single_element(self): + self._assertCondCheckFails(lambda s: constant_op.constant([1, 2, 3])) + self._assertCondCheckFails(lambda s: constant_op.constant([True, False])) + + def test_tensor_illegal_cond_not_single_element_dynamic_shape(self): + self._fixed_while_loop(lambda s: _partial_shaped_bools()) + # TODO(mdan): This error is quite bad. Measure the cost of an assertion. + self.assertRaisesRuntime( + errors_impl.InvalidArgumentError, 'requested shape has 1') + class IfStmtTest(testing.AutoGraphTestCase): @@ -1065,6 +1134,62 @@ class IfStmtTest(testing.AutoGraphTestCase): TypeError, "'x' has dtype int32.*but.*float32"): self._basic_cond(lambda: 1, lambda: 1.0) + def _fixed_cond(self, cond_val): + def body(): + nonlocal x + x = 1 + + def orelse(): + nonlocal x + x = -1 + + def set_state(cond_vars): + nonlocal x + x, = cond_vars + + x = 0 + control_flow.if_stmt( + cond=cond_val, + body=body, + orelse=orelse, + get_state=lambda: (x,), + set_state=set_state, + symbol_names=('x',), + nouts=1) + return x + + def _assertFixedCondResult(self, cond, expected): + def test_fn(): + return self._fixed_cond(cond) + self.assertEqual(test_fn(), expected) + + def test_tensor_legal_cond_scalar(self): + self._assertFixedCondResult(constant_op.constant(True), 1) + self._assertFixedCondResult(constant_op.constant(False), -1) + + def test_tensor_legal_cond_single_element_nd(self): + self._assertFixedCondResult(constant_op.constant([[True]]), 1) + self._assertFixedCondResult(constant_op.constant([[False]]), -1) + self._assertFixedCondResult(_unranked_item(True), 1) + self._assertFixedCondResult(_unranked_item(False), -1) + + def _assertCondCheckFails(self, cond): + with self.assertRaisesRegex( + ValueError, 'condition of if statement expected to be `tf.bool`'): + self._fixed_cond(cond) + + def test_tensor_illegal_cond_not_bool(self): + self._assertCondCheckFails(constant_op.constant(1)) + + def test_tensor_illegal_cond_not_single_element(self): + self._assertCondCheckFails(constant_op.constant([1, 2, 3])) + self._assertCondCheckFails(constant_op.constant([True, False])) + + def test_tensor_illegal_cond_not_single_element_dynamic_shape(self): + self._fixed_cond(_partial_shaped_bools()) + # TODO(mdan): This error is quite bad. Measure the cost of an assertion. + self.assertRaisesRuntime( + errors_impl.InvalidArgumentError, 'requested shape has 1') if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py index bf5ea035b54..b35e51f3e7d 100644 --- a/tensorflow/python/autograph/operators/py_builtins.py +++ b/tensorflow/python/autograph/operators/py_builtins.py @@ -191,8 +191,10 @@ def _tf_abs(x): def _tf_dataset_abs(x): specs = nest.flatten(x.element_spec) if len(specs) == 1: - return x.map(math_ops.abs) - return x.map(lambda *e: nest.map_structure(math_ops.abs, e)) + return x.map(math_ops.abs, num_parallel_calls=dataset_ops.AUTOTUNE) + return x.map( + lambda *e: nest.map_structure(math_ops.abs, e), + num_parallel_calls=dataset_ops.AUTOTUNE) def _py_abs(x): @@ -427,7 +429,8 @@ def map_(fn, *iterables): def _tf_dataset_map(fn, *iterables): - return dataset_ops.DatasetV2.zip(iterables).map(fn) + zipped_dataset = dataset_ops.DatasetV2.zip(iterables) + return zipped_dataset.map(fn, num_parallel_calls=dataset_ops.AUTOTUNE) def _py_map(fn, *iterables): diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 93bd9228f36..8ff18fcf2f4 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -4,8 +4,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD index 5d32f60465d..ba44b2b435c 100644 --- a/tensorflow/python/autograph/utils/BUILD +++ b/tensorflow/python/autograph/utils/BUILD @@ -4,8 +4,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/autograph/utils/testing.py b/tensorflow/python/autograph/utils/testing.py index bec6966e7cb..5b74496fb74 100644 --- a/tensorflow/python/autograph/utils/testing.py +++ b/tensorflow/python/autograph/utils/testing.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import re +import sys import types import unittest @@ -81,18 +82,29 @@ class AutoGraphTestCase(test.TestCase): @def_function.function(autograph=False) # Testing autograph itself. def fn_wrapper(): self.assertions = [] + self.raises_cm = None self.graph_assertions = [] self.trace_log = [] fn() targets = [args for _, args in self.assertions] return targets - tensors = fn_wrapper() + try: + tensors = fn_wrapper() - for assertion in self.graph_assertions: - assertion(fn_wrapper.get_concrete_function().graph) + for assertion in self.graph_assertions: + assertion(fn_wrapper.get_concrete_function().graph) + + actuals = self.evaluate(tensors) + + except: # pylint:disable=bare-except + if self.raises_cm is not None: + # Note: Yes, the Raises and function contexts cross. + self.raises_cm.__exit__(*sys.exc_info()) + return + else: + raise - actuals = self.evaluate(tensors) for (assertion, _), values in zip(self.assertions, actuals): assertion(*values) @@ -109,6 +121,7 @@ class AutoGraphTestCase(test.TestCase): super().setUp() self.variables = {} self.trace_log = [] + self.raises_cm = None op_callbacks.add_op_callback(self._op_callback) def tearDown(self): @@ -145,3 +158,9 @@ class AutoGraphTestCase(test.TestCase): def assertDictEqual(self, *args): self.assertions.append((super().assertDictEqual, list(args))) + + def assertRaisesRuntime(self, *args): + if self.raises_cm is not None: + raise ValueError('cannot use more than one assertRaisesRuntime in a test') + self.raises_cm = self.assertRaisesRegex(*args) + self.raises_cm.__enter__() diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD index 1c650ea5499..ddc8c484390 100644 --- a/tensorflow/python/compat/BUILD +++ b/tensorflow/python/compat/BUILD @@ -4,8 +4,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - py_library( name = "v2_compat", srcs = ["v2_compat.py"], diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index f1798090eb5..0c700718285 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 9, 26) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 10, 13) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD index 77c98bc61d6..fcbb5c0bd87 100644 --- a/tensorflow/python/compiler/tensorrt/BUILD +++ b/tensorflow/python/compiler/tensorrt/BUILD @@ -3,7 +3,10 @@ # and provide TensorRT operators and converter package. # APIs are meant to change over time. +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cuda_py_tests") # cuda_py_test and cuda_py_tests enable XLA tests by default. We can't @@ -37,6 +40,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", + "//tensorflow/compiler/tf2tensorrt:trt_engine_instance_proto_py", "//tensorflow/compiler/tf2tensorrt:trt_ops_loader", "//tensorflow/python:convert_to_constants", "//tensorflow/python:func_graph", @@ -101,6 +105,7 @@ cuda_py_test( xla_enable_strict_auto_jit = False, deps = [ ":trt_convert_py", + "//tensorflow/compiler/tf2tensorrt:trt_engine_instance_proto_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:graph_util", diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index 5ae4b32279b..76aa755391f 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -44,6 +44,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.profiler import trace from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import load from tensorflow.python.saved_model import loader @@ -418,7 +419,6 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): run_params, saved_model_dir, inputs_data, - config, graph_state, num_runs=2): params = self._GetParamsCached() @@ -430,6 +430,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): num_runs) gc.collect() # Force GC to destroy the TRT engine cache. return results + + # The default config for tf.session is None. Create a config with + # TensorRTOptimizer enabled to support convert_online for inference. + config = None + # TODO(b/170220818): use the default session config to run inferenence + # graphs for the offline conversion case after fixing the bug. + if graph_state == GraphState.INFERENCE: + config = self._GetConfigProto(run_params, GraphState.INFERENCE) return self._RunGraphV1(saved_model_dir, inputs_data, config, num_runs) def _CreateConverter(self, run_params, saved_model_dir, session_config, @@ -814,70 +822,72 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): return self._MakeSavedModelV1(run_params) def RunTest(self, run_params): - should_run, reason_for_skipping = self.ShouldRunTest(run_params) - if not should_run: - return self.skipTest(reason_for_skipping) + with trace.Trace(run_params.test_name): + should_run, reason_for_skipping = self.ShouldRunTest(run_params) + if not should_run: + return self.skipTest(reason_for_skipping) - saved_model_dir = self._MakeSavedModel(run_params) + saved_model_dir = self._MakeSavedModel(run_params) - np.random.seed(12345) # Fix the seed so the test is deterministic. - inputs_data = [] - input_specs = self._GetParamsCached().input_specs - for dim_list in self._GetParamsCached().input_dims: - assert len(input_specs) == len(dim_list) - current_input_data = [] - for spec, np_shape in zip(input_specs, dim_list): - np_dtype = spec.dtype.as_numpy_dtype() - # Multiply the input by some constant to avoid all zeros input for - # integer types. - scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0 - # TODO(laigd): add debug options. E.g. we can set the input data to be - # continuous natural numbers: - # seq = np.arange(np.prod(np_shape)) - # seq.resize(np_shape) - # current_inputs_data.append(scale * seq.astype(np_dtype)) - data = (scale * np.random.random_sample(np_shape)).astype(np_dtype) - if run_params.is_v2: - with ops.device("/GPU:0"): - data = ops.convert_to_tensor(data) - current_input_data.append(data) - inputs_data.append(current_input_data) + np.random.seed(12345) # Fix the seed so the test is deterministic. + inputs_data = [] + input_specs = self._GetParamsCached().input_specs + for dim_list in self._GetParamsCached().input_dims: + assert len(input_specs) == len(dim_list) + current_input_data = [] + for spec, np_shape in zip(input_specs, dim_list): + np_dtype = spec.dtype.as_numpy_dtype() + # Multiply the input by some constant to avoid all zeros input for + # integer types. + scale = 10.0 if np.issubdtype(np_dtype, np.integer) else 1.0 + # TODO(laigd): add debug options. E.g. we can set the input data to be + # continuous natural numbers: + # seq = np.arange(np.prod(np_shape)) + # seq.resize(np_shape) + # current_inputs_data.append(scale * seq.astype(np_dtype)) + data = (scale * np.random.random_sample(np_shape)).astype(np_dtype) + if run_params.is_v2: + with ops.device("/GPU:0"): + data = ops.convert_to_tensor(data) + current_input_data.append(data) + inputs_data.append(current_input_data) - # Verify original graph. - self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir, - GraphState.ORIGINAL) + # Verify the original graph. + self._VerifyGraphDef(run_params, saved_model_dir, saved_model_dir, + GraphState.ORIGINAL) - # Run original graph without trt to get reference result. - config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL) - logging.info("Running original graph w/o trt, config:\n%s", - str(config_no_trt)) - ref_result = self._RunGraph(run_params, saved_model_dir, inputs_data, - config_no_trt, GraphState.ORIGINAL) + # Run the original graph without TensorRT to get the reference result. + logging.info("Running original graph w/o TensorRT\n") + ref_result = self._RunGraph( + run_params, + saved_model_dir, + inputs_data, + GraphState.ORIGINAL, + num_runs=1) - # Run calibration if necessary. - if IsQuantizationWithCalibration(run_params): - infer_saved_model_dir = self._GetCalibratedInferGraph( - run_params, saved_model_dir, inputs_data) - self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, - GraphState.INFERENCE) - elif not run_params.convert_online: - infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir) - self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, - GraphState.INFERENCE) - else: - infer_saved_model_dir = saved_model_dir + # Run calibration if necessary. + if IsQuantizationWithCalibration(run_params): + infer_saved_model_dir = self._GetCalibratedInferGraph( + run_params, saved_model_dir, inputs_data) + self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, + GraphState.INFERENCE) + elif not run_params.convert_online: + infer_saved_model_dir = self._GetInferGraph(run_params, saved_model_dir) + self._VerifyGraphDef(run_params, saved_model_dir, infer_saved_model_dir, + GraphState.INFERENCE) + else: + infer_saved_model_dir = saved_model_dir - # Run inference. - infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE) - logging.info("Running final inference graph, config:\n%s", - str(infer_config)) - result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data, - infer_config, GraphState.INFERENCE) - self.assertAllClose( - ref_result, - result, - atol=self.ExpectedAbsoluteTolerance(run_params), - rtol=self.ExpectedRelativeTolerance(run_params)) + # Run the inference graph, either using the converted graph or the + # original graph with convert_online == True. + logging.info("Running final inference graph\n") + result = self._RunGraph(run_params, infer_saved_model_dir, inputs_data, + GraphState.INFERENCE) + self.assertAllClose( + ref_result, + result, + atol=self.ExpectedAbsoluteTolerance(run_params), + rtol=self.ExpectedRelativeTolerance(run_params)) def testIdempotence(self): # Test that applying tensorrt optimizer or offline conversion tools multiple diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index 7d52e054bea..0baae0bd3bf 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -27,6 +27,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.compiler.tf2tensorrt._pywrap_py_utils import is_tensorrt_enabled +from tensorflow.compiler.tf2tensorrt.utils.trt_engine_instance_pb2 import TRTEngineInstance # pylint: disable=g-importing-member from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -74,6 +75,10 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): def mkdtemp(self): return tempfile.mkdtemp(dir=self.get_temp_dir()) + def testTRTEngineInstanceAvailable(self): + # test if we can access the TRTEngineInstance protobuf + assert hasattr(TRTEngineInstance(), "serialized_engine") + def testGetTensorrtRewriterConfig(self): """Test case for TrtGraphConverter.get_tensorrt_rewriter_config().""" if not is_tensorrt_enabled(): diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 84e21a7b536..0ea339d6301 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -92,7 +92,6 @@ cuda_py_test( size = "small", srcs = ["copy_to_device_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -158,6 +157,7 @@ tf_py_test( srcs = ["data_service_ops_test.py"], shard_count = 10, srcs_version = "PY3", + tags = ["notap"], # TODO(b/170783829) deps = [ ":data_service_test_base", "//tensorflow:tensorflow_py", @@ -578,7 +578,6 @@ cuda_py_test( size = "small", srcs = ["prefetch_to_device_test.py"], tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py index f1c2dc488d6..99d1a7b563e 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_ft_test.py @@ -35,14 +35,6 @@ TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR NO_WORK_DIR = data_service_test_base.NO_WORK_DIR -def _all_cluster_configurations(): - with_work_dir = combinations.combine( - work_dir=TMP_WORK_DIR, fault_tolerant_mode=[True, False]) - without_work_dir = combinations.combine( - work_dir=NO_WORK_DIR, fault_tolerant_mode=False) - return with_work_dir + without_work_dir - - class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @@ -112,8 +104,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase, def testDispatcherAndWorkerRestart(self): cluster = self.create_cluster(num_workers=1) num_elements = 100 - ds = dataset_ops.Dataset.range(num_elements) - ds = self.make_distributed_dataset(ds, cluster) + ds = self.make_distributed_range_dataset(num_elements, cluster) cluster.restart_dispatcher() cluster.restart_worker() @@ -122,6 +113,28 @@ class DataServiceOpsTest(data_service_test_base.TestBase, cluster.restart_worker() self.assertDatasetProduces(ds, list(range(num_elements))) + @combinations.generate(test_base.eager_only_combinations()) + def testDispatcherAndMultiWorkerRestart(self): + num_workers = 2 + cluster = self.create_cluster(num_workers=num_workers) + num_elements = 100 + ds = self.make_distributed_range_dataset(num_elements, cluster) + iterator = iter(ds) + results = [] + + cluster.restart_dispatcher() + for worker_index in range(num_workers): + cluster.restart_worker(worker_index=worker_index) + for elem in iterator: + results.append(elem.numpy()) + self.assertCountEqual(num_workers * list(range(num_elements)), results) + cluster.restart_dispatcher() + for worker_index in range(num_workers): + cluster.restart_worker(worker_index=worker_index) + for elem in iterator: + results.append(elem.numpy()) + self.assertCountEqual(num_workers * list(range(num_elements)), results) + @combinations.generate(test_base.eager_only_combinations()) def testStartServersLate(self): # Test that the data service client performs retries instead of failing when diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py index be3c0404ee3..65674818daf 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_ops_test.py @@ -49,14 +49,6 @@ TMP_WORK_DIR = data_service_test_base.TMP_WORK_DIR NO_WORK_DIR = data_service_test_base.NO_WORK_DIR -def _all_cluster_configurations(): - with_work_dir = combinations.combine( - work_dir=TMP_WORK_DIR, fault_tolerant_mode=[True, False]) - without_work_dir = combinations.combine( - work_dir=NO_WORK_DIR, fault_tolerant_mode=False) - return with_work_dir + without_work_dir - - class DataServiceOpsTest(data_service_test_base.TestBase, parameterized.TestCase): @@ -241,29 +233,14 @@ class DataServiceOpsTest(data_service_test_base.TestBase, results.append(elem.numpy()) self.assertCountEqual(list(range(num_elements)), results) - @combinations.generate(test_base.eager_only_combinations()) - def testSharedJobNameDifferentDatasets(self): - cluster = self.create_cluster(num_workers=1) - - def make_ds(num_elements): - return dataset_ops.Dataset.range(num_elements) - - ds1 = self.make_distributed_dataset( - make_ds(num_elements=10), cluster, job_name="job_name") - ds2 = self.make_distributed_dataset( - make_ds(num_elements=11), cluster, job_name="job_name") - iter(ds1) - with self.assertRaisesRegex(errors.FailedPreconditionError, - "AttrValues are different"): - iter(ds2) - @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 - ds = dataset_ops.Dataset.range(num_elements) - ds1 = self.make_distributed_dataset(ds, cluster, job_name="job_name1") - ds2 = self.make_distributed_dataset(ds, cluster, job_name="job_name2") + ds1 = self.make_distributed_range_dataset( + num_elements, cluster, job_name="job_name1") + ds2 = self.make_distributed_range_dataset( + num_elements, cluster, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @@ -271,9 +248,10 @@ class DataServiceOpsTest(data_service_test_base.TestBase, def testSharedJobNameMultiIteration(self): cluster = self.create_cluster(num_workers=1) num_elements = 10 - ds = dataset_ops.Dataset.range(num_elements) - ds1 = self.make_distributed_dataset(ds, cluster, job_name="job_name") - ds2 = self.make_distributed_dataset(ds, cluster, job_name="job_name") + ds1 = self.make_distributed_range_dataset( + num_elements, cluster, job_name="job_name") + ds2 = self.make_distributed_range_dataset( + num_elements, cluster, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) @@ -286,10 +264,11 @@ class DataServiceOpsTest(data_service_test_base.TestBase, cluster = self.create_cluster(num_workers=1) num_elements = 100 num_repetitions = 3 - ds = dataset_ops.Dataset.range(num_elements) - ds1 = self.make_distributed_dataset(ds, cluster, job_name="job_name") + ds1 = self.make_distributed_range_dataset( + num_elements, cluster, job_name="job_name") ds1 = ds1.repeat(num_repetitions) - ds2 = self.make_distributed_dataset(ds, cluster, job_name="job_name") + ds2 = self.make_distributed_range_dataset( + num_elements, cluster, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) @@ -380,7 +359,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase, options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) - cluster = self.create_cluster(3) + cluster = self.create_cluster(num_workers=3) ds = self.make_distributed_dataset(ds, cluster) next(iter(ds)) @@ -401,36 +380,65 @@ class DataServiceOpsTest(data_service_test_base.TestBase, @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochTensorSlices(self): - cluster = self.create_cluster(2) + cluster = self.create_cluster(num_workers=2) vals = [5, 1, 2, 4] ds = dataset_ops.Dataset.from_tensor_slices(vals) - ds = ds.apply( - data_service_ops.distribute( - processing_mode="distributed_epoch", service=cluster.target)) + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces(ds, vals, assert_items_equal=True) + @combinations.generate(test_base.eager_only_combinations()) + def testDistributeDistributedEpochInterleave(self): + cluster = self.create_cluster(num_workers=2) + elements = [1, 5, 0] + ds = dataset_ops.Dataset.from_tensor_slices(elements) + ds = ds.interleave(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") + self.assertDatasetProduces(ds, elements, assert_items_equal=True) + + @combinations.generate(test_base.eager_only_combinations()) + def testDistributeDistributedEpochParallelInterleave(self): + cluster = self.create_cluster(num_workers=2) + elements = [1, 5, 0] + ds = dataset_ops.Dataset.from_tensor_slices(elements) + ds = ds.interleave( + lambda x: dataset_ops.Dataset.from_tensor_slices([x]), + num_parallel_calls=dataset_ops.AUTOTUNE) + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") + self.assertDatasetProduces(ds, elements, assert_items_equal=True) + + @combinations.generate(test_base.eager_only_combinations()) + def testDistributeDistributedEpochFlatMap(self): + cluster = self.create_cluster(num_workers=2) + elements = [1, 5, 0] + ds = dataset_ops.Dataset.from_tensor_slices(elements) + ds = ds.flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices([x])) + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") + self.assertDatasetProduces(ds, elements, assert_items_equal=True) + @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochRepeat(self): - cluster = self.create_cluster(2) + cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).repeat(num_repeats) - ds = ds.apply( - data_service_ops.distribute( - processing_mode="distributed_epoch", service=cluster.target)) + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces( ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @combinations.generate(test_base.eager_only_combinations()) def testDistributeDistributedEpochShuffleAndRepeat(self): - cluster = self.create_cluster(2) + cluster = self.create_cluster(num_workers=2) num_repeats = 5 num_elements = 20 ds = dataset_ops.Dataset.range(num_elements).shuffle(num_elements).repeat( num_repeats) - ds = ds.apply( - data_service_ops.distribute( - processing_mode="distributed_epoch", service=cluster.target)) + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces( ds, num_repeats * list(range(num_elements)), assert_items_equal=True) @@ -451,9 +459,8 @@ class DataServiceOpsTest(data_service_test_base.TestBase, cluster = self.create_cluster(num_workers=2) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) - ds = ds.apply( - data_service_ops.distribute( - processing_mode="distributed_epoch", service=cluster.target)) + ds = self.make_distributed_dataset( + ds, cluster, processing_mode="distributed_epoch") self.assertDatasetProduces( ds, list(range(num_elements)), assert_items_equal=True) @@ -608,7 +615,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase, combinations.times(test_base.eager_only_combinations())) def testDistributeLargeGraph(self): cluster = self.create_cluster( - 1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) + num_workers=1, work_dir=NO_WORK_DIR, fault_tolerant_mode=False) # Larger than default OSS grpc message size limit of 4MB. tensor = array_ops.ones((2, 1000, 1000), dtype=dtypes.float32) ds = dataset_ops.Dataset.from_tensors(tensor) diff --git a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py index bbf70eeebaf..0bb1383a56b 100644 --- a/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py +++ b/tensorflow/python/data/experimental/kernel_tests/data_service_test_base.py @@ -115,7 +115,14 @@ class TestCluster(object): # pylint: disable=protected-access def restart_dispatcher(self): - """Stops `dispatcher` and creates a new dispatcher with the same port.""" + """Stops `dispatcher` and creates a new dispatcher with the same port. + + Restarting is supported only when the dispatcher is configured with + `fault_tolerant_mode=True`. + """ + if not self.dispatcher._config.fault_tolerant_mode: + raise ValueError( + "Trying to restart the dispatcher without fault-tolerance.") port = int(self.dispatcher_address().split(":")[1]) self.dispatcher._stop() self.dispatcher = server_lib.DispatchServer( @@ -190,12 +197,13 @@ class TestBase(test_base.DatasetTestBase): def make_distributed_dataset(self, dataset, cluster, + processing_mode="parallel_epochs", job_name=None, max_outstanding_requests=None): # pylint: disable=protected-access return dataset.apply( data_service_ops._distribute( - "parallel_epochs", + processing_mode, cluster.target, job_name=job_name, max_outstanding_requests=max_outstanding_requests, @@ -203,9 +211,12 @@ class TestBase(test_base.DatasetTestBase): def make_distributed_range_dataset(self, num_elements, - dispatcher, + cluster, job_name=None, max_outstanding_requests=None): dataset = dataset_ops.Dataset.range(num_elements) - return self.make_distributed_dataset(dataset, dispatcher, job_name, - max_outstanding_requests) + return self.make_distributed_dataset( + dataset, + cluster, + job_name=job_name, + max_outstanding_requests=max_outstanding_requests) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/autotune_buffer_sizes_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/autotune_buffer_sizes_test.py index 5d89688e181..c45809a2dad 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/autotune_buffer_sizes_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/autotune_buffer_sizes_test.py @@ -81,12 +81,10 @@ class AutotuneBufferSizesTest(test_base.DatasetTestBase, ])) dataset = dataset.map( lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) - dataset = dataset.prefetch(buffer_size=3) dataset = dataset.map( lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) dataset = dataset.map( lambda x: x + 1, num_parallel_calls=dataset_ops.AUTOTUNE) - dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) dataset = dataset.interleave( lambda x: dataset_ops.Dataset.from_tensors(x + 1), num_parallel_calls=dataset_ops.AUTOTUNE) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py index fa55d3cf60b..14a0eafdd01 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import functools +import os import warnings from absl.testing import parameterized @@ -224,6 +225,65 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(dataset, expected_output=expected_output) + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(autotune=False, autotune_buffers=False) + + combinations.combine(autotune=True, autotune_buffers=False) + + combinations.combine(autotune=True, autotune_buffers=True), + combinations.combine(set_env=[False, True]))) + def testOptimizationEnableGradientDescent(self, autotune, autotune_buffers, + set_env): + if set_env: + os.environ["TF_DATA_EXPERIMENT_OPT_IN"] = "enable_gradient_descent" + os.environ["TF_JOB_NAME"] = "test_job" + + dataset = dataset_ops.Dataset.range(5) + dataset = dataset.prefetch(buffer_size=-1) + dataset = dataset.map(lambda x: x + 1, num_parallel_calls=2) + dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1) + dataset = dataset.prefetch(buffer_size=3) + dataset = dataset.map(lambda x: x + 1, num_parallel_calls=-1) + dataset = dataset.prefetch(buffer_size=1) + + options = dataset_ops.Options() + options.experimental_optimization.autotune = autotune + options.experimental_optimization.autotune_buffers = autotune_buffers + dataset = dataset.with_options(options) + + self.assertDatasetProduces(dataset, expected_output=list(range(3, 8))) + + if set_env: + del os.environ["TF_DATA_EXPERIMENT_OPT_IN"] + del os.environ["TF_JOB_NAME"] + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + combinations.combine(autotune=False, autotune_buffers=False) + + combinations.combine(autotune=True, autotune_buffers=False) + + combinations.combine(autotune=True, autotune_buffers=True), + combinations.combine(first_buffer_sizes=[(1, -1, -1, 4), + (2, -1, 3, -1), + (2, 1, -1, -1)]), + combinations.combine(second_buffer_sizes=[(1, -1, -1, 4), + (2, -1, 3, -1), + (2, 1, -1, -1)])) + ) + def testOptimizationAutotuneBuffers(self, autotune, autotune_buffers, + first_buffer_sizes, second_buffer_sizes): + dataset = dataset_ops.Dataset.range(10) + for buffer_size in first_buffer_sizes: + dataset = dataset.prefetch(buffer_size=buffer_size) + dataset = dataset.map(lambda x: x + 1) + for buffer_size in second_buffer_sizes: + dataset = dataset.prefetch(buffer_size=buffer_size) + options = dataset_ops.Options() + options.experimental_optimization.autotune = autotune + options.experimental_optimization.autotune_buffers = autotune_buffers + dataset = dataset.with_options(options) + self.assertDatasetProduces(dataset, expected_output=list(range(1, 11))) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationThreadPoolDataset(self): dataset = dataset_ops.Dataset.range(10).batch(10) @@ -366,6 +426,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): "make_sloppy", "latency_all_edges", "slack", + "disable_prefetch_legacy_autotune", ] expected_optimizations_disabled = [] expected_optimizations_default = [] @@ -414,6 +475,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): "make_sloppy", "latency_all_edges", "slack", + "disable_prefetch_legacy_autotune", ] expected_optimizations_default = [] graph_rewrites = options._graph_rewrites() @@ -443,6 +505,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): options.experimental_optimization.autotune_ram_budget = 999999999 options.experimental_optimization.autotune_buffers = True self.assertIn("autotune_buffer_sizes", options._graph_rewrites().enabled) + self.assertIn("disable_prefetch_legacy_autotune", + options._graph_rewrites().enabled) + autotune, algorithm, cpu_budget, ram_budget = options._autotune_settings() self.assertTrue(autotune) self.assertEqual(algorithm, diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 2a368c56272..c6652fb69d2 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -224,57 +224,6 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(device_dataset._variant_tensor.device, "/job:localhost/replica:0/task:0/device:GPU:0") - @combinations.generate(test_base.eager_only_combinations()) - def testIteratorOnDeviceEagerMode(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - dataset = dataset_ops.Dataset.range(10) - dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) - iterator = iter(dataset) - data = next(iterator) - optional_data = iterator.get_next_as_optional() - - self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) - self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) - self.assertIn("gpu:0", data.device.lower()) - self.assertIn("gpu:0", optional_data.get_value().device.lower()) - self.assertIn("gpu:0", optional_data.has_value().device.lower()) - - @combinations.generate(test_base.graph_only_combinations()) - def testIteratorOnDeviceGraphModeOneShotIterator(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - dataset = dataset_ops.Dataset.range(10) - dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) - iterator = dataset_ops.make_one_shot_iterator(dataset) - data = iterator.get_next() - optional_data = iterator.get_next_as_optional() - - self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) - self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) - self.assertIn("gpu:0", data.device.lower()) - self.assertIn("gpu:0", optional_data.get_value().device.lower()) - self.assertIn("gpu:0", optional_data.has_value().device.lower()) - - @combinations.generate(test_base.graph_only_combinations()) - def testIteratorOnDeviceGraphModeInitializableIterator(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - dataset = dataset_ops.Dataset.range(10) - dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) - iterator = dataset_ops.make_initializable_iterator(dataset) - data = iterator.get_next() - optional_data = iterator.get_next_as_optional() - - self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) - self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) - self.assertIn("gpu:0", data.device.lower()) - self.assertIn("gpu:0", optional_data.get_value().device.lower()) - self.assertIn("gpu:0", optional_data.has_value().device.lower()) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index f034725fc6c..f6f967ebac1 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -238,13 +238,9 @@ def _from_dataset_id(processing_mode, job_name=job_name, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) - # TODO(b/157105111): Make this an autotuned parallel map when we have a way - # to limit memory usage. - # The value 16 is chosen based on experience with pipelines that require - # more than 8 parallel calls to prevent this stage from being a bottleneck. dataset = dataset.map( lambda x: compression_ops.uncompress(x, output_spec=element_spec), - num_parallel_calls=16) + num_parallel_calls=dataset_ops.AUTOTUNE) # Disable autosharding for shared jobs. if job_name: @@ -537,11 +533,7 @@ def register_dataset(service, dataset): dataset = dataset.map( lambda *x: compression_ops.compress(x), num_parallel_calls=dataset_ops.AUTOTUNE) - # Prefetch one compressed element to reduce latency when requesting data - # from tf.data workers. - # TODO(b/157105111): Set this to autotune when we have a way to limit - # memory usage - dataset = dataset.prefetch(1) + dataset = dataset.prefetch(dataset_ops.AUTOTUNE) # Apply options so that the dataset executed in the tf.data service will # be optimized and support autotuning. dataset = dataset._apply_options() # pylint: disable=protected-access diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index 5105f30fd07..568c01646de 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -85,9 +85,9 @@ class _AutoShardDataset(dataset_ops.UnaryDataset): return self._element_spec -def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name +def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name return dataset_ops.DatasetV1Adapter( - _AutoShardDataset(input_dataset, num_workers, index)) + _AutoShardDataset(input_dataset, num_workers, index, num_replicas)) class _RebatchDataset(dataset_ops.UnaryDataset): diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py index 9e9d7254ee6..5c69855e15f 100644 --- a/tensorflow/python/data/experimental/ops/optimization_options.py +++ b/tensorflow/python/data/experimental/ops/optimization_options.py @@ -300,8 +300,11 @@ class OptimizationOptions(options.OptionsBase): # equivalent to tuning the buffer sizes of the other asynchronous # transformations. result.enabled.append("autotune_buffer_sizes") + result.enabled.append("disable_prefetch_legacy_autotune") + if self.autotune is False: # pylint: disable=g-bool-id-comparison result.disabled.append("autotune_buffer_sizes") + result.disabled.append("disable_prefetch_legacy_autotune") return result @@ -312,6 +315,7 @@ class OptimizationOptions(options.OptionsBase): graph_rewrite_configs = [] autotune_only_optimizations = [ "autotune_buffer_sizes", + "disable_prefetch_legacy_autotune", "enable_gradient_descent", "map_parallelization" ] diff --git a/tensorflow/python/data/experimental/ops/snapshot.py b/tensorflow/python/data/experimental/ops/snapshot.py index acba57da2a4..2cf7cd8704d 100644 --- a/tensorflow/python/data/experimental/ops/snapshot.py +++ b/tensorflow/python/data/experimental/ops/snapshot.py @@ -334,23 +334,27 @@ def snapshot(path, compression="AUTO", reader_func=None, shard_func=None): def _apply_fn(dataset): """Actual dataset transformation.""" + project_func = None if shard_func is None: dataset = dataset.enumerate() - dataset = _SnapshotDataset( - input_dataset=dataset, - path=path, - compression=compression, - reader_func=reader_func, - # This will not do the right thing where the graph is built on a - # different machine than the executor (e.g. Cloud TPUs). - shard_func=lambda index, _: index % multiprocessing.cpu_count()) - return dataset.map(lambda _, elem: elem) + # This sets the amount of parallelism based on the number of CPU cores on + # the machine where this Python code is executed, which may differ from + # the number of CPU cores where the input pipeline graph is actually + # executed (e.g. remote Cloud TPU workers). + local_shard_func = lambda index, _: index % multiprocessing.cpu_count() + project_func = lambda _, elem: elem else: - return _SnapshotDataset( - input_dataset=dataset, - path=path, - compression=compression, - reader_func=reader_func, - shard_func=shard_func) + local_shard_func = shard_func + dataset = _SnapshotDataset( + input_dataset=dataset, + path=path, + compression=compression, + reader_func=reader_func, + # This will not do the right thing where the graph is built on a + # different machine than the executor (e.g. Cloud TPUs). + shard_func=local_shard_func) + if project_func is not None: + dataset = dataset.map(project_func) + return dataset return _apply_fn diff --git a/tensorflow/python/data/experimental/service/server_lib_test.py b/tensorflow/python/data/experimental/service/server_lib_test.py index 16cea26614f..8248abdc2fd 100644 --- a/tensorflow/python/data/experimental/service/server_lib_test.py +++ b/tensorflow/python/data/experimental/service/server_lib_test.py @@ -172,7 +172,7 @@ class ServerLibTest(test.TestCase): # return UnavailableError with no trace events collected string. with self.assertRaises(errors.UnavailableError) as error: profiler_client.trace(worker._address, tempfile.mkdtemp(), duration_ms=10) - self.assertEqual("No trace event is collected", str(error.exception)) + self.assertStartsWith(str(error.exception), "No trace event was collected") if __name__ == "__main__": diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 23640f3196c..2705778fba6 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -572,6 +572,7 @@ cuda_py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python:test_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:def_function", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/data/kernel_tests/from_tensors_test.py b/tensorflow/python/data/kernel_tests/from_tensors_test.py index e526745e0e5..2515faec041 100644 --- a/tensorflow/python/data/kernel_tests/from_tensors_test.py +++ b/tensorflow/python/data/kernel_tests/from_tensors_test.py @@ -252,12 +252,12 @@ class FromTensorsTest(test_base.DatasetTestBase, parameterized.TestCase): # placement algorithm overriding the DT_RESOURCE colocation constraints. with ops.device("/cpu:0"): var_0 = resource_variable_ops.ResourceVariable(initial_value=1) - dataset = dataset.map(lambda x: x + var_0.read_value()) + dataset = dataset.map(lambda x: x + var_0.read_value()) sess.run(var_0.initializer) with ops.device("/cpu:1"): var_1 = resource_variable_ops.ResourceVariable(initial_value=1) - dataset = dataset.map(lambda x: x + var_1.read_value()) + dataset = dataset.map(lambda x: x + var_1.read_value()) sess.run(var_1.initializer) iterator = dataset_ops.make_initializable_iterator(dataset) diff --git a/tensorflow/python/data/kernel_tests/placement_test.py b/tensorflow/python/data/kernel_tests/placement_test.py index 339f132b06c..43bf1c8c0b5 100644 --- a/tensorflow/python/data/kernel_tests/placement_test.py +++ b/tensorflow/python/data/kernel_tests/placement_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import def_function @@ -27,13 +28,14 @@ from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -@combinations.generate(test_base.v2_eager_only_combinations()) class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): """Tests for tf.data placement within tf.functions. @@ -48,6 +50,7 @@ class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): # cross-device copies. config.set_optimizer_experimental_options({"disable_meta_optimizer": True}) + @combinations.generate(test_base.eager_only_combinations()) def testWhileWithCapturedDataset(self): dataset = dataset_ops.Dataset.range(10) @@ -61,6 +64,7 @@ class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(f().numpy(), 45) + @combinations.generate(test_base.eager_only_combinations()) def testWhile(self): self.skipTest("b/166625126") @@ -75,8 +79,8 @@ class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(f().numpy(), 45) + @combinations.generate(test_base.eager_only_combinations()) def testCondWithPlacement(self): - self.skipTest("b/166625126") # When the cond op is explicitly placed, there shouldn't be cross-device # copies. @def_function.function @@ -93,10 +97,10 @@ class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): nxt = next(iterator) return nxt - self.assertEqual(f(), 1) + self.assertEqual(f().numpy(), 1) + @combinations.generate(test_base.eager_only_combinations()) def testCondWithColocation(self): - self.skipTest("b/166625126") # When the cond op is colocated with the dataset, there shouldn't be # cross-device copies. @def_function.function @@ -115,6 +119,7 @@ class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(f().numpy(), 1) + @combinations.generate(test_base.eager_only_combinations()) def testCond(self): self.skipTest("b/166625126") # Ideally, placer should avoid cross-device copies even when the cond op @@ -134,6 +139,7 @@ class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(f().numpy(), 1) + @combinations.generate(test_base.eager_only_combinations()) def testId(self): self.skipTest("b/166625126") # Ideally, placer should know that Identity(dataset) should be on the same @@ -145,5 +151,104 @@ class PlacementTest(test_base.DatasetTestBase, parameterized.TestCase): return dataset f() + @combinations.generate(test_base.eager_only_combinations()) + def testIteratorOnDeviceEagerMode(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) + iterator = iter(dataset) + data = next(iterator) + optional_data = iterator.get_next_as_optional() + + self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) + self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) + self.assertIn("gpu:0", data.device.lower()) + self.assertIn("gpu:0", optional_data.get_value().device.lower()) + self.assertIn("gpu:0", optional_data.has_value().device.lower()) + + @combinations.generate(test_base.graph_only_combinations()) + def testIteratorOnDeviceGraphModeOneShotIterator(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + self.skipTest("TODO(b/169429285): tf.data.Dataset.make_one_shot_iterator " + "does not support GPU placement.") + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) + iterator = dataset_ops.make_one_shot_iterator(dataset) + data = iterator.get_next() + optional_data = iterator.get_next_as_optional() + + with ops.colocate_with(dataset._variant_tensor): + dataset_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(dataset_device)) + + with ops.colocate_with(iterator._iterator_resource): + iterator_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(iterator_device)) + + with ops.colocate_with(data): + data_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(data_device)) + + with ops.colocate_with(optional_data.get_value()): + get_value_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(get_value_device)) + + with ops.colocate_with(optional_data.has_value()): + has_value_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(has_value_device)) + + @combinations.generate(test_base.graph_only_combinations()) + def testIteratorOnDeviceGraphModeInitializableIterator(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) + iterator = dataset_ops.make_initializable_iterator(dataset) + data = iterator.get_next() + optional_data = iterator.get_next_as_optional() + + with ops.colocate_with(dataset._variant_tensor): + dataset_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(dataset_device)) + + with ops.colocate_with(iterator._iterator_resource): + iterator_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(iterator_device)) + + with ops.colocate_with(data): + data_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(data_device)) + + with ops.colocate_with(optional_data.get_value()): + get_value_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(get_value_device)) + + with ops.colocate_with(optional_data.has_value()): + has_value_device = test_ops.device_placement_op() + self.assertIn(b"GPU:0", self.evaluate(has_value_device)) + + @combinations.generate(test_base.eager_only_combinations()) + def testIterDatasetEagerModeWithExplicitDevice(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @def_function.function + def comp(): + value = constant_op.constant(0, dtype=dtypes.int64) + for d in iter(dataset_ops.Dataset.range(10)): + value += d + return value + + with ops.device("/gpu:0"): + result = comp() + self.assertEqual(result.numpy(), 45) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/window_test.py b/tensorflow/python/data/kernel_tests/window_test.py index 98b453a5900..2515bd52f60 100644 --- a/tensorflow/python/data/kernel_tests/window_test.py +++ b/tensorflow/python/data/kernel_tests/window_test.py @@ -239,6 +239,17 @@ class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(x, range(i*10, (i+1)*10)) self.assertDatasetProduces(y, range(i*10, (i+1)*10)) + @combinations.generate(test_base.default_test_combinations()) + def testDropRemainderOutput(self): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.window(30, drop_remainder=True) + dataset = dataset.flat_map(lambda x: x.batch(30)) + dataset = dataset.batch(4) + + self.assertDatasetProduces( + dataset, + expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 2ac9643ec0c..a7d2ed4840f 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -418,7 +418,7 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, RuntimeError: If not inside of tf.function and not executing eagerly. """ if context.executing_eagerly() or ops.inside_function(): - with ops.device(self._variant_tensor.device): + with ops.colocate_with(self._variant_tensor): return iterator_ops.OwnedIterator(self) else: raise RuntimeError("__iter__() is only supported inside of tf.function " @@ -2273,7 +2273,7 @@ class DatasetV1(DatasetV2): def _make_one_shot_iterator(self): # pylint: disable=missing-docstring if context.executing_eagerly(): - with ops.device(self._variant_tensor.device): + with ops.colocate_with(self._variant_tensor): return iterator_ops.OwnedIterator(self) _ensure_same_dataset_graph(self) @@ -2316,7 +2316,7 @@ class DatasetV1(DatasetV2): else: six.reraise(ValueError, err) - with ops.device(self._variant_tensor.device): + with ops.colocate_with(self._variant_tensor): # pylint: disable=protected-access return iterator_ops.Iterator( gen_dataset_ops.one_shot_iterator( @@ -2383,7 +2383,7 @@ class DatasetV1(DatasetV2): if shared_name is None: shared_name = "" - with ops.device(self._variant_tensor.device): + with ops.colocate_with(self._variant_tensor): iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, **self._flat_structure) @@ -4389,7 +4389,7 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): # pylint: disable=protected-access # We colocate the prefetch dataset with its input as this collocation only # happens automatically in graph mode. - with ops.device(input_dataset._variant_tensor.device): + with ops.colocate_with(input_dataset._variant_tensor): variant_tensor = gen_dataset_ops.prefetch_dataset( input_dataset._variant_tensor, buffer_size=self._buffer_size, diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index c18831cc45d..748ae00b384 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -205,22 +205,12 @@ class Iterator(trackable.Trackable): output_types, output_shapes, output_classes) if shared_name is None: shared_name = "" - if _device_stack_is_empty(): - with ops.device("/cpu:0"): - iterator_resource = gen_dataset_ops.iterator_v2( - container="", - shared_name=shared_name, - output_types=structure.get_flat_tensor_types( - output_structure), - output_shapes=structure.get_flat_tensor_shapes( - output_structure)) - else: - iterator_resource = gen_dataset_ops.iterator_v2( - container="", - shared_name=shared_name, - output_types=structure.get_flat_tensor_types(output_structure), - output_shapes=structure.get_flat_tensor_shapes( - output_structure)) + iterator_resource = gen_dataset_ops.iterator_v2( + container="", + shared_name=shared_name, + output_types=structure.get_flat_tensor_types(output_structure), + output_shapes=structure.get_flat_tensor_shapes( + output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -288,17 +278,10 @@ class Iterator(trackable.Trackable): output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) - if _device_stack_is_empty(): - with ops.device("/cpu:0"): - iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( - string_handle, - output_types=structure.get_flat_tensor_types(output_structure), - output_shapes=structure.get_flat_tensor_shapes(output_structure)) - else: - iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( - string_handle, - output_types=structure.get_flat_tensor_types(output_structure), - output_shapes=structure.get_flat_tensor_shapes(output_structure)) + iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=structure.get_flat_tensor_types(output_structure), + output_shapes=structure.get_flat_tensor_shapes(output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -372,7 +355,8 @@ class Iterator(trackable.Trackable): "dataset with output shapes %r." % (self.output_shapes, dataset_output_shapes)) - with ops.device(self._iterator_resource.device): + # TODO(b/169442955): Investigate the need for this colocation constraint. + with ops.colocate_with(self._iterator_resource): # pylint: disable=protected-access return gen_dataset_ops.make_iterator( dataset._variant_tensor, self._iterator_resource, name=name) @@ -425,7 +409,8 @@ class Iterator(trackable.Trackable): if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) - with ops.device(self._iterator_resource.device): + # TODO(b/169442955): Investigate the need for this colocation constraint. + with ops.colocate_with(self._iterator_resource): # pylint: disable=protected-access flat_ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, @@ -435,15 +420,15 @@ class Iterator(trackable.Trackable): return structure.from_tensor_list(self._element_spec, flat_ret) def get_next_as_optional(self): - device = self._iterator_resource.device - # pylint: disable=protected-access - with ops.device(device): + # TODO(b/169442955): Investigate the need for this colocation constraint. + with ops.colocate_with(self._iterator_resource): + # pylint: disable=protected-access return optional_ops._OptionalImpl( gen_dataset_ops.iterator_get_next_as_optional( self._iterator_resource, output_types=structure.get_flat_tensor_types(self.element_spec), output_shapes=structure.get_flat_tensor_shapes( - self.element_spec)), self.element_spec, device) + self.element_spec)), self.element_spec) def string_handle(self, name=None): """Returns a string-valued `tf.Tensor` that represents this iterator. @@ -535,25 +520,23 @@ class IteratorResourceDeleter(object): object is part of a reference cycle, the cycle will be collectable. """ - __slots__ = ["_deleter", "_handle", "_device", "_eager_mode"] + __slots__ = ["_deleter", "_handle", "_eager_mode"] - def __init__(self, handle, device, deleter): + def __init__(self, handle, deleter): self._deleter = deleter self._handle = handle - self._device = device self._eager_mode = context.executing_eagerly() def __del__(self): - with ops.device(self._device): - # Make sure the resource is deleted in the same mode as it was created in. - if self._eager_mode: - with context.eager_mode(): - gen_dataset_ops.delete_iterator( - handle=self._handle, deleter=self._deleter) - else: - with context.graph_mode(): - gen_dataset_ops.delete_iterator( - handle=self._handle, deleter=self._deleter) + # Make sure the resource is deleted in the same mode as it was created in. + if self._eager_mode: + with context.eager_mode(): + gen_dataset_ops.delete_iterator( + handle=self._handle, deleter=self._deleter) + else: + with context.graph_mode(): + gen_dataset_ops.delete_iterator( + handle=self._handle, deleter=self._deleter) @tf_export("data.Iterator", v1=[]) @@ -683,8 +666,6 @@ class OwnedIterator(IteratorBase): error_message = ("Either `dataset` or both `components` and " "`element_spec` need to be provided.") - self._device = context.context().device_name - if dataset is None: if (components is None or element_spec is None): raise ValueError(error_message) @@ -698,12 +679,7 @@ class OwnedIterator(IteratorBase): else: if (components is not None or element_spec is not None): raise ValueError(error_message) - if (_device_stack_is_empty() or - context.context().device_spec.device_type != "CPU"): - with ops.device("/cpu:0"): - self._create_iterator(dataset) - else: - self._create_iterator(dataset) + self._create_iterator(dataset) def _create_iterator(self, dataset): # pylint: disable=protected-access @@ -721,7 +697,7 @@ class OwnedIterator(IteratorBase): self._element_spec) self._flat_output_shapes = structure.get_flat_tensor_shapes( self._element_spec) - with ops.device(ds_variant.device): + with ops.colocate_with(ds_variant): self._iterator_resource, self._deleter = ( gen_dataset_ops.anonymous_iterator_v2( output_types=self._flat_output_types, @@ -730,7 +706,6 @@ class OwnedIterator(IteratorBase): # Delete the resource when this object is deleted self._resource_deleter = IteratorResourceDeleter( handle=self._iterator_resource, - device=self._device, deleter=self._deleter) def __iter__(self): @@ -741,25 +716,21 @@ class OwnedIterator(IteratorBase): def _next_internal(self): if not context.executing_eagerly(): - with ops.device(self._device): + # TODO(b/169442955): Investigate the need for this colocation constraint. + with ops.colocate_with(self._iterator_resource): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return structure.from_compatible_tensor_list(self._element_spec, ret) - # This runs in sync mode as iterators use an error status to communicate - # that there is no more data to iterate over. - # TODO(b/77291417): Fix + # TODO(b/77291417): This runs in sync mode as iterators use an error status + # to communicate that there is no more data to iterate over. with context.execution_mode(context.SYNC): - with ops.device(self._device): - # TODO(ashankar): Consider removing this ops.device() context manager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - ret = gen_dataset_ops.iterator_get_next( - self._iterator_resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) + ret = gen_dataset_ops.iterator_get_next( + self._iterator_resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) try: # Fast path for the case `self._structure` is not a nested structure. @@ -829,14 +800,15 @@ class OwnedIterator(IteratorBase): return self._next_internal() def get_next_as_optional(self): - # pylint: disable=protected-access - with ops.device(self._device): + # TODO(b/169442955): Investigate the need for this colocation constraint. + with ops.colocate_with(self._iterator_resource): + # pylint: disable=protected-access return optional_ops._OptionalImpl( gen_dataset_ops.iterator_get_next_as_optional( self._iterator_resource, output_types=structure.get_flat_tensor_types(self.element_spec), output_shapes=structure.get_flat_tensor_shapes( - self.element_spec)), self.element_spec, self._device) + self.element_spec)), self.element_spec) def _gather_saveables_for_checkpoint(self): diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py index 8f679927855..d3ca3f667af 100644 --- a/tensorflow/python/data/ops/optional_ops.py +++ b/tensorflow/python/data/ops/optional_ops.py @@ -170,13 +170,12 @@ class _OptionalImpl(Optional): `Optional.__init__()` in the public API. """ - def __init__(self, variant_tensor, element_spec, device=None): + def __init__(self, variant_tensor, element_spec): self._variant_tensor = variant_tensor self._element_spec = element_spec - self._device = device def has_value(self, name=None): - with ops.device(self._device): + with ops.colocate_with(self._variant_tensor): return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name) def get_value(self, name=None): @@ -184,16 +183,15 @@ class _OptionalImpl(Optional): # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: - with ops.device(self._device): - return structure.from_tensor_list( - self._element_spec, - gen_dataset_ops.optional_get_value( - self._variant_tensor, - name=scope, - output_types=structure.get_flat_tensor_types( - self._element_spec), - output_shapes=structure.get_flat_tensor_shapes( - self._element_spec))) + with ops.colocate_with(self._variant_tensor): + result = gen_dataset_ops.optional_get_value( + self._variant_tensor, + name=scope, + output_types=structure.get_flat_tensor_types(self._element_spec), + output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) + # NOTE: We do not colocate the deserialization of composite tensors + # because not all ops are guaranteed to have non-GPU kernels. + return structure.from_tensor_list(self._element_spec, result) @property def element_spec(self): diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 61fbda603fe..8b005983719 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -537,6 +537,12 @@ class NoneTensorSpec(type_spec.BatchableTypeSpec): def _to_legacy_output_classes(self): return self + def most_specific_compatible_shape(self, other): + if type(self) is not type(other): + raise ValueError("No TypeSpec is compatible with both %s and %s" % + (self, other)) + return self + type_spec.register_type_spec_from_value_converter(type(None), NoneTensorSpec.from_value) diff --git a/tensorflow/python/debug/cli/cli_shared_test.py b/tensorflow/python/debug/cli/cli_shared_test.py index 3182d29d125..2d461c9e630 100644 --- a/tensorflow/python/debug/cli/cli_shared_test.py +++ b/tensorflow/python/debug/cli/cli_shared_test.py @@ -105,7 +105,7 @@ class TimeToReadableStrTest(test_util.TensorFlowTestCase): cli_shared.time_to_readable_str(100, force_time_unit="ks") -@test_util.run_deprecated_v1 +@test_util.run_v1_only("tfdbg CLI is for tf.Session only") class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase): def setUp(self): @@ -319,7 +319,7 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase): self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description) -@test_util.run_deprecated_v1 +@test_util.run_v1_only("tfdbg CLI is for tf.Session only") class GetErrorIntroTest(test_util.TensorFlowTestCase): def setUp(self): diff --git a/tensorflow/python/debug/lib/common_test.py b/tensorflow/python/debug/lib/common_test.py index f6413f6b7b3..8469700accf 100644 --- a/tensorflow/python/debug/lib/common_test.py +++ b/tensorflow/python/debug/lib/common_test.py @@ -27,7 +27,7 @@ from tensorflow.python.platform import googletest class CommonTest(test_util.TensorFlowTestCase): - @test_util.run_deprecated_v1 + @test_util.run_v1_only("Relies on tensor name, which is unavailable in TF2") def testOnFeedOneFetch(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") @@ -36,7 +36,7 @@ class CommonTest(test_util.TensorFlowTestCase): self.assertItemsEqual(["a:0"], loaded[0]) self.assertItemsEqual(["b:0"], loaded[1]) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("Relies on tensor name, which is unavailable in TF2") def testGetRunKeyFlat(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") @@ -45,7 +45,7 @@ class CommonTest(test_util.TensorFlowTestCase): self.assertItemsEqual(["a:0"], loaded[0]) self.assertItemsEqual(["a:0", "b:0"], loaded[1]) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("Relies on tensor name, which is unavailable in TF2") def testGetRunKeyNestedFetches(self): a = constant_op.constant(10.0, name="a") b = constant_op.constant(20.0, name="b") diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py index 95da6cb9ff8..e7a4d2d5450 100644 --- a/tensorflow/python/debug/lib/debug_gradients_test.py +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import gradient_descent -@test_util.run_deprecated_v1 +@test_util.run_v1_only("Sessions are not available in TF 2.x") class IdentifyGradientTest(test_util.TensorFlowTestCase): def setUp(self): diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index d0a4ecdbac4..366b25e89ac 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -106,7 +106,7 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): self.assertTrue( source_utils.guess_is_tensorflow_py_library(source_utils.__file__)) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("Tensor.op is not available in TF 2.x") def testFileInPythonKernelsPathReturnsTrue(self): x = constant_op.constant(42.0, name="x") self.assertTrue( @@ -320,7 +320,7 @@ class SourceHelperTest(test_util.TensorFlowTestCase): source_utils.load_source(source_path) -@test_util.run_v1_only("b/120545219") +@test_util.run_v1_only("Sessions are not available in TF 2.x") class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase): def createAndRunGraphWithWhileLoop(self): diff --git a/tensorflow/python/debug/wrappers/disk_usage_test.py b/tensorflow/python/debug/wrappers/disk_usage_test.py index 71c56b33106..e9cb388c89f 100644 --- a/tensorflow/python/debug/wrappers/disk_usage_test.py +++ b/tensorflow/python/debug/wrappers/disk_usage_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import monitored_session -@test_util.run_deprecated_v1 +@test_util.run_v1_only("Sessions are not available in TF 2.x") class DumpingDebugWrapperDiskUsageLimitTest(test_util.TensorFlowTestCase): @classmethod diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py index 266404a1038..7ca0cf192fa 100644 --- a/tensorflow/python/debug/wrappers/framework_test.py +++ b/tensorflow/python/debug/wrappers/framework_test.py @@ -141,7 +141,7 @@ class TestDebugWrapperSessionBadAction(framework.BaseDebugWrapperSession): return framework.OnRunEndResponse() -@test_util.run_deprecated_v1 +@test_util.run_v1_only("Sessions are not available in TF 2.x") class DebugWrapperSessionTest(test_util.TensorFlowTestCase): def _no_rewrite_session_config(self): diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index dcd32555abe..c9599747f3f 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -4,7 +4,10 @@ load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") package( - default_visibility = ["//tensorflow:internal"], + default_visibility = [ + "//tensorflow:internal", + "//third_party/py/keras:__subpackages__", # TODO(scottzhu): remove this once keras is relying on tf.__internal__. + ], licenses = ["notice"], # Apache 2.0 ) @@ -93,13 +96,14 @@ py_library( ":values", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", - "//tensorflow/python:device", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nccl_ops", "//tensorflow/python:platform", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", ], ) @@ -142,6 +146,8 @@ py_library( ":cross_device_ops", ":distribute_lib", ":mirrored_strategy", + ":multi_process_runner", + ":multi_worker_test_base", ":one_device_strategy", ":sharded_variable", "//tensorflow/python/distribute/client", @@ -201,7 +207,6 @@ py_test( "//tensorflow/python/autograph/core:test_lib", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", - "//third_party/py/numpy", ], ) @@ -788,6 +793,7 @@ py_library( visibility = [ "//tensorflow:internal", "//tensorflow_models:__subpackages__", + "//third_party/py/keras:__subpackages__", ], deps = [ ":collective_all_reduce_strategy", @@ -811,6 +817,7 @@ py_test( python_version = "PY3", deps = [ ":combinations", + ":test_util", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_combinations", "//tensorflow/python/distribute/cluster_resolver:tfconfig_cluster_resolver_py", @@ -825,6 +832,7 @@ py_library( visibility = [ "//tensorflow:internal", "//tensorflow_models:__subpackages__", + "//third_party/py/keras:__subpackages__", ], deps = [ ":central_storage_strategy", @@ -835,6 +843,7 @@ py_library( ":multi_process_runner", ":multi_worker_test_base", ":one_device_strategy", + ":test_util", ":tpu_strategy", "//tensorflow/python:config", "//tensorflow/python:platform", @@ -853,12 +862,21 @@ distribute_py_test( disable_mlir_bridge = False, python_version = "PY3", deps = [ + ":central_storage_strategy", + ":collective_all_reduce_strategy", ":combinations", + ":mirrored_strategy", + ":one_device_strategy", ":reduce_util", ":strategy_combinations", + ":test_util", + ":tpu_strategy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", "//tensorflow/python:config", "//tensorflow/python:constant_op", - "//tensorflow/python/eager:context", + "//tensorflow/python:tf2", + "//tensorflow/python/eager:def_function", "@absl_py//absl/testing:parameterized", ], ) @@ -868,6 +886,7 @@ py_library( srcs = ["multi_worker_test_base.py"], srcs_version = "PY2AND3", deps = [ + ":distribute_coordinator", ":multi_process_runner", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -946,6 +965,7 @@ distribute_py_test( ":multi_worker_test_base", ":reduce_util", ":strategy_combinations", + ":test_util", ":values", "//tensorflow/python:control_flow_ops", "//tensorflow/python:errors", @@ -1021,6 +1041,7 @@ cuda_py_test( "multi_and_single_gpu", ], deps = [ + ":collective_util", ":combinations", ":cross_device_ops", ":multi_process_runner", @@ -1028,6 +1049,10 @@ cuda_py_test( ":reduce_util", ":values", "//tensorflow/python:array_ops", + "//tensorflow/python:collective_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:indexed_slices", "//tensorflow/python:util", @@ -1228,6 +1253,7 @@ distribute_py_test( ":combinations", ":distribute_lib", ":strategy_combinations", + ":test_util", ":tpu_strategy", ":tpu_values", ":values", @@ -1280,6 +1306,7 @@ distribute_py_test( deps = [ ":combinations", ":strategy_combinations", + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", @@ -1330,23 +1357,6 @@ distribute_py_test( ], ) -distribute_py_test( - name = "strategy_reduce_test", - srcs = ["strategy_reduce_test.py"], - main = "strategy_reduce_test.py", - tags = [ - "multi_and_single_gpu", - ], - deps = [ - ":combinations", - ":strategy_combinations", - "//tensorflow/python:errors", - "//tensorflow/python:variables", - "//tensorflow/python/eager:test", - "@absl_py//absl/testing:parameterized", - ], -) - py_library( name = "single_loss_example", srcs = ["single_loss_example.py"], @@ -1522,7 +1532,9 @@ cuda_py_test( ":multi_worker_test_base", ":multi_worker_util", ":reduce_util", + ":strategy_combinations", ":strategy_test_lib", + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -1708,6 +1720,38 @@ distribute_py_test( ":reduce_util", ":strategy_combinations", ":strategy_test_lib", + ":test_util", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:def_function", + "@absl_py//absl/testing:parameterized", + ], +) + +distribute_py_test( + name = "strategy_gather_test", + srcs = ["strategy_gather_test.py"], + disable_mlir_bridge = False, + python_version = "PY3", + shard_count = 2, + tags = [ + "multi_and_single_gpu", + "notsan", # TODO(b/160006974) + ], + xla_enable_strict_auto_jit = True, + deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":multi_worker_test_base", + ":reduce_util", + ":strategy_combinations", + ":strategy_test_lib", + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -1753,20 +1797,21 @@ py_library( srcs_version = "PY3", deps = [ ":collective_all_reduce_strategy", - ":cross_device_utils", - ":distribute_utils", + ":multi_process_runner", ":values", "//tensorflow/python:array_ops", + "//tensorflow/python:config", "//tensorflow/python:framework_ops", "//tensorflow/python:util", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/types", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/eager:context", ], ) distribute_py_test( name = "test_util_test", srcs = ["test_util_test.py"], + disable_mlir_bridge = False, tags = [ "multi_and_single_gpu", ], diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py index 34b76615d2a..6eabbfa219a 100644 --- a/tensorflow/python/distribute/client/client.py +++ b/tensorflow/python/distribute/client/client.py @@ -163,13 +163,14 @@ class RemoteValue(object): The remote value, as a numpy data type (if scalar) or ndarray. Raises: - FunctionRetryableError: If the function that produces this `RemoteValue` + tf.errors.CancelledError: If the function that produces this `RemoteValue` is aborted or cancelled due to failure, and the user should handle and reschedule. """ self._status_available_event.wait() if self._status is _RemoteValueStatus.ABORTED: - raise FunctionRetryableError( + raise errors.CancelledError( + None, None, "The corresponding function is aborted. Please reschedule the " "function.") if self._error is not None: @@ -191,11 +192,6 @@ class InputError(Exception): super().__init__(message) -class FunctionRetryableError(Exception): - """An error that represents the closure was aborted and should be retried.""" - pass - - def _maybe_rebuild_remote_values(worker, structure): """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed.""" errors_in_structure = [] @@ -324,17 +320,12 @@ class Closure(object): # It will do nothing if there is no return value. nest.map_structure(lambda x: x.fetch(), self._output_remote_values) # pylint: disable=protected-access - def _set_output_remote_values_aborted(self): - """Set output remote_value aborted.""" - # It will do nothing if there is no return value. - nest.map_structure(lambda x: x._set_aborted(), self._output_remote_values) # pylint: disable=protected-access - def _set_output_remote_values_cancelled(self): nest.map_structure( lambda x: x._set_error( # pylint: disable=protected-access,g-long-lambda - FunctionRetryableError("The corresponding function is " - "cancelled. Please reschedule the " - "function.")), + errors.CancelledError( + None, None, "The corresponding function is " + "cancelled. Please reschedule the function.")), self._output_remote_values) # pylint: disable=protected-access def execute_on(self, worker): diff --git a/tensorflow/python/distribute/client/client_mpr_test.py b/tensorflow/python/distribute/client/client_mpr_test.py index fc1c0ce7c7c..802b23e87ec 100644 --- a/tensorflow/python/distribute/client/client_mpr_test.py +++ b/tensorflow/python/distribute/client/client_mpr_test.py @@ -49,7 +49,7 @@ class ClientMprTest(test.TestCase): test_schedule=False, test_join=False): - def proc_func(functions_scheduled_event, test_finished_event): + def fn(functions_scheduled_event, test_finished_event): cluster_resolver = TFConfigClusterResolver() if cluster_resolver.task_type != "chief": utils.start_server(cluster_resolver, "grpc") @@ -107,12 +107,12 @@ class ClientMprTest(test.TestCase): functions_scheduled_event = manager.Event() test_finished_event = manager.Event() mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=3, num_ps=1, has_eval=False), args=(functions_scheduled_event, test_finished_event), rpc_layer="grpc", - list_stdout=True, + return_output=True, use_dill_for_args=False) mpr.start() diff --git a/tensorflow/python/distribute/client/client_test.py b/tensorflow/python/distribute/client/client_test.py index fb68716dc26..981ad964b6d 100644 --- a/tensorflow/python/distribute/client/client_test.py +++ b/tensorflow/python/distribute/client/client_test.py @@ -27,6 +27,7 @@ import time from tensorflow.python.distribute.client import client from tensorflow.python.eager import cancellation from tensorflow.python.eager import def_function +from tensorflow.python.framework import errors from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import coordinator @@ -203,7 +204,7 @@ class CoordinatedClosureQueueTest(test.TestCase): self.assertTrue(closure_queue.done()) with self.assertRaisesRegex( - client.FunctionRetryableError, + errors.CancelledError, 'The corresponding function is cancelled. Please reschedule the ' 'function.'): closure2._fetch_output_remote_values() @@ -225,7 +226,7 @@ class CoordinatedClosureQueueTest(test.TestCase): self.assertTrue(closure_queue.done()) with self.assertRaisesRegex( - client.FunctionRetryableError, + errors.CancelledError, 'The corresponding function is cancelled. Please reschedule the ' 'function.'): closure2._fetch_output_remote_values() @@ -307,7 +308,7 @@ class CoordinatedClosureQueueTest(test.TestCase): # The following asserts that closure3 should have been cancelled. if not call_wait: with self.assertRaisesRegex( - client.FunctionRetryableError, + errors.CancelledError, 'The corresponding function is cancelled. Please reschedule the ' 'function.'): closure3._fetch_output_remote_values() diff --git a/tensorflow/python/distribute/client/parameter_server_client_test.py b/tensorflow/python/distribute/client/parameter_server_client_test.py index 78bd5f76c61..022539308d1 100644 --- a/tensorflow/python/distribute/client/parameter_server_client_test.py +++ b/tensorflow/python/distribute/client/parameter_server_client_test.py @@ -406,7 +406,7 @@ class ErrorReportingTest(TestCaseWithErrorReportingThread): with self.assertRaises(errors.InvalidArgumentError): self.client.join() - with self.assertRaises(client_lib.FunctionRetryableError): + with self.assertRaises(errors.CancelledError): long_function.fetch() for _ in range(3): diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index a3e63e8a6f1..5719fc6870a 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -101,6 +101,9 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy): # TODO(anjalisridhar): Update our guides with examples showing how we can use # the cluster_resolver argument. + # The starting number for collective keys. This should only be set in tests. + _collective_key_base = 0 + def __init__( self, communication=cross_device_ops_lib.CollectiveCommunication.AUTO, @@ -362,7 +365,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): else: local_devices = (self._worker_device,) - self._collective_keys = cross_device_utils.CollectiveKeys() + self._collective_keys = cross_device_utils.CollectiveKeys( + group_key_start=1 + CollectiveAllReduceStrategy._collective_key_base) # pylint: disable=protected-access self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices) * self._num_workers, @@ -428,7 +432,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): group_key = self._collective_keys.get_group_key([device]) group_size = self._num_workers collective_instance_key = ( - self._collective_keys.get_variable_instance_key()) + self._collective_keys.get_instance_key(group_key, device)) with ops.device(device): initial_value = kwargs["initial_value"] @@ -476,7 +480,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): dataset, self._input_workers_with_options(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync, + num_replicas_in_sync=self._num_replicas_in_sync, input_context=input_context) def _distribute_datasets_from_function(self, dataset_fn, options): @@ -505,7 +509,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): dataset, self._input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync, + num_replicas_in_sync=self._num_replicas_in_sync, input_context=input_context) def _make_input_fn_iterator( @@ -767,3 +771,10 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): Boolean. """ return True + + def _get_replica_id_in_sync_group(self, replica_id): + return self._id_in_cluster * len(self.worker_devices) + replica_id + + def _get_local_replica_id(self, replica_id_in_sync_group): + return (replica_id_in_sync_group - + self._id_in_cluster * len(self.worker_devices)) diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py index 67e156f1a3d..305008f6cb8 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy_test.py @@ -29,14 +29,16 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations -from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_test_lib +from tensorflow.python.distribute import test_util from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.framework import config as tf_config @@ -56,6 +58,8 @@ from tensorflow.python.platform import test from tensorflow.python.training.server_lib import ClusterSpec +CollectiveAllReduceStrategy = ( + collective_all_reduce_strategy.CollectiveAllReduceStrategy) CollectiveAllReduceExtended = ( collective_all_reduce_strategy.CollectiveAllReduceExtended) @@ -90,14 +94,10 @@ def create_test_objects(cluster_spec=None, class CollectiveAllReduceStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): - collective_key_base = 0 - def setUp(self): # We use a different key_base for each test so that collective keys won't be # reused. - # TODO(yuefengz, ayushd): enable it to reuse collective keys in different - # tests. - CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 + CollectiveAllReduceStrategy._collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() def _get_test_object(self, task_type, task_id, num_gpus=0): @@ -106,18 +106,6 @@ class CollectiveAllReduceStrategyTestBase( task_type=task_type, task_id=task_id, num_gpus=num_gpus) - - collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=10 + - CollectiveAllReduceStrategyTestBase.collective_key_base, - op_instance_key_start=100 + - CollectiveAllReduceStrategyTestBase.collective_key_base, - variable_instance_key_start=10000 + - CollectiveAllReduceStrategyTestBase.collective_key_base) - strategy.extended._collective_keys = collective_keys - strategy.extended._cross_device_ops._collective_keys = collective_keys - strategy.extended._host_cross_device_ops._collective_keys = collective_keys - return strategy, target, session_config def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): @@ -598,5 +586,29 @@ class LogicalDeviceTest(test.TestCase, parameterized.TestCase): context._reset_context() # pylint: disable=protected-access +@combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + mode=['eager'])) +class CollectiveAllReduceStrategyV2Test(test.TestCase, parameterized.TestCase): + + def test_replica_id_in_sync_group(self, strategy): + + def replica_fn(): + replica_ctx = distribution_strategy_context.get_replica_context() + return replica_ctx.replica_id_in_sync_group, replica_ctx._replica_id + + results = test_util.gather(strategy, strategy.run(replica_fn)) + self.assertAllEqual(list(range(strategy.extended._num_replicas_in_sync)), + results[0].numpy()) + self.assertAllEqual( + list(range(len(strategy.extended.worker_devices))) * + strategy.extended._num_workers, results[1].numpy()) + + if __name__ == '__main__': - test.main() + test_util.main() diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py index 50fcc1eb14c..c9d3d7d9a9a 100644 --- a/tensorflow/python/distribute/combinations.py +++ b/tensorflow/python/distribute/combinations.py @@ -40,12 +40,12 @@ from tensorflow.python.eager import context from tensorflow.python.framework import combinations as framework_combinations from tensorflow.python.framework import ops from tensorflow.python.framework import test_combinations as combinations_lib +from tensorflow.python.framework import test_util from tensorflow.python.platform import flags from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect - -FLAGS = flags.FLAGS +from tensorflow.python.util.tf_export import tf_export # TODO(rchao): Rename `distribution` parameter to `strategy` or @@ -87,28 +87,42 @@ class ClusterParameters(combinations_lib.ParameterModifier): raise ValueError("Only support one NamedDistribution for multi worker" "tests.") strategy = v + + if strategy: + has_chief = strategy.has_chief + num_workers = strategy.num_workers + runner = strategy.runner + if "has_chief" in kwargs and kwargs["has_chief"] != has_chief: + raise ValueError( + "both has_chief and strategy specified but are not compatible") + if "num_workers" in kwargs and kwargs["num_workers"] != num_workers: + raise ValueError( + "both num_workers and strategy specified but are not compatible") + else: + has_chief = kwargs.get("has_chief", False) + num_workers = kwargs.get("num_workers", 1) + runner = None + # Always set cluster parameters if they're requested. So that generate() # works when there's no startegy in the combinations. update = {} if "has_chief" in requested_parameters: - update["has_chief"] = strategy.has_chief if strategy else False + update["has_chief"] = has_chief if "num_workers" in requested_parameters: - update["num_workers"] = strategy.num_workers if strategy else 1 + update["num_workers"] = num_workers if "runner" in requested_parameters: - update["runner"] = strategy.runner if strategy else None + update["runner"] = runner return update class DistributionCombination(combinations_lib.TestCombination): """Sets up distribution strategy for tests.""" - XLA_TEST = re.search(r"(test_xla|test_xla_gpu)$", sys.argv[0]) - def should_execute_combination(self, kwargs): distributions = [ v for v in kwargs.values() if isinstance(v, NamedDistribution) ] - if self.XLA_TEST and any(d.no_xla for d in distributions): + if test_util.is_xla_enabled() and any(d.no_xla for d in distributions): return ( False, "n/a: skipping strategy combination with no_xla=True in XLA tests") @@ -212,7 +226,7 @@ class TPUCombination(combinations_lib.TestCombination): [d.required_tpu or 0 for d in distributions]) use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] + [d.use_cloud_tpu for d in distributions]) - tpu = hasattr(FLAGS, "tpu") and FLAGS.tpu or "" + tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or "" if not number_of_required_tpus and TPUCombination.TPU_TEST: return (False, "Test that doesn't require TPUs.") @@ -302,15 +316,16 @@ def concat(*combined): return result +@tf_export("__internal__.distribute.combinations.generate", v1=[]) def generate(combinations, test_combinations=()): # pylint: disable=g-doc-args,g-doc-return-or-yield - """Distributed adapter of `framework.combinations_lib.generate`. + """Distributed adapter of `tf.__internal__.test.combinations.generate`. All tests with distributed strategy should use this one instead of - `framework.test_combinations.generate`. This function has support of strategy - combinations, GPU/TPU and multi worker support. + `tf.__internal__.test.combinations.generate`. This function has support of + strategy combinations, GPU/TPU and multi worker support. - See `framework.test_combinations_lib.generate` for usage. + See `tf.__internal__.test.combinations.generate` for usage. """ # pylint: enable=g-doc-args,g-doc-return-or-yield default_combinations = ( @@ -348,11 +363,6 @@ times = combinations_lib.times NamedObject = combinations_lib.NamedObject -def main(): - """Tests must call this main().""" - return multi_process_runner.test_main() - - # Identifies whether we're in the main process or worker processes. # `_multi_worker_test` decoration behaves differently in the main processs and # the worker processes. See the documentation of _multi_worker_test for detail. diff --git a/tensorflow/python/distribute/combinations_test.py b/tensorflow/python/distribute/combinations_test.py index 3fc3735d560..02ddcbef632 100644 --- a/tensorflow/python/distribute/combinations_test.py +++ b/tensorflow/python/distribute/combinations_test.py @@ -25,6 +25,7 @@ import unittest from absl.testing import parameterized from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute.cluster_resolver import tfconfig_cluster_resolver from tensorflow.python.framework import combinations as framework_combinations from tensorflow.python.platform import test @@ -88,6 +89,13 @@ class ClusterCombinationTest(test.TestCase, parameterized.TestCase): # If combinations library doesn't raise an exception, the test is passed. pass + @combinations.generate(combinations.combine(num_workers=2,)) + def testUseWithoutStrategy(self): + # There's no perfect way to check if the test runs in a subprocess. We + # approximate by checking the presence of TF_CONFIG, which is normally not + # set to the main process. + self.assertNotEqual(os.getenv("TF_CONFIG"), "") + # unittest.expectedFailure doesn't work with parameterized test methods, so we # have to decorate the class instead. @@ -106,14 +114,6 @@ class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase): # combinations library should raise an exception. pass - @combinations.generate(combinations.combine(num_workers=2,)) - def testUseWithoutStrategy(self): - # There's no perfect way to check if the test runs in a subprocess. We - # approximate by checking the presence of TF_CONFIG, which is normally not - # set to the main process. - self.assertNotEqual(os.getenv("TF_CONFIG"), "") - raise ValueError("actually run") - # Tests that we *actually* run the test method in multiple workers instead of # just passing silently. More importantly, it verifies that the test can fail. @@ -157,4 +157,4 @@ class CombinationsOnClassMultiWorkerExpectedFailureTest(test.TestCase, if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index c4a4bf6ac0c..6d2a4e16f84 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -36,7 +36,7 @@ from tensorflow.python.distribute import tpu_values from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import def_function -from tensorflow.python.eager import executor +from tensorflow.python.eager import executor as executor_lib from tensorflow.python.framework import kernels from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -144,8 +144,8 @@ def _normalize_value_destination_pairs(value_destination_pairs): def _validate_value_destination_pairs(value_destination_pairs): + """Validates value_destination_pairs are valid.""" # TODO(yuefengz): raise exceptions instead of returning False. - # pylint: disable=g-missing-docstring if not value_destination_pairs: return False if not isinstance(value_destination_pairs, (list, tuple)): return False if not all(isinstance(pair, tuple) for pair in value_destination_pairs): @@ -197,7 +197,7 @@ def simple_broadcast(value, destinations, always_mirrored=False): def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, reduce_op): - # pylint: disable=g-missing-docstring + """Reduces the value by accumulation_fn and reduce_op.""" all_values = per_replica_value.values if not all_values: raise ValueError("`per_replica_value` must be non-empty") @@ -215,6 +215,18 @@ def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, return reduced +def _simple_gather(per_replica_value, reduce_to_device, axis): + """Concatenate all values in the DistributedValues input and return.""" + all_values = per_replica_value.values + if not all_values: + raise ValueError("`per_replica_value` must be non-empty") + + with ops.device(reduce_to_device): + with context.device_policy(context.DEVICE_PLACEMENT_SILENT): + gathered = array_ops.concat(all_values, axis) + return gathered + + @tf_export("distribute.CrossDeviceOps") class CrossDeviceOps(object): """Base class for cross-device reduction and broadcasting algorithms. @@ -317,9 +329,6 @@ class CrossDeviceOps(object): `tf.distribute.DistributedValues` or if destinations is not a string, `tf.Variable` or `tf.distribute.DistributedValues`. """ - if isinstance(per_replica_value, ops.IndexedSlices): - raise NotImplementedError("gather/all_gather does not support " - "IndexedSlices") if experimental_hints is None: experimental_hints = collective_util.Hints() @@ -562,6 +571,20 @@ class ReductionToOneDevice(CrossDeviceOps): self.accumulation_fn, reduce_op) return self.broadcast(reduced, destinations) + def _gather_implementation(self, per_replica_value, destinations, axis, + experimental_hints): + del experimental_hints # Unused. + if check_destinations(destinations): + devices = get_devices_from(destinations) + else: + devices = get_devices_from(per_replica_value) + reduce_to_device = self.reduce_to_device or devices[0] + logging.log_first_n( + logging.INFO, + "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10) + gathered = _simple_gather(per_replica_value, reduce_to_device, axis) + return self.broadcast(gathered, destinations) + def batch_reduce_implementation(self, reduce_op, value_destination_pairs, experimental_hints): return [ @@ -857,6 +880,15 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): return self._simple_cross_replica_ops.batch_reduce( reduce_op, zip(sparse_values, sparse_values)) + def _gather_implementation(self, per_replica_value, destinations, axis, + experimental_hints): + logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not " + "supported. Falling back to gather on one device and " + "then broadcast. We're working on a more efficient " + "implementation.") + return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access + experimental_hints) + # For compatibility with code using the old name of `AllReduceCrossDeviceOps`. AllReduceCrossTowerOps = AllReduceCrossDeviceOps @@ -987,7 +1019,6 @@ class CollectiveAllReduce(CrossDeviceOps): if group_size % len(devices) > 0: raise ValueError("group_size must be divisible by the number of devices.") - self._devices = tuple(device_util.canonicalize(d) for d in devices) self._group_size = group_size self._collective_keys = (collective_keys or cross_device_utils.CollectiveKeys()) @@ -1007,14 +1038,21 @@ class CollectiveAllReduce(CrossDeviceOps): # This deadlocks since neither collective is able to finish. self._lock = threading.Lock() + self._devices = tuple(device_util.canonicalize(d) for d in devices) + group_key = self._collective_keys.get_group_key(self._devices) # Collective ops requires all devices to participate and is blocking. In # eager, we need one async executor for each device to be able to launch # them altogether. Note that async doesn't imply concurrency. Within an # async executor operations are still executed sequentially. In graph or # function building, the executors are not used. self._executors = [] - for _ in range(len(devices)): - self._executors.append(executor.new_executor(enable_async=True)) + self._launchers = [] + for device in self._devices: + executor = executor_lib.new_executor(enable_async=True) + self._executors.append(executor) + launcher = cross_device_utils.CollectiveReplicaLauncher( + group_key, group_size, self._collective_keys, device, executor) + self._launchers.append(launcher) super(CollectiveAllReduce, self).__init__() @@ -1098,9 +1136,11 @@ class CollectiveAllReduce(CrossDeviceOps): batch_size = len(per_replica_values) # Pass self._communication to the runtime as a communication hint. communication = self._communication.value - # For now, we use NCCL only when batch_size > 1. + # For now, we use NCCL only when batch_size > 1 since we don't have a way to + # order NCCL launches. We're hoping that there's only one batched + # all-reduce, which is the gradients. # TODO(b/132575814): switch to NCCL for all collectives when communication - # is NCCL. + # is NCCL if and only if we can order collectives deterministically. if self._communication == CollectiveCommunication.NCCL and batch_size == 1: communication = CollectiveCommunication.AUTO.value @@ -1114,63 +1154,39 @@ class CollectiveAllReduce(CrossDeviceOps): # queuing time due to concurrent intense computation. # # TODO(b/147393503): explore solutions for optimal ordering. - packs = cross_device_utils.pack_by_size( - list(reversed(per_replica_values)), experimental_hints.bytes_per_pack) + values_by_device = [[] for _ in range(len(self._devices))] + for per_replica in reversed(per_replica_values): + for i in range(len(self._devices)): + values_by_device[i].append(per_replica.values[i]) - if batch_size > 1: - logging.info( - "Collective batch_all_reduce: %d all-reduces, num_devices = %d, " - "group_size = %d, communication_hint = %s, num_packs = %d", - batch_size, len(self._devices), self._group_size, communication, - len(packs)) - else: - logging.log_first_n( - logging.INFO, "Collective batch_all_reduce: %d all-reduces, " - "num_devices = %d, group_size = %d, communication_hint = %s, " - "num_packs = %d" % (batch_size, len( - self._devices), self._group_size, communication, len(packs)), 10) - - reduced_values = [] + outputs_by_device = [] with self._lock: - for pack in packs: - # By placing all CollectiveReduce ops in a pack under single name scope, - # we ensure they will be picked up by the `ScopedAllocator` grappler - # optimizer and packed into a single all-reduce. - with ops.name_scope("allreduce"): - for per_replica in pack: - # Add control dependencies per device from the last gradients to the - # current set, in order to serialize NCCL launches. - if (communication == CollectiveCommunication.NCCL.value and - reduced_values): - control_inputs = list(reduced_values[-1]) - else: - control_inputs = None - reduced_values.append( - cross_device_utils.build_collective_reduce( - per_replica.values, - self._devices, - self._group_size, - self._collective_keys, - "Add", - "Id", - communication, - control_inputs, - executors=self._executors, - timeout=experimental_hints.timeout_seconds)) + for i in range(len(self._devices)): + packs = cross_device_utils.group_by_size( + values_by_device[i], experimental_hints.bytes_per_pack) + if not context.executing_eagerly() and i == 0: + logging.info( + "Collective batch_all_reduce: %d all-reduces, num_devices = %d, " + "group_size = %d, communication_hint = %s, num_packs = %d", + batch_size, len(self._launchers), self._group_size, communication, + len(packs)) + outputs_by_device.append(self._launchers[i].batch_all_reduce( + packs, communication, experimental_hints.timeout_seconds)) for e in self._executors: e.wait() mirrored = [] - # Reverse the order of reduced value to recover the order in the input. - for value in reversed(reduced_values): + for values in zip(*outputs_by_device): if reduce_op == reduce_util.ReduceOp.MEAN: - for i, v in enumerate(value): + values = list(values) + for i, v in enumerate(values): with ops.device(v.device): - value[i] = v / self._group_size + values[i] = v / self._group_size mirrored.append( - distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) - return mirrored + distribute_utils.regroup(values, wrap_class=value_lib.Mirrored)) + # Reverse the order of reduced value to recover the order in the input. + return list(reversed(mirrored)) def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, experimental_hints): @@ -1183,24 +1199,16 @@ class CollectiveAllReduce(CrossDeviceOps): # Pass self._communication to the runtime as a communication hint. communication_hint = self._communication.value - # For now, we use NCCL only when batch_size > 1. - # TODO(b/132575814): switch to NCCL for all collectives when communication - # is NCCL. - if self._communication == CollectiveCommunication.NCCL and len( - per_replica_values) == 1: - communication_hint = CollectiveCommunication.AUTO.value gathered_values = [] - with self._lock, ops.name_scope("allreduce"): + with self._lock: for per_replica in per_replica_values: - gathered_values.append( - cross_device_utils.build_collective_gather_indexed_slices( - per_replica.values, - self._devices, - self._group_size, - self._collective_keys, - communication_hint, - timeout=experimental_hints.timeout_seconds)) + outputs = [] + for i in range(len(self._devices)): + outputs.append(self._launchers[i].all_reduce_indexed_slices( + per_replica.values[i], communication_hint, + experimental_hints.timeout_seconds)) + gathered_values.append(outputs) mirrored = [] for value in gathered_values: @@ -1247,11 +1255,6 @@ class CollectiveAllReduce(CrossDeviceOps): batch_size = len(per_replica_values) # Pass self._communication to the runtime as a communication hint. communication = self._communication.value - # For now, we use NCCL only when batch_size > 1. - # TODO(b/132575814): switch to NCCL for all collectives when communication - # is NCCL. - if self._communication == CollectiveCommunication.NCCL and batch_size == 1: - communication = CollectiveCommunication.AUTO.value logging.log_first_n( logging.INFO, "Collective batch_all_gather: %d all-gathers, " @@ -1262,21 +1265,12 @@ class CollectiveAllReduce(CrossDeviceOps): gathered_values = [] with self._lock, ops.name_scope("allgather"): for per_replica in per_replica_values: - if (communication == CollectiveCommunication.NCCL.value and - gathered_values): - control_inputs = list(gathered_values[-1]) - else: - control_inputs = None - gathered_values.append( - cross_device_utils.build_collective_gather( - per_replica.values, - self._devices, - self._group_size, - self._collective_keys, - axis, - communication, - control_inputs, - timeout=experimental_hints.timeout_seconds)) + outputs = [] + for i in range(len(self._devices)): + outputs.append(self._launchers[i].all_gather( + per_replica.values[i], axis, communication, + experimental_hints.timeout_seconds)) + gathered_values.append(outputs) return gathered_values if context.executing_eagerly(): @@ -1285,8 +1279,7 @@ class CollectiveAllReduce(CrossDeviceOps): gathered_values = compute_gathered_values() mirrored = [] - # Reverse the order of gathered value to recover the order in the input. - for value in reversed(gathered_values): + for value in gathered_values: mirrored.append( distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) return mirrored @@ -1299,8 +1292,8 @@ class CollectiveAllReduce(CrossDeviceOps): self._communication) -def choose_the_best(devices, session_config=None): - """Find the best CrossDeviceOps locally given a `tf.compat.v1.ConfigProto`. +def select_cross_device_ops(devices, session_config=None): + """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`. Args: devices: a list of devices passed to `tf.distribute.Strategy`. diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 2ca046c9d18..461b32d57b0 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -20,12 +20,15 @@ from __future__ import print_function import collections import os +import threading +import time from absl.testing import parameterized from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib +from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import multi_process_runner @@ -37,21 +40,26 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import collective_ops from tensorflow.python.util import nest CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication ReduceOp = reduce_util.ReduceOp +IndexedSlicesValue = indexed_slices.IndexedSlicesValue +IndexedSlices = indexed_slices.IndexedSlices -def make_per_replica_value(value_fn, devices): +def make_per_replica_value(value, devices): """Creates a `PerReplica` object whose values reside in `devices`. Args: - value_fn: a callable that takes one argument (`device_idx`) and should - return the value that is going to be created on devices[device_idx]. + value: a tensor-convertible value or a `IndexedSlicesValue`, or a callable + that takes one argument (`device_idx`) and should return the value that is + going to be created on devices[device_idx]. devices: a list of device strings to create `PerReplica` values on. Returns: @@ -59,11 +67,11 @@ def make_per_replica_value(value_fn, devices): """ values = [] for device_idx, device in enumerate(devices): - v = value_fn(device_idx) - if isinstance(v, indexed_slices.IndexedSlicesValue): + v = value(device_idx) if callable(value) else value + if isinstance(v, IndexedSlicesValue): with ops.device(device): values.append( - indexed_slices.IndexedSlices( + IndexedSlices( values=array_ops.identity(v.values), indices=array_ops.identity(v.indices), dense_shape=array_ops.identity(v.dense_shape))) @@ -102,7 +110,7 @@ class MultiProcessPoolRunner(): # expensive initialization cost of TensorFlow in new processes. # # Note that they have to be globals and can't be owned by test classes because -# usually proc_func usually captures the test class instance, and test class +# usually fn usually captures the test class instance, and test class # instance can't be pickled if it has mpr as a member (it is not allowed to # pickle Process objects). # TODO(crccw): Use `num_workers` combination once it is ready. @@ -179,7 +187,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): """ if isinstance(value, ops.Tensor): return [value] - elif isinstance(value, indexed_slices.IndexedSlices): + elif isinstance(value, IndexedSlices): return [value] elif isinstance(value, value_lib.Mirrored): return value.values @@ -336,27 +344,27 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = [ - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[3.], [4.]], indices=[1, 2], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[5.], [6.]], indices=[7, 8], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[7.], [8.]], indices=[3, 2], dense_shape=[10, 1]), ] inputs = inputs_data[0:group_size] if group_size == 1: - expect = indexed_slices.IndexedSlices( + expect = IndexedSlices( values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]) elif group_size == 2: - expect = indexed_slices.IndexedSlices( + expect = IndexedSlices( values=[[1.], [2.], [3.], [4.]], indices=[0, 1, 1, 2], dense_shape=[10, 1]) elif group_size == 4: - expect = indexed_slices.IndexedSlices( + expect = IndexedSlices( values=[[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]], indices=[0, 1, 1, 2, 7, 8, 3, 2], dense_shape=[10, 1]) @@ -366,12 +374,11 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): def testAllReduceSparseVariableLength(self): # One device per process, 2 processes, 2 replicas in total. inputs = [ - indexed_slices.IndexedSlicesValue( - values=[[1.]], indices=[0], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]), + IndexedSlicesValue( values=[[2.], [3.], [4.]], indices=[0, 1, 2], dense_shape=[10, 1]), ] - expect = indexed_slices.IndexedSlices( + expect = IndexedSlices( values=[[1.], [2.], [3.], [4.]], indices=[0, 0, 1, 2], dense_shape=[10, 1]) @@ -447,53 +454,53 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = ([ - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1]) ], [ - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[5.], [6.]], indices=[1, 2], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[7.], [8.]], indices=[0, 1], dense_shape=[5, 1]) ], [ - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[9.], [10.]], indices=[3, 4], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[11.], [12.]], indices=[3, 4], dense_shape=[5, 1]) ], [ - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[13.], [14.]], indices=[8, 9], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[15.], [16.]], indices=[3, 4], dense_shape=[5, 1]) ]) inputs = inputs_data[0:group_size] if group_size == 1: expect = [ - indexed_slices.IndexedSlices( + IndexedSlices( values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), - indexed_slices.IndexedSlicesValue( + IndexedSlicesValue( values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1]) ] if group_size == 2: expect = [ - indexed_slices.IndexedSlices( + IndexedSlices( values=[[1.], [2.], [5.], [6.]], indices=[0, 1, 1, 2], dense_shape=[10, 1]), - indexed_slices.IndexedSlices( + IndexedSlices( values=[[3.], [4.], [7.], [8.]], indices=[1, 2, 3, 4], dense_shape=[5, 1]) ] elif group_size == 4: expect = [ - indexed_slices.IndexedSlices( + IndexedSlices( values=[[1.], [2.], [5.], [6.], [9.], [10.], [13.], [14.]], indices=[0, 1, 1, 2, 3, 4, 8, 9], dense_shape=[10, 1]), - indexed_slices.IndexedSlices( + IndexedSlices( values=[[3.], [4.], [7.], [8.], [11.], [12.], [15.], [16.]], indices=[1, 2, 0, 1, 3, 4, 3, 4], dense_shape=[5, 2]) @@ -521,8 +528,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32) def gather_fn(): - value_fn = lambda device_idx: value - per_replica_value = make_per_replica_value(value_fn, devices) + per_replica_value = make_per_replica_value(value, devices) gathered_values = collective._gather( per_replica_value, per_replica_value, axis=axis) gathered_values = self.as_list(gathered_values) @@ -546,6 +552,219 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): get_global_mpr(num_processes).run(replica_fn) + @combinations.generate( + combinations.combine( + num_processes=1, + required_gpus=2, + communication=[ + CollectiveCommunication.NCCL, CollectiveCommunication.RING + ])) + def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes, + required_gpus, + communication): + + def replica_fn(): + collective, devices, _ = self.make_collective(num_processes, + required_gpus, + communication) + + # We would like to simulate the following sequence: + # thread-0 device0 device1 + # thread-1 device0 device1 + # If the kernel launch sequence is as-is the program will deadlock since + # NCCL requires the launch order to be same on each device. + v0 = make_per_replica_value(1.0, devices) + v1 = make_per_replica_value(2.0, devices) + + # Add a delay to collective_ops.all_reduce according to the input tensors + # index in `sequence.` + sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]] + all_reduce = collective_ops.all_reduce + + def delayed_all_reduce(input_tensor, *args, **kwargs): + for idx, v in enumerate(sequence): + if input_tensor is v: + time.sleep(idx) + break + return all_reduce(input_tensor, *args, **kwargs) + + with test.mock.patch.object(collective_ops, "all_reduce", + delayed_all_reduce): + # We only use NCCL for batch reduce with two or more values, so we use + # two values here. + + def thread_fn(): + reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, + [(v0, v0), (v0, v0)]) + self.assertAllEqual(reduced[0].values, [2.0, 2.0]) + self.assertAllEqual(reduced[1].values, [2.0, 2.0]) + + t = threading.Thread(target=thread_fn) + t.start() + reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1), + (v1, v1)]) + self.assertAllEqual(reduced[0].values, [4.0, 4.0]) + self.assertAllEqual(reduced[1].values, [4.0, 4.0]) + t.join() + + get_global_mpr(num_processes).run(replica_fn) + + @combinations.generate( + combinations.combine( + num_processes=1, + required_gpus=2, + communication=[ + CollectiveCommunication.NCCL, CollectiveCommunication.RING + ])) + def testInputsAreFunctionArgs(self, num_processes, required_gpus, + communication): + + def replica_fn(): + collective, devices, _ = self.make_collective(num_processes, + required_gpus, + communication) + + @def_function.function + def reduce_fn(v): + # Function inputs don't have device placement. + self.assertEqual(v.values[0].device, "") + self.assertEqual(v.values[1].device, "") + # We only use NCCL for batch reduce with two or more values, so we use + # two values here. + reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), + (v, v)]) + self.assertEqual(reduced[0].values[0].device, devices[0]) + self.assertEqual(reduced[0].values[1].device, devices[1]) + self.assertEqual(reduced[1].values[0].device, devices[0]) + self.assertEqual(reduced[1].values[1].device, devices[1]) + # Returning Mirrored only evaluates the primary value, which causes + # hanging, + return [reduced[0].values, reduced[1].values] + + v = make_per_replica_value(1.0, devices) + reduced = reduce_fn(v) + self.assertAllClose(reduced, [[2.0, 2.0], [2.0, 2.0]]) + + get_global_mpr(num_processes).run(replica_fn) + + @combinations.generate( + combinations.combine( + num_processes=2, + required_gpus=[0, 1], + communication=[CollectiveCommunication.RING])) + def testTimeoutReduceDense(self, num_processes, communication, required_gpus): + + def replica_fn(): + collective, devices, task_id = self.make_collective( + num_processes, required_gpus, communication) + if task_id != 0: + return + + v = make_per_replica_value(1.0, devices) + hints = collective_util.Hints(timeout_seconds=1) + + @def_function.function + def reduce_dense(): + collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) + + # The collective should time out because we only launch it on worker-0, + # while there're three workers in total. + with self.assertRaises(errors.DeadlineExceededError): + reduce_dense() + + get_global_mpr(num_processes).run(replica_fn) + + @combinations.generate( + combinations.combine( + num_processes=2, + required_gpus=[0, 1], + communication=[CollectiveCommunication.RING])) + def testTimeoutBatchReduceDense(self, num_processes, communication, + required_gpus): + + def replica_fn(): + collective, devices, task_id = self.make_collective( + num_processes, required_gpus, communication) + if task_id != 0: + return + + v = make_per_replica_value(1.0, devices) + hints = collective_util.Hints(timeout_seconds=1) + + @def_function.function + def batch_reduce_dense(): + collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], + hints) + + # The collective should time out because we only launch it on worker-0, + # while there're two workers in total. + with self.assertRaises(errors.DeadlineExceededError): + batch_reduce_dense() + + get_global_mpr(num_processes).run(replica_fn) + + @combinations.generate( + combinations.combine( + num_processes=2, + required_gpus=[0, 1], + communication=[CollectiveCommunication.RING])) + def testTimeoutReduceSparse(self, num_processes, communication, + required_gpus): + + def replica_fn(): + collective, devices, task_id = self.make_collective( + num_processes, required_gpus, communication) + if task_id != 0: + return + + v = make_per_replica_value( + IndexedSlicesValue( + values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) + hints = collective_util.Hints(timeout_seconds=1) + + @def_function.function + def reduce_sparse(): + collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) + + # The collective should time out because we only launch it on worker-0, + # while there're two workers in total. + with self.assertRaises(errors.DeadlineExceededError): + reduce_sparse() + + get_global_mpr(num_processes).run(replica_fn) + + @combinations.generate( + combinations.combine( + num_processes=2, + required_gpus=[0, 1], + communication=[CollectiveCommunication.RING])) + def testTimeoutBatchReduceSparse(self, num_processes, required_gpus, + communication): + + def replica_fn(): + collective, devices, task_id = self.make_collective( + num_processes, required_gpus, communication) + if task_id != 0: + return + + v = make_per_replica_value( + IndexedSlicesValue( + values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) + hints = collective_util.Hints(timeout_seconds=1) + + @def_function.function + def batch_reduce_sparse(): + collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], + hints) + + # The collective should time out because we only launch it on worker-0, + # while there're two workers in total. + with self.assertRaises(errors.DeadlineExceededError): + batch_reduce_sparse() + + get_global_mpr(num_processes).run(replica_fn) + + if __name__ == "__main__": # Set default inter op thread pool size to one to ensure we don't exhaust the # thread pool with the additional executors to run collectives in eager. diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index cb442f4a53a..3920c2af6b0 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -24,7 +24,6 @@ import threading from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -34,7 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nccl_ops from tensorflow.python.platform import tf_logging as logging -OP_INSTANCE_KEY_START_NUMBER = 100 +INSTANCE_KEY_START_NUMBER = 100 def aggregate_gradients_using_nccl(replica_grads): @@ -181,69 +180,66 @@ class CollectiveKeys(object): *Instance key*: an integer key to identify the set of same counterpart of tensors on different devices in a device group that need to be all-reduced. - "Graph key": an integer key that is unique key graph. This is used to support - multiple graphs per client session. It must be non-zero and set in the - `config` argument of each call to `session.run`. - This class is thread safe. """ - def __init__(self, - group_key_start=1, - op_instance_key_start=OP_INSTANCE_KEY_START_NUMBER, - variable_instance_key_start=1000000): + def __init__(self, group_key_start=1): """Initializes the object. Args: group_key_start: the starting integer of group key. - op_instance_key_start: the starting integer of instance key for ops. - variable_instance_key_start: the starting integer of instance key for - variables. """ self._group_key = group_key_start self._group_key_table = {} - - assert op_instance_key_start != variable_instance_key_start - self._op_instance_key = op_instance_key_start - self._variable_instance_key = variable_instance_key_start + self._instance_key_table = {} self._lock = threading.Lock() def get_group_key(self, devices): """Returns a group key for the set of devices. Args: - devices: list of strings naming devices in a collective group. + devices: a list of canonical device strings in a collective group. Returns: int key uniquely identifying the set of device names. """ - parsed = [pydev.DeviceSpec.from_string(d) for d in devices] - # In the between-graph replicated training, different workers need to get - # the same device key. So we remove the task_type and task_id from the - # devices. - # TODO(yuefengz): in the in-graph replicated training, we need to include - # task_type and task_id. - names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) - key_id = ','.join(names) + key_id = hash(tuple(sorted(devices))) with self._lock: if key_id not in self._group_key_table: new_key = self._group_key self._group_key += 1 self._group_key_table[key_id] = new_key + self._instance_key_table[new_key] = {} + for device in devices: + self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER return self._group_key_table[key_id] - def get_op_instance_key(self): - """Returns a new instance key for use in defining a collective op.""" - with self._lock: - v = self._op_instance_key - self._op_instance_key += 1 - return v + def get_instance_key(self, group_key, device): + """Returns a new instance key for use in defining a collective op. - def get_variable_instance_key(self): - """Returns a new instance key for use in creating a Variable.""" + You should call this once per each collective op of a collective instance. + + Args: + group_key: the group key returned by get_group_key(). You should not + assign the group key yourself. + device: a canonical device string. It should be the device this collective + op is on. + + Returns: + a new instance key. + + Raises: + ValueError: when the group key is invalid or the device is not in the + group. + """ with self._lock: - v = self._variable_instance_key - self._variable_instance_key += 1 + group = self._instance_key_table.get(group_key, None) + if group is None: + raise ValueError('group {} not found'.format(group_key)) + if device not in group: + raise ValueError('{} not in group {}'.format(device, group_key)) + v = group[device] + group[device] += 1 return v def __deepcopy__(self, memo): @@ -252,135 +248,146 @@ class CollectiveKeys(object): copied = CollectiveKeys() copied._group_key = self._group_key copied._group_key_table = copy.deepcopy(self._group_key_table, memo) - copied._op_instance_key = self._op_instance_key - copied._variable_instance_key = self._variable_instance_key + copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo) return copied -def build_collective_reduce(input_tensors, - devices, - group_size, - collective_keys, - reduction_op='Add', - unary_op='Id', - communication_hint='AUTO', - control_inputs=None, - executors=None, - timeout=None): - """Build a subgraph that does one full all-reduce, using the collective Op. +class CollectiveReplicaLauncher(object): + """Launch collectives on one replica.""" - If called in eager mode, it's required to supply a list of async executors for - each input Tensor. + def __init__(self, + group_key, + group_size, + collective_keys, + device, + executor=None): + if executor and not executor.is_async(): + raise ValueError('executor must be async') + self._group_key = group_key + self._group_size = group_size + self._collective_keys = collective_keys + self._device = device + self._executor = executor - Args: - input_tensors: tensors within a single worker graph that are to be reduced - together; must be one per device. - devices: a list of device strings to run the collective on. - group_size: total number of devices globally that will be doing this same - reduction. The reduction will actually include the corresponding tensors - at all these workers. - collective_keys: a CollectiveKeys object. - reduction_op: string naming the reduction op. - unary_op: string naming the unary final op. - communication_hint: string providing hint to runtime for choosing collective - implementation. - control_inputs: if not None, add control edges between control_inputs and - (index-wise) corresponding collective_reduce tensors - executors: a list of async executor. Required for eager execution. - timeout: a float or None. The timeout in seconds. - - Returns: - An array of final tensors, one per device, computed by the full reduction. - - Raises: - ValueError: There must be at least two tensors over all the workers. - """ - if context.executing_eagerly(): - if (not executors or len(executors) != len(input_tensors) or - not all(e.is_async() for e in executors)): - raise ValueError( - 'collectives requires async executors for each device in eager mode') - if len(input_tensors) != len(devices): - raise ValueError('collective requires one input tensor for each device, ' - 'len(input_tensors) = %d, len(devices) = %d' % - (len(input_tensors), len(devices))) - - if group_size < 2: - return input_tensors - group_key = collective_keys.get_group_key(devices) - instance_key = collective_keys.get_op_instance_key() - subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec - - out_tensors = [] - for idx, input_tensor in enumerate(input_tensors): + def _executor_scope(self): + if context.executing_eagerly() and not self._executor: + raise ValueError('collectives requires a async executor in eager mode') if context.executing_eagerly(): - executor_scope = context.executor_scope(executors[idx]) - else: - executor_scope = ops.NullContextmanager() - with executor_scope, \ - ops.device(devices[idx]), \ - ops.control_dependencies( - _control_input(devices, control_inputs, idx)): - out_tensor = collective_ops.all_reduce( + return context.executor_scope(self._executor) + return ops.NullContextmanager() + + def _control_input(self, control_input): + if control_input is not None: + return ops.control_dependencies([control_input]) + return ops.NullContextmanager() + + def all_reduce(self, + input_tensor, + control_input=None, + communication_hint='AUTO', + timeout=0): + """All-reduce a dense tensor. + + This can be called in eager mode if a async executor is supplied when + creating the launcher. + + Args: + input_tensor: a dense tensor. It must have the same shape on all replicas. + control_input: if not None, add control edges between control_input and + the all-reduce. + communication_hint: string providing hint to runtime for choosing + collective implementation. + timeout: a float. The timeout in seconds. + + Returns: + The reduced tensor. + """ + instance_key = self._collective_keys.get_instance_key( + self._group_key, self._device) + with self._executor_scope(), \ + ops.device(self._device), \ + self._control_input(control_input): + return collective_ops.all_reduce( input_tensor, - group_size, - group_key, + self._group_size, + self._group_key, instance_key, - reduction_op, - unary_op, - subdiv_offsets, - communication_hint, + communication_hint=communication_hint, timeout=timeout) - out_tensors.append(out_tensor) - return out_tensors + def batch_all_reduce(self, + input_tensor_packs, + communication_hint='AUTO', + timeout=0): + """Batch all-reduce dense tensors. -def build_collective_gather(input_tensors, - devices, - group_size, - collective_keys, - axis, - communication_hint='AUTO', - control_inputs=None, - timeout=None): - """Build a subgraph that does one full all-gather, using the collective Op. + This takes a list of batches of tensors. Using multiple batches have the + benefit that it doesn't need to wait for all inputs to be ready to start the + all-reduce. - This method must be called in graph mode or inside a tf.function. + This can be called in eager mode if a async executor is supplied when + creating the launcher. - Args: - input_tensors: tensors within a single worker graph that are to be gathered - together; must be one per device. Input tensors cannot have rank 0. - devices: a list of device strings to run the collective on. - group_size: total number of devices globally that will be doing this same - gathering. The gathering will actually include the corresponding tensors - at all these workers. - collective_keys: a CollectiveKeys object. - axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the - range [0, rank(value)). - communication_hint: string providing hint to runtime for choosing collective - implementation. Available options are `AUTO`, `NCCL`, and `RING`. - control_inputs: if not None, add control edges between control_inputs and - (index-wise) corresponding collective_gather tensors - timeout: a float or None. The timeout in seconds. + Args: + input_tensor_packs: a list of lists of dense tensors. + communication_hint: string providing hint to runtime for choosing + collective implementation. + timeout: a float. The timeout in seconds. - Returns: - An array of final tensors, one per device, computed by the full gather. - """ - if len(input_tensors) != len(devices): - raise ValueError( - 'collective requires one input tensor for each device, %d != %d' % - (len(input_tensors), len(devices))) + Returns: + A flat list of reduced tensors. + """ + outputs = [] + for pack in input_tensor_packs: + # By placing all CollectiveReduce ops in a batch under single name scope, + # we ensure they will be picked up by the `ScopedAllocator` grappler + # optimizer and packed into a single all-reduce. + with ops.name_scope('allreduce'): + # TODO(b/169168846): inserts a parallel all_gather to verify packings + # are the same on each replica. + for input_tensor in pack: + if communication_hint == 'NCCL' and outputs: + control_input = outputs[-1] + else: + control_input = None + outputs.append( + self.all_reduce(input_tensor, control_input, communication_hint, + timeout)) + return outputs - if group_size < 2: - return input_tensors - group_key = collective_keys.get_group_key(devices) - instance_key_tensor = collective_keys.get_op_instance_key() - instance_key_shape = collective_keys.get_op_instance_key() + def all_gather(self, + input_tensor, + axis, + communication_hint='AUTO', + timeout=0): + """All-gather a dense tensor. - out_tensors = [] - for idx, input_tensor in enumerate(input_tensors): - with ops.device(devices[idx]), ops.control_dependencies( - _control_input(devices, control_inputs, idx)): + This method must be called inside a tf.function. + + Args: + input_tensor: a dense tensor. It must have the same rank on all replicas, + and dimensions other than `axis` need to be the same as well. + axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the + range [0, rank(value)). + communication_hint: string providing hint to runtime for choosing + collective implementation. Available options are `AUTO`, `NCCL`, and + `RING`. + timeout: a float. The timeout in seconds. + + Returns: + The gathered Tensor. + + Raises: + RuntimeError: if called in eager mode. + """ + if context.executing_eagerly(): + raise RuntimeError('all_gather in eager mode is not supported') + + instance_key_tensor = self._collective_keys.get_instance_key( + self._group_key, self._device) + instance_key_shape = self._collective_keys.get_instance_key( + self._group_key, self._device) + with ops.device(self._device): # 1. Transpose # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which @@ -394,8 +401,8 @@ def build_collective_gather(input_tensors, # 2. Pad gathered_shape = collective_ops.all_gather( array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), - group_size, - group_key, + self._group_size, + self._group_key, instance_key_shape, communication_hint, timeout=timeout) @@ -406,8 +413,8 @@ def build_collective_gather(input_tensors, # 3. Gather gather_padded_out_tensor = collective_ops.all_gather( padded_input_tensor, - group_size, - group_key, + self._group_size, + self._group_key, instance_key_tensor, communication_hint, timeout=timeout) @@ -424,93 +431,56 @@ def build_collective_gather(input_tensors, (math_ops.range(1, axis + 1), [0], math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), axis=0) - out_tensor = array_ops.transpose(out_tensor_t, perm=perm_after) - out_tensors.append(out_tensor) - return out_tensors + return array_ops.transpose(out_tensor_t, perm=perm_after) + def all_reduce_indexed_slices(self, + input_slices, + communication_hint='AUTO', + timeout=0): + """All-reduce an IndexedSlices. -def _pad_util(input_tensor, full_axis_dim): - """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" - missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] - tensor_rank = array_ops.rank(input_tensor) - paddings_axis = [[0, missing_axis_dim]] - paddings = array_ops.concat([ - paddings_axis, - array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) - ], - axis=0) - padded_input_tensor = array_ops.pad(input_tensor, paddings) - return padded_input_tensor + This method must be called inside a tf.function. + Args: + input_slices: an IndexedSlices. + communication_hint: string providing hint to runtime for choosing + collective implementation. + timeout: a float. The timeout in seconds. -def build_collective_gather_indexed_slices(input_slices_list, - devices, - group_size, - collective_keys, - communication_hint='AUTO', - control_inputs=None, - timeout=None): - """Build a subgraph that all-gathers IndexedSlices using the collective Op. + Returns: + The reduced IndexedSlices. - This method must be called in graph mode or inside a tf.function. + Raises: + RuntimeError: if called in eager mode. + """ + if context.executing_eagerly(): + raise RuntimeError( + 'all_reduce_indexed_slices in eager mode is not supported') - Args: - input_slices_list: a list of IndexedSlices within a single worker graph that - are to be gathered together; must be one per device. - devices: a list of device strings to run the collective on. - group_size: total number of devices globally that will be doing this same - gathering. The gathering will actually include the corresponding tensors - at all these workers. - collective_keys: a CollectiveKeys object. - communication_hint: string providing hint to runtime for choosing collective - implementation. - control_inputs: if not None, add control edges between control_inputs and - (index-wise) corresponding collective_reduce tensors - timeout: a float or None. The timeout in seconds. + gather_length_key = self._collective_keys.get_instance_key( + self._group_key, self._device) + gather_indices_key = self._collective_keys.get_instance_key( + self._group_key, self._device) + gather_values_key = self._collective_keys.get_instance_key( + self._group_key, self._device) + reduce_densified_key = self._collective_keys.get_instance_key( + self._group_key, self._device) - Returns: - An array of final IndexedSlices, one per device, computed by the full - gather. - - Raises: - ValueError: if control_inputs is not None and doesn't match the length and - devices of inputs. - """ - assert not context.executing_eagerly(), ( - 'build_collective_gather_indexed_slices can only be called in graph mode' - ' or inside tf.function') - if len(input_slices_list) != len(devices): - raise ValueError( - 'collective requires one input IndexedSlice for each device, %d != %d' % - (len(input_slices_list), len(devices))) - - if group_size < 2: - return input_slices_list - - group_key = collective_keys.get_group_key(devices) - gather_length_key = collective_keys.get_op_instance_key() - gather_indices_key = collective_keys.get_op_instance_key() - gather_values_key = collective_keys.get_op_instance_key() - reduce_densified_key = collective_keys.get_op_instance_key() - - # Current CollectiveAllGather implementations require input IndexedSlices to - # have consistent length across the board, we handle the reduction of - # IndexedSlices as follows: - # 1. Gather the lengths of IndexedSlices from all participants. - # 2. If they have consistent length, apply all_gather. - # 3. Otherwise convert IndexedSlices to dense tensors and apply - # all_reduce. - out_slices_list = [] - for idx, input_slices in enumerate(input_slices_list): - # pylint: disable = cell-var-from-loop - with ops.device(devices[idx]): + # Current CollectiveAllGather implementations require input IndexedSlices to + # have consistent length across the board, we handle the reduction of + # IndexedSlices as follows: + # 1. Gather the lengths of IndexedSlices from all participants. + # 2. If they have consistent length, apply all_gather. + # 3. Otherwise convert IndexedSlices to dense tensors and apply + # all_reduce. + with ops.device(self._device): def all_gather(): """Use all_gather to aggregate `IndexedSlices`.""" all_values = collective_ops.all_gather( input_slices.values, - group_size, - group_key, + self._group_size, + self._group_key, gather_values_key, communication_hint, timeout=timeout) @@ -519,8 +489,8 @@ def build_collective_gather_indexed_slices(input_slices_list, with ops.control_dependencies(control): all_indices = collective_ops.all_gather( input_slices.indices, - group_size, - group_key, + self._group_size, + self._group_key, gather_indices_key, communication_hint, timeout=timeout) @@ -534,8 +504,8 @@ def build_collective_gather_indexed_slices(input_slices_list, densified = ops.convert_to_tensor(input_slices) reduced = collective_ops.all_reduce( densified, - group_size, - group_key, + self._group_size, + self._group_key, reduce_densified_key, 'Add', 'Id', [0], @@ -550,23 +520,18 @@ def build_collective_gather_indexed_slices(input_slices_list, dense_shape=input_slices.dense_shape) length = array_ops.shape(input_slices.indices) - with ops.control_dependencies( - _control_input(input_slices, control_inputs, idx)): - all_lengths = collective_ops.all_gather( - length, - group_size, - group_key, - gather_length_key, - communication_hint, - timeout=timeout) - out_slices = control_flow_ops.cond( + all_lengths = collective_ops.all_gather( + length, + self._group_size, + self._group_key, + gather_length_key, + communication_hint, + timeout=timeout) + return control_flow_ops.cond( math_ops.equal( math_ops.reduce_max(all_lengths), math_ops.reduce_min(all_lengths)), all_gather, densify_and_all_reduce) - out_slices_list.append(out_slices) - # pylint: enable=cell-var-from-loop - return out_slices_list def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): @@ -653,56 +618,35 @@ def stitch_values(values_and_indices_list): return result -def per_replica_num_elements(per_replica): - """Returns the static number of elements of one replica. +def group_by_size(input_tensors, bytes_per_pack): + """Groups `input_tensors` into chunks of `bytes_per_pack`. - Args: - per_replica: A PerReplica of Tensor or IndexedSlices. - - Returns: - Number of elements. None if some replica has a different or unknown shape. - """ - - values = per_replica._values # pylint: disable=protected-access - s0 = values[0].shape - for v in values: - assert not isinstance(v, ops.IndexedSlices) - if v.shape != s0: - return None - return s0.num_elements() - - -def pack_by_size(per_replica_list, bytes_per_pack): - """Packs `per_replica_list` into chunks of `bytes_per_pack`. - - The method preserves the original order of `per_replica_list`. The packing is + The method preserves the original order of `input_tensors`. The grouping is best effort, each pack could have more or less bytes than `bytes_per_pack`. - It only packs values with known shape. Note that, the usage is different from - `cross_device_ops._pack_tensors`, this function is intended to work with the - ScopeAllocator style batching used in `CollectiveAllReduce`. + It only groups values with known shape. Args: - per_replica_list: A list of PerReplica. - bytes_per_pack: Bytes per pack. + input_tensors: a list of Tensor. + bytes_per_pack: an integer. Returns: - A list of packs of PerReplica. All values are packed into one pack if - `bytes_per_pack` is zero or any of the value has unknown shape. + A list of packs of Tensor. All values are grouped into one pack if + `bytes_per_pack` is zero or any of the value has unknown shape. """ if bytes_per_pack == 0: - return [per_replica_list] + return [input_tensors] packs = [] last_pack_size = 0 - for value in per_replica_list: - num_elements = per_replica_num_elements(value) + for value in input_tensors: + num_elements = value.shape.num_elements() if num_elements is None: # Can't pack values with unknown shape. logging.warning( 'not packing values due to the unknown or inconsistent shape of %s', value) - return [per_replica_list] - size = num_elements * value._primary.dtype.size # pylint: disable=protected-access + return [input_tensors] + size = num_elements * value.dtype.size # Try to keep each pack as close to bytes_per_pack as possible, while each # pack is at least bytes_per_pack large. I.E. we err on the side of having # few but large packs. @@ -714,24 +658,15 @@ def pack_by_size(per_replica_list, bytes_per_pack): return packs -def _control_input(devices, control_inputs, idx): - """Returns the `idx`-th item in control_inputs to be used in ops.control_dependencies. - - This is a helper function for building collective ops. - - Args: - devices: a list of device strings the collective run on. - control_inputs: a list or None. - idx: the index into `inputs` and `control_inputs`. - - Returns: - A one item list of the `idx`-th element of `control_inputs`, or an empty - list if `control_inputs` is None. - """ - if control_inputs is None: - return [] - if len(control_inputs) != len(devices): - raise ValueError( - 'control_inputs must match the length of the devices, %s != %s' % - (len(control_inputs), len(devices))) - return [control_inputs[idx]] +def _pad_util(input_tensor, full_axis_dim): + """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" + missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] + tensor_rank = array_ops.rank(input_tensor) + paddings_axis = [[0, missing_axis_dim]] + paddings = array_ops.concat([ + paddings_axis, + array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) + ], + axis=0) + padded_input_tensor = array_ops.pad(input_tensor, paddings) + return padded_input_tensor diff --git a/tensorflow/python/distribute/cross_device_utils_test.py b/tensorflow/python/distribute/cross_device_utils_test.py index 626ec5cfd60..108d5478ce6 100644 --- a/tensorflow/python/distribute/cross_device_utils_test.py +++ b/tensorflow/python/distribute/cross_device_utils_test.py @@ -23,7 +23,6 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util -from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -114,11 +113,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): device_util.resolve(destination), device_util.resolve(result.device)) -class PackBySizeTest(test.TestCase): - - def assertShape(self, per_replica, shape): - for v in per_replica._values: # pylint: disable=protected-access - self.assertEqual(v.shape, shape) +class GroupBySizeTest(test.TestCase): def testPreferLargerPack(self): # Each packs except the last one should be equal or larger than @@ -133,49 +128,38 @@ class PackBySizeTest(test.TestCase): # size = 1 * 4 = 4 array_ops.ones([1], dtype=dtypes.int32), ] - per_replica_values = [value_lib.PerReplica([v, v]) for v in values] - packs = cross_device_utils.pack_by_size( - per_replica_values, bytes_per_pack=200) + packs = cross_device_utils.group_by_size(values, bytes_per_pack=200) self.assertLen(packs, 2) self.assertLen(packs[0], 3) - self.assertShape(packs[0][0], [2, 4, 4]) - self.assertShape(packs[0][1], [8]) - self.assertShape(packs[0][2], [10, 10]) + self.assertEqual(packs[0][0].shape, [2, 4, 4]) + self.assertEqual(packs[0][1].shape, [8]) + self.assertEqual(packs[0][2].shape, [10, 10]) self.assertLen(packs[1], 1) - self.assertShape(packs[1][0], [1]) + self.assertEqual(packs[1][0].shape, [1]) def testZeroBytesPerPack(self): values = [ array_ops.ones([1], dtype=dtypes.float32), array_ops.ones([2], dtype=dtypes.float32), ] - per_replica_values = [value_lib.PerReplica([v, v]) for v in values] - packs = cross_device_utils.pack_by_size( - per_replica_values, bytes_per_pack=0) + packs = cross_device_utils.group_by_size(values, bytes_per_pack=0) self.assertLen(packs, 1) self.assertLen(packs[0], 2) - self.assertShape(packs[0][0], [1]) - self.assertShape(packs[0][1], [2]) + self.assertEqual(packs[0][0].shape, [1]) + self.assertEqual(packs[0][1].shape, [2]) def testUnknownShape(self): def create_placeholder(shape, dtype): with ops.Graph().as_default(): return array_ops.placeholder(dtype=dtype, shape=shape) - per_replica_values = [ - value_lib.PerReplica([ - array_ops.ones([10, 10], dtype=dtypes.float32), - array_ops.ones([10, 10], dtype=dtypes.float32), - ]), - value_lib.PerReplica([ - array_ops.ones([10, 10], dtype=dtypes.float32), - create_placeholder([None, 10], dtype=dtypes.float32), - ]), + values = [ + array_ops.ones([10, 10], dtype=dtypes.float32), + create_placeholder([None, 10], dtype=dtypes.float32), ] - packs = cross_device_utils.pack_by_size( - per_replica_values, bytes_per_pack=1) + packs = cross_device_utils.group_by_size(values, bytes_per_pack=1) self.assertLen(packs, 1) - self.assertEqual(packs[0], per_replica_values) + self.assertEqual(packs[0], values) if __name__ == "__main__": diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 14e2b6f3f02..7cd5eb1c85a 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -681,9 +681,7 @@ class StrategyBase(object): [see the guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): - * Start by either creating a `tf.data.Dataset` normally or using - `tf.distribute.experimental_make_numpy_dataset` to make a dataset out of - a `numpy` array. + * Start by creating a `tf.data.Dataset` normally. * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert a `tf.data.Dataset` to something that produces "per-replica" values. If you want to manually specify how the dataset should be partitioned @@ -908,39 +906,6 @@ class StrategyBase(object): return self.extended._make_input_fn_iterator( # pylint: disable=protected-access input_fn, replication_mode=replication_mode) - @deprecation.deprecated( - "2020-09-30", "Please use tf.data.Dataset.from_tensor_slices instead") - def experimental_make_numpy_dataset(self, numpy_input): - """Makes a `tf.data.Dataset` from a numpy array. - - This avoids adding `numpy_input` as a large constant in the graph, - and copies the data to the machine or machines that will be processing - the input. - - Note that you will likely need to use `experimental_distribute_dataset` - with the returned dataset to further distribute it with the strategy. - - Example: - - >>> strategy = tf.distribute.MirroredStrategy() - >>> numpy_input = np.ones([10], dtype=np.float32) - >>> dataset = strategy.experimental_make_numpy_dataset(numpy_input) - >>> dataset - - >>> dataset = dataset.batch(2) - >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) - - Args: - numpy_input: a nest of NumPy input arrays that will be converted into a - dataset. Note that the NumPy arrays are stacked, as that is normal - `tf.data.Dataset` behavior. - - Returns: - A `tf.data.Dataset` representing `numpy_input`. - """ - return self.extended.experimental_make_numpy_dataset( - numpy_input, session=None) - @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only def experimental_run(self, fn, input_iterator=None): """DEPRECATED TF 1.x ONLY.""" @@ -1087,7 +1052,7 @@ class StrategyBase(object): This method can be used for several purposes. First, it allows you to specify your own batching and sharding logic. (In contrast, `tf.distribute.experimental_distribute_dataset` does batching and sharding - for you.)For example, where + for you.) For example, where `experimental_distribute_dataset` is unable to shard the input files, this method might be used to manually shard the dataset (avoiding the slow fallback behavior in `experimental_distribute_dataset`). In cases where the @@ -1561,6 +1526,9 @@ class StrategyBase(object): _require_cross_replica_or_default_context_extended(self._extended) dst = device_util.current( ) or self._extended._default_device or "/device:CPU:0" + if isinstance(value, ops.IndexedSlices): + raise NotImplementedError("gather/all_gather does not support " + "IndexedSlices") return self._extended._local_results( self._extended._gather_to(value, dst, axis))[0] @@ -2880,6 +2848,11 @@ class ReplicaContext(object): raise ValueError( "replica_id_in_sync_group can only be an integer, a Tensor or None.") self._replica_id_in_sync_group = replica_id_in_sync_group + # We need this check becaused TPUContext extends from ReplicaContext and + # does not pass a strategy object since it is used by TPUEstimator. + if strategy: + self._local_replica_id = strategy.extended._get_local_replica_id( + replica_id_in_sync_group) self._summary_recording_distribution_strategy = None @doc_controls.do_not_generate_docs @@ -2979,6 +2952,11 @@ class ReplicaContext(object): dtypes.int32, name="replica_id_in_sync_group") + @property + def _replica_id(self): + """This is the local replica id in a given sync group.""" + return self._local_replica_id + @property def strategy(self): """The current `tf.distribute.Strategy` object.""" @@ -3193,7 +3171,7 @@ class ReplicaContext(object): def grad_wrapper(*xs): ys = self.merge_call(batch_all_gather, args=xs) # The gradient of an all-gather is itself an all-gather. - return ys, lambda *dy_s: self.all_gather(dy_s, axis) + return ys, lambda *dy_s: self._all_gather(dy_s, axis) return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) @@ -3346,6 +3324,10 @@ class _DefaultDistributionExtended(StrategyExtendedV1): del reduce_op, destinations, experimental_hints return value + def _gather_to_implementation(self, value, destinations, axis, experimental_hints): + del destinations, axis, experimental_hints + return value + def _update(self, var, fn, args, kwargs, group): # The implementations of _update() and _update_non_slot() are identical # except _update() passes `var` as the first argument to `fn()`. @@ -3400,6 +3382,12 @@ class _DefaultDistributionExtended(StrategyExtendedV1): def should_save_summary(self): return True + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + + def _get_replica_id_in_sync_group(self, replica_id): + return replica_id + # TODO(priyag): This should inherit from `InputIterator`, once dependency # issues have been resolved. class DefaultInputIterator(object): diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index 0fe05b52d6f..04a1c5dc903 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized -import numpy as np from tensorflow.python.autograph.core import converter_testing from tensorflow.python.data.ops import dataset_ops @@ -99,10 +98,6 @@ class _TestExtended(distribute_lib.StrategyExtendedV1): del reduce_op, destinations, experimental_hints return value - def _experimental_make_numpy_dataset(self, numpy_input, session): - del session - return dataset_ops.DatasetV2.from_tensor_slices(numpy_input) - def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): # TODO(tomhennigan) This is missing many things (e.g. ctx.run_op). @@ -124,6 +119,9 @@ class _TestExtended(distribute_lib.StrategyExtendedV1): else: return nest.map_structure(self._unwrap, result) + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), @@ -161,7 +159,7 @@ class TestStrategyTest(test.TestCase): def run_fn(): replica_context = ds_context.get_replica_context() - self.assertTrue(replica_context is not None) + self.assertIsNotNone(replica_context) self.assertIs(None, ds_context.get_cross_replica_context()) self.assertFalse(ds_context.in_cross_replica_context()) self.assertTrue(ds_context.has_strategy()) @@ -343,13 +341,6 @@ class TestStrategyTest(test.TestCase): self.assertEqual(self.evaluate(x), self.evaluate(x_r)) self.assertEqual(self.evaluate(y), self.evaluate(y_r)) - @_run_in_and_out_of_scope - def testExperimentalMakeNumpyDataset(self, dist): - numpy_input = np.ones([10], dtype=np.float32) - dataset = dist.experimental_make_numpy_dataset(numpy_input) - self.assertEqual( - self.evaluate(dataset.reduce(0., lambda a, b: a + b)), 10.) - @_run_in_and_out_of_scope def testExperimentalRunStepsOnIterator(self, dist): all_inputs = [] diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index dbefff9cf66..d01cedcead0 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -61,7 +61,7 @@ from tensorflow.tools.docs import doc_controls def get_distributed_dataset(dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): """Returns a distributed dataset from the given tf.data.Dataset instance. @@ -77,8 +77,10 @@ def get_distributed_dataset(dataset, iterators should be created. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - split_batch_by: Optional integer. If present, we "split" each batch of the - dataset by `split_batch_by` value. + num_replicas_in_sync: Optional integer. If this is not None, the value is + used to decide how to rebatch datasets into smaller batches so that + the total batch size for each step (across all workers and replicas) + adds up to `dataset`'s batch size. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -92,14 +94,14 @@ def get_distributed_dataset(dataset, dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) else: return DistributedDatasetV1( dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) @@ -552,67 +554,41 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False): return global_has_value, replicas -def _is_statically_shaped(tensor_class, shape): +def _is_statically_shaped(element_spec): """Test if an iterator output is statically shaped. For sparse and ragged tensors this only tests the batch dimension. Args: - tensor_class: a class from an iterator.output_classes list. - shape: a TensorShape from an iterator.output_shapes list. + element_spec: a nest structure of `tf.TypeSpec`. The element spec of the + dataset of the iterator. Returns: True if the shape is static, false otherwise. """ - if (tensor_class == sparse_tensor.SparseTensor or - isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)): - # For sparse or ragged tensor, we should only check the first - # dimension in order to get_next_as_optional. This is because - # when these tensors get batched by dataset only the batch dimension - # is set. - if shape.rank > 0 and shape.as_list()[0] is None: - return False - return True - return shape.is_fully_defined() - -def _get_static_shape(iterators): - """Returns a boolean indicating if the input is fully defined.""" - static_shape = True - for iterator in iterators: - if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator, - _SingleWorkerDatasetIterator)): - continue - flattened = zip(nest.flatten(iterator.output_shapes), - nest.flatten(iterator.output_classes)) - for output_shape, output_class in flattened: - if not _is_statically_shaped(output_class, output_shape): - static_shape = False - break - return static_shape + for spec in nest.flatten(element_spec): + if isinstance( + spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)): + # For sparse or ragged tensor, we should only check the first + # dimension in order to get_next_as_optional. This is because + # when these tensors get batched by dataset only the batch dimension + # is set. + if spec.shape.rank > 0 and spec.shape.as_list()[0] is None: + return False + else: + for component in nest.flatten(spec._component_specs): # pylint: disable=protected-access + if not component.shape.is_fully_defined(): + return False + return True class DistributedIteratorBase(DistributedIteratorInterface): """Common implementation for all input iterators.""" # pylint: disable=super-init-not-called - def __init__(self, input_workers, iterators, strategy): - static_shape = _get_static_shape(iterators) - - # TODO(b/133073708): we currently need a flag to control the usage because - # there is a performance difference between get_next() and - # get_next_as_optional(). And we only enable get_next_as_optional when the - # output shapes are not static. - # - # TODO(rxsang): We want to always enable the get_next_as_optional behavior - # when user passed input_fn instead of dataset. - if getattr( - strategy.extended, "experimental_enable_get_next_as_optional", False): - self._enable_get_next_as_optional = ( - not static_shape) or strategy.extended._in_multi_worker_mode() - else: - self._enable_get_next_as_optional = False - + def __init__(self, input_workers, iterators, strategy, + enable_get_next_as_optional): assert isinstance(input_workers, InputWorkers) if not input_workers.worker_devices: raise ValueError("Should have at least one worker for input iterator.") @@ -620,6 +596,7 @@ class DistributedIteratorBase(DistributedIteratorInterface): self._iterators = iterators self._input_workers = input_workers self._strategy = strategy + self._enable_get_next_as_optional = enable_get_next_as_optional def next(self): return self.__next__() @@ -753,9 +730,13 @@ class DistributedIteratorV1(DistributedIteratorBase): class DistributedIteratorSpec(type_spec.TypeSpec): """Type specification for `DistributedIterator`.""" - __slots__ = ["_input_workers", "_element_spec", "_strategy"] + __slots__ = [ + "_input_workers", "_element_spec", "_strategy", + "_enable_get_next_as_optional" + ] - def __init__(self, input_workers, element_spec, strategy): + def __init__(self, input_workers, element_spec, strategy, + enable_get_next_as_optional): # We don't want to allow deserialization of this class because we don't # serialize the strategy object. Currently the only places where # _deserialize is called is when we save/restore using SavedModels. @@ -766,6 +747,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec): self._input_workers = input_workers self._element_spec = element_spec self._strategy = strategy + self._enable_get_next_as_optional = enable_get_next_as_optional @property def value_type(self): @@ -806,7 +788,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec): lambda a, b: a.most_specific_compatible_type(b), self._element_spec, other._element_spec) return DistributedIteratorSpec(self._input_workers, element_spec, - self._strategy) + self._strategy, + self._enable_get_next_as_optional) @property def _component_specs(self): @@ -825,32 +808,41 @@ class DistributedIteratorSpec(type_spec.TypeSpec): return value._iterators # pylint: disable=protected-access def _from_components(self, components): - return DistributedIterator(input_workers=self._input_workers, - iterators=None, - components=components, - element_spec=self._element_spec, - strategy=self._strategy) + return DistributedIterator( + input_workers=self._input_workers, + iterators=None, + components=components, + element_spec=self._element_spec, + strategy=self._strategy, + enable_get_next_as_optional=self._enable_get_next_as_optional) @staticmethod def from_value(value): # pylint: disable=protected-access return DistributedIteratorSpec(value._input_workers, value._element_spec, - value._strategy) + value._strategy, + value._enable_get_next_as_optional) def _with_tensor_ranks_only(self): element_spec = nest.map_structure( lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access self._element_spec) return DistributedIteratorSpec(self._input_workers, element_spec, - self._strategy) + self._strategy, + self._enable_get_next_as_optional) class DistributedIterator(DistributedIteratorBase, composite_tensor.CompositeTensor): """Input Iterator for a distributed dataset.""" - def __init__(self, input_workers=None, iterators=None, strategy=None, - components=None, element_spec=None): + def __init__(self, + input_workers=None, + iterators=None, + strategy=None, + components=None, + element_spec=None, + enable_get_next_as_optional=False): if input_workers is None: raise ValueError("`input_workers` should be " "provided.") @@ -865,20 +857,15 @@ class DistributedIterator(DistributedIteratorBase, self._element_spec = element_spec self._input_workers = input_workers self._iterators = components - static_shape = _get_static_shape(self._iterators) self._strategy = strategy - if getattr(strategy.extended, - "experimental_enable_get_next_as_optional", False): - self._enable_get_next_as_optional = ( - not static_shape) or strategy.extended._in_multi_worker_mode() - else: - self._enable_get_next_as_optional = False + self._enable_get_next_as_optional = enable_get_next_as_optional else: if (components is not None and element_spec is not None): raise ValueError(error_message) - super(DistributedIterator, self).__init__(input_workers, iterators, - strategy) + super(DistributedIterator, + self).__init__(input_workers, iterators, strategy, + enable_get_next_as_optional) @property def element_spec(self): @@ -886,9 +873,9 @@ class DistributedIterator(DistributedIteratorBase, @property def _type_spec(self): - return DistributedIteratorSpec(self._input_workers, - self.element_spec, - self._strategy) + return DistributedIteratorSpec(self._input_workers, self.element_spec, + self._strategy, + self._enable_get_next_as_optional) class _IterableInput(DistributedDatasetInterface): @@ -932,20 +919,24 @@ class DistributedDataset(_IterableInput): dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): """Distribute the dataset on all workers. - If `split_batch_by` is not None, we "split" each batch of the dataset by - `split_batch_by` value. + If `num_replicas_in_sync` is not None, we split each batch of the dataset + into `num_replicas_in_sync` smaller batches, to be distributed among that + worker's replicas, so that the batch size for a global step (across all + workers and replicas) is as expected. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - split_batch_by: Optional integer. If present, we "split" each batch of the - dataset by `split_batch_by` value. + num_replicas_in_sync: Optional integer. If this is not None, the value + is used to decide how to rebatch datasets into smaller batches so that + the total batch size for each step (across all workers and replicas) + adds up to `dataset`'s batch size. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -957,36 +948,29 @@ class DistributedDataset(_IterableInput): # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. - if split_batch_by: - try: - # pylint: disable=protected-access - with ops.colocate_with(dataset._variant_tensor): - dataset = distribute._LegacyRebatchDataset(dataset, split_batch_by) - # Add a prefetch to pipeline rebatching for performance. - # TODO(rachelim): Instead of inserting an extra prefetch stage here, - # leverage static graph rewrites to insert _RebatchDataset before - # the final `prefetch` if it exists. - dataset = dataset.prefetch(split_batch_by) - except errors.InvalidArgumentError as e: - if "without encountering a batch" in str(e): - six.reraise( - ValueError, - ValueError( - "Call the `batch` method on the input Dataset in order to be " - "able to split your input across {} replicas.\n Please " - "the tf.distribute.Strategy guide. {}".format( - split_batch_by, e)), - sys.exc_info()[2]) - else: - raise + + # Additionally, we rebatch the dataset on each worker into + # `num_replicas_in_sync` smaller batches to be distributed among that + # worker's replicas, so that the batch size for a global step (across all + # workers and replicas) adds up to the original dataset's batch size. + if num_replicas_in_sync is not None: + num_workers = input_context.num_input_pipelines if input_context else len( + input_workers.worker_devices) + rebatch_fn = self._make_rebatch_fn(dataset, num_workers, + num_replicas_in_sync) + else: + rebatch_fn = None self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 + if rebatch_fn is not None: + dataset = rebatch_fn(dataset, input_context.input_pipeline_id) dataset = input_ops.auto_shard_dataset(dataset, input_context.num_input_pipelines, - input_context.input_pipeline_id) + input_context.input_pipeline_id, + num_replicas_in_sync) self._cloned_datasets.append(dataset) else: replicated_ds = distribute.replicate(dataset, @@ -995,14 +979,73 @@ class DistributedDataset(_IterableInput): with ops.device(worker): cloned_dataset = replicated_ds[worker] cloned_dataset = cloned_dataset.with_options(dataset.options()) + if rebatch_fn is not None: + cloned_dataset = rebatch_fn(cloned_dataset, i) cloned_dataset = input_ops.auto_shard_dataset( - cloned_dataset, len(input_workers.worker_devices), i) + cloned_dataset, len(input_workers.worker_devices), i, + num_replicas_in_sync) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers self._strategy = strategy - self._element_spec = _create_distributed_tensor_spec(self._strategy, - dataset.element_spec) # pylint: disable=protected-access + self._enable_get_next_as_optional = _enable_get_next_as_optional( + self._strategy, dataset.element_spec) + self._element_spec = _create_distributed_tensor_spec( + self._strategy, self._cloned_datasets[0].element_spec) + + def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): + """Returns a callable that rebatches the input dataset. + + Args: + dataset: A `tf.data.Dataset` representing the dataset to be distributed. + num_workers: An integer representing the number of workers to distribute + `dataset` among. + num_replicas_in_sync: An integer representing the number of replicas in + sync across all workers. + """ + if num_replicas_in_sync % num_workers: + raise ValueError( + "tf.distribute expects every worker to have the same number of " + "replicas. However, encountered `num_replicas_in_sync` ({}) that " + "cannot be divided by `num_workers` ({})".format( + num_replicas_in_sync, num_workers)) + + num_replicas_per_worker = num_replicas_in_sync // num_workers + with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access + batch_size = distribute.compute_batch_size(dataset) + + def rebatch_fn(dataset, worker_index): + try: + # pylint: disable=protected-access + def apply_rebatch(): + batch_sizes = distribute.batch_sizes_for_worker( + batch_size, num_workers, num_replicas_per_worker, worker_index) + return distribute._RebatchDataset( + dataset, batch_sizes).prefetch(num_replicas_per_worker) + + def apply_legacy_rebatch(): + return distribute._LegacyRebatchDataset( + dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker) + + with ops.colocate_with(dataset._variant_tensor): + return control_flow_ops.cond( + math_ops.not_equal(batch_size, -1), + true_fn=apply_rebatch, + false_fn=apply_legacy_rebatch) + except errors.InvalidArgumentError as e: + if "without encountering a batch" in str(e): + six.reraise( + ValueError, + ValueError( + "Call the `batch` method on the input Dataset in order to be " + "able to split your input across {} replicas.\n Please see " + "the tf.distribute.Strategy guide. {}".format( + num_replicas_in_sync, e)), + sys.exc_info()[2]) + else: + raise + + return rebatch_fn def __iter__(self): if not (context.executing_eagerly() or @@ -1019,11 +1062,17 @@ class DistributedDataset(_IterableInput): self._input_workers, enable_legacy_iterators) if enable_legacy_iterators: - iterator = DistributedIteratorV1(self._input_workers, worker_iterators, - self._strategy) + iterator = DistributedIteratorV1( + self._input_workers, + worker_iterators, + self._strategy, + enable_get_next_as_optional=self._enable_get_next_as_optional) else: - iterator = DistributedIterator(self._input_workers, worker_iterators, - self._strategy) + iterator = DistributedIterator( + self._input_workers, + worker_iterators, + self._strategy, + enable_get_next_as_optional=self._enable_get_next_as_optional) iterator._element_spec = self.element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1047,14 +1096,14 @@ class DistributedDatasetV1(DistributedDataset): dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): self._input_workers = input_workers super(DistributedDatasetV1, self).__init__( dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) def make_one_shot_iterator(self): @@ -1104,7 +1153,8 @@ class DistributedDatasetV1(DistributedDataset): self._input_workers, True) iterator = DistributedIteratorV1(self._input_workers, worker_iterators, - self._strategy) + self._strategy, + self._enable_get_next_as_optional) iterator._element_spec = self.element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1156,6 +1206,8 @@ class DistributedDatasetsFromFunction(_IterableInput): _create_datasets_per_worker_with_input_context(self._input_contexts, self._input_workers, dataset_fn)) + self._enable_get_next_as_optional = _enable_get_next_as_optional( + self._strategy, element_spec) self._element_spec = _create_distributed_tensor_spec( self._strategy, element_spec) @@ -1174,11 +1226,17 @@ class DistributedDatasetsFromFunction(_IterableInput): enable_legacy_iterators) if enable_legacy_iterators: - iterator = DistributedIteratorV1(self._input_workers, iterators, - self._strategy) + iterator = DistributedIteratorV1( + self._input_workers, + iterators, + self._strategy, + enable_get_next_as_optional=self._enable_get_next_as_optional) else: - iterator = DistributedIterator(self._input_workers, iterators, - self._strategy) + iterator = DistributedIterator( + self._input_workers, + iterators, + self._strategy, + enable_get_next_as_optional=self._enable_get_next_as_optional) iterator._element_spec = self._element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1225,7 +1283,8 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): iterators = _create_iterators_per_worker(self._datasets, self._input_workers, True) iterator = DistributedIteratorV1(self._input_workers, iterators, - self._strategy) + self._strategy, + self._enable_get_next_as_optional) iterator._element_spec = self._element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1289,9 +1348,8 @@ class InputFunctionIterator(DistributedIteratorV1): "input_fn must return a tf.data.Dataset or a callable.") iterators.append(iterator) - super(InputFunctionIterator, self).__init__(input_workers, iterators, - strategy) - self._enable_get_next_as_optional = False + super(InputFunctionIterator, self).__init__( + input_workers, iterators, strategy, enable_get_next_as_optional=False) # TODO(anjalisridhar): This class will soon be removed and users should move @@ -1303,20 +1361,24 @@ class DatasetIterator(DistributedIteratorV1): dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): """Make an iterator for the dataset on given devices. - If `split_batch_by` is not None, we "split" each batch of the - dataset by `split_batch_by` value. + If `num_replicas_in_sync` is not None, we split each batch of the dataset + into `num_replicas_in_sync` smaller batches, to be distributed among that + worker's replicas, so that the batch size for a global step (across all + workers and replicas) is as expected. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - split_batch_by: Optional integer. If present, we "split" each batch of the - dataset by `split_batch_by` value. + num_replicas_in_sync: Optional integer. If this is not None, the value is + used to decide how to rebatch datasets into smaller batches so that the + total batch size for each step (across all workers and replicas) adds up + to `dataset`'s batch size. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -1326,14 +1388,13 @@ class DatasetIterator(DistributedIteratorV1): dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) worker_iterators = _create_iterators_per_worker( dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access - super(DatasetIterator, self).__init__( - input_workers, - worker_iterators, # pylint: disable=protected-access - strategy) + super(DatasetIterator, + self).__init__(input_workers, worker_iterators, strategy, + dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access self._element_spec = dist_dataset.element_spec @@ -1952,3 +2013,19 @@ def _replace_per_replica_spec(spec, i): return spec._value_specs[i] # pylint: disable=protected-access else: return spec + + +def _enable_get_next_as_optional(strategy, element_spec): + """Returns whether to enable using partial batch handling.""" + # TODO(b/133073708): we currently need a flag to control the usage because + # there is a performance difference between get_next() and + # get_next_as_optional(). And we only enable get_next_as_optional when the + # output shapes are not static. + # + # TODO(rxsang): We want to always enable the get_next_as_optional behavior + # when user passed input_fn instead of dataset. + if not getattr(strategy.extended, "experimental_enable_get_next_as_optional", + False): + return False + return not _is_statically_shaped( + element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index ec0b591d710..5abd6f483d3 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -35,6 +35,7 @@ from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test @@ -61,7 +62,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, devices, - split_batch_by, + num_replicas_in_sync, strategy, input_context=None): # The `input_context` passed in is to shard dataset for @@ -93,7 +94,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) return iterator @@ -101,7 +102,7 @@ class DistributedIteratorTestBase(test.TestCase): input_type, dataset, input_workers, - split_batch_by, + num_replicas_in_sync, strategy, input_context=None): if input_type == "dataset": @@ -110,14 +111,14 @@ class DistributedIteratorTestBase(test.TestCase): dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) else: return input_lib.DistributedDatasetV1( dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) else: return strategy.distribute_datasets_from_function(dataset) @@ -163,7 +164,7 @@ class DistributedIteratorTestBase(test.TestCase): expected_values, strategy, sess=None, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): if iteration_type == "for_loop" and not context.executing_eagerly(): self.skipTest("unsupported test combination.") @@ -183,7 +184,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, devices, - split_batch_by, + num_replicas_in_sync, strategy, input_context=input_context) else: @@ -192,7 +193,7 @@ class DistributedIteratorTestBase(test.TestCase): input_type, dataset_or_input_fn, input_workers, - split_batch_by, + num_replicas_in_sync, strategy, input_context=input_context) @@ -361,10 +362,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] - if tf2.enabled(): - dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) - else: - dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -419,10 +417,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:CPU:0"])] - if tf2.enabled(): - dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) - else: - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -455,10 +450,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs.setdefault(host_device, []) worker_device_pairs[host_device].append(tpu_device) worker_device_pairs = worker_device_pairs.items() - if tf2.enabled(): - dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) - else: - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -493,14 +485,10 @@ class DistributedIteratorTest(DistributedIteratorTestBase, def dataset_fn(ctx): del ctx - if tf2.enabled(): - dataset1 = dataset_ops.DatasetV2.range(10) - dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2) - return dataset_ops.DatasetV2.zip((dataset1, dataset2)) - else: - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -563,7 +551,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] input_workers = input_lib.InputWorkers(worker_device_pairs) - dataset = dataset_ops.DatasetV2.range(10) + dataset = dataset_ops.Dataset.range(10) dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, distribution) @@ -663,7 +651,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs, expected_values, distribution, - split_batch_by=distribution.num_replicas_in_sync, + num_replicas_in_sync=distribution.num_replicas_in_sync, input_context=distribution.extended._make_input_context()) @combinations.generate( @@ -724,7 +712,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs, expected_values, distribution, - split_batch_by=distribution.num_replicas_in_sync, + num_replicas_in_sync=distribution.num_replicas_in_sync, input_context=distribution.extended._make_input_context()) @combinations.generate( @@ -733,14 +721,14 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type=["dataset"], api_type=["wrap_into_iterator", "wrap_into_dataset"], iteration_type=["get_next", "for_loop"], - split_batch_by=[None, 2], + num_replicas_in_sync=[None, 2], distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.central_storage_strategy_with_gpu_and_cpu ], enable_get_next_as_optional=[True, False])) def testBatchSplitting(self, input_type, api_type, iteration_type, - split_batch_by, distribution, + num_replicas_in_sync, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:CPU:0"])] @@ -750,7 +738,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type, dataset_fn) updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) + batch_size // + num_replicas_in_sync if num_replicas_in_sync else batch_size) expected_values = [[range(i, i+updated_batch_size), range(i+updated_batch_size, i+2*updated_batch_size)] for i in range(0, 100, updated_batch_size*2)] @@ -766,7 +755,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, expected_values, distribution, sess=None, - split_batch_by=split_batch_by) + num_replicas_in_sync=num_replicas_in_sync) @combinations.generate( combinations.combine( @@ -774,13 +763,13 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type=["dataset"], api_type=["wrap_into_dataset"], iteration_type=["get_next", "for_loop"], - split_batch_by=[None, 2], + num_replicas_in_sync=[None, 2], distribution=[ strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False])) def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type, - split_batch_by, distribution, + num_replicas_in_sync, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:GPU:1"])] @@ -796,7 +785,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type, dataset_fn) updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) + batch_size // + num_replicas_in_sync if num_replicas_in_sync else batch_size) expected_values = [ [ # pylint: disable=g-complex-comprehension range(i, i + updated_batch_size), @@ -815,7 +805,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, expected_values, distribution, sess=None, - split_batch_by=split_batch_by) + num_replicas_in_sync=num_replicas_in_sync) @combinations.generate( combinations.combine( @@ -1249,6 +1239,187 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, expected_for_sum = 310. self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3) + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ])) + def testMWMSPartialBatch(self, input_type, api_type, iteration_type, + distribution): + # Test case: 2 workers, 1 replica each. + # This test simulates the sharded behavior when we have two files each with + # 12 elements and a global batch size of 8. When we consider the dataset in + # aggregate (non-distributed), there are 24 elements divided into 3 batches + # of size 8. Hence, the correct distributed behavior is for each replica to + # see sub-batches of size 4, over three steps. + def dataset_fn(ctx): + del ctx + dataset = dataset_ops.Dataset.range(12).batch(8) + + # Set the sharding behavior to OFF for simplicity of test setup; namely, + # `dataset` defines the per-worker dataset and will not be further + # sharded. Each worker will see a dataset that is + # tf.data.Dataset.range(12).batch(8).rebatch(...). + options = dataset_ops.Options() + options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF + dataset = dataset.with_options(options) + return dataset + + dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) + + # Actual devices don't matter in this test as long as there is 1 local + # replica. + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + + # Each test runs individually on each worker, so we compare the + # values on each worker. Each worker should rebatch its dataset into + # smaller batches of size 4. + expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]] + self._test_input_iteration( + input_type, + api_type, + iteration_type, + dataset, + worker_device_pairs, + expected_values, + distribution, + num_replicas_in_sync=distribution.num_replicas_in_sync, + input_context=distribution.extended._make_input_context()) + + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ])) + def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type, + iteration_type, distribution): + # Test case: 2 workers, 1 replica each. + # This test simulates the sharded behavior when we have two files each with + # 12 elements and a global batch size of 8. When we consider the dataset in + # aggregate (non-distributed), there are 24 elements divided into 3 batches + # of size 8. Hence, the correct distributed behavior is for each replica to + # see sub-batches of size 4, over three steps. However, when we create a + # DistributedDataset and cannot statically infer the intended global batch + # size (e.g. if the user does not use a batching dataset), each worker will + # rebatch based on the dynamic batch size of the data encountered, even when + # it encounters partial batches. The last per-worker partial batch (size 4) + # ends up being split into two replicas, resulting in 4 steps in total, of + # (global) batch sizes 8, 8, 4, 4. + def dataset_fn(ctx): + del ctx + # The following dataset is equivalent to + # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset. + # This causes DistributedDataset to use LegacyRebatch instead. + batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4]) + offsets = dataset_ops.Dataset.from_tensor_slices([0, 8]) + dataset = dataset_ops.Dataset.zip((offsets, batch_sizes)) + + def map_fn(offset, batch_size): + return math_ops.range(offset, offset + batch_size) + + dataset = dataset.map(map_fn) + + # Set the sharding behavior to OFF for simplicity of test setup; namely, + # `dataset` defines the per-worker dataset and will not be further + # sharded. Each worker will see a dataset that is equivalent to + # tf.data.Dataset.range(12).batch(8).rebatch(...). + options = dataset_ops.Options() + options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF + dataset = dataset.with_options(options) + return dataset + + dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) + + # Actual devices don't matter in this test as long as the number of global + # replicas is 2. + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + + # Each test runs individually on each worker, so we compare the + # values on each worker. Each worker should rebatch its dataset into + # smaller batches of size 4. + expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]] + self._test_input_iteration( + input_type, + api_type, + iteration_type, + dataset, + worker_device_pairs, + expected_values, + distribution, + num_replicas_in_sync=distribution.num_replicas_in_sync, + input_context=distribution.extended._make_input_context()) + + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ], + auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA])) + def testMWMSWithDataSharding(self, input_type, api_type, iteration_type, + distribution, auto_shard_policy): + # Test case: 2 workers, 1 replica each. + # This test simulates the sharded behavior the dataset is sharded by data + # and the batch size is indivisible by the number of replicas. This checks + # that the elements are as expected and the batch size across all workers + # adds up to 3. This test will only pass if the autoshard rewrite rewrites + # RebatchDatasetV2 to legacy RebatchDataset when sharding by data. + def dataset_fn(ctx): + del ctx + dataset = dataset_ops.Dataset.range(8).batch(3) + + # Set the sharding behavior to OFF for simplicity of test setup; namely, + # `dataset` defines the per-worker dataset and will not be further + # sharded. Each worker will see a dataset that is + # tf.data.Dataset.range(12).batch(8).rebatch(...). + options = dataset_ops.Options() + options.experimental_distribute.auto_shard_policy = auto_shard_policy + dataset = dataset.with_options(options) + return dataset + + dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) + + # Actual devices don't matter in this test as long as there is 1 local + # replica. + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + + # Each test runs individually on each worker, so we compare the + # values on each worker. We expect each worker to see different shards of + # data. + cr = distribution.cluster_resolver + worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type, + cr.task_id) + + if worker_id == 0: + expected_values = [[[0, 1]], [[3, 4]], [[6]]] + elif worker_id == 1: + expected_values = [[[2]], [[5]], [[7]]] + + self._test_input_iteration( + input_type, + api_type, + iteration_type, + dataset, + worker_device_pairs, + expected_values, + distribution, + num_replicas_in_sync=distribution.num_replicas_in_sync, + input_context=distribution.extended._make_input_context()) + if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/input_ops.py b/tensorflow/python/distribute/input_ops.py index 37a7ed642d0..de828f4bcd9 100644 --- a/tensorflow/python/distribute/input_ops.py +++ b/tensorflow/python/distribute/input_ops.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import ops # pylint: disable=protected-access -def auto_shard_dataset(dataset, num_shards, index): +def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None): """Shard the input pipeline by sharding the underlying list of files. Args: @@ -37,6 +37,8 @@ def auto_shard_dataset(dataset, num_shards, index): shards operating in parallel. Same usage as in `tf.data.Dataset.shard`. index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. Same usage as in `tf.data.Dataset.shard`. + num_replicas_in_sync: An integer representing the total number of replicas + across all workers. This is used in the rewrite when sharding by data. Returns: A modified `Dataset` obtained by updating the pipeline sharded by the @@ -45,10 +47,14 @@ def auto_shard_dataset(dataset, num_shards, index): """ if (dataset.options().experimental_distribute.auto_shard_policy != AutoShardPolicy.OFF): + if num_replicas_in_sync is None: + num_replicas_in_sync = 1 if isinstance(dataset, dataset_ops.DatasetV1): - return distribute._AutoShardDatasetV1(dataset, num_shards, index) + return distribute._AutoShardDatasetV1(dataset, num_shards, index, + num_replicas_in_sync) else: - return distribute._AutoShardDataset(dataset, num_shards, index) + return distribute._AutoShardDataset(dataset, num_shards, index, + num_replicas_in_sync) else: return dataset diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD index c1849afcb70..d3a71e0136d 100644 --- a/tensorflow/python/distribute/integration_test/BUILD +++ b/tensorflow/python/distribute/integration_test/BUILD @@ -9,7 +9,6 @@ package( distribute_py_test( name = "saved_model_test", srcs = ["saved_model_test.py"], - disable_mlir_bridge = False, deps = [ "//tensorflow:tensorflow_py", "//tensorflow/python:lookup_ops", @@ -18,6 +17,7 @@ distribute_py_test( "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -40,6 +40,7 @@ cuda_py_test( "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:multi_process_runner", "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/eager:test", ], ) diff --git a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py index 6d822ca1b97..02dee6f6adb 100644 --- a/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py +++ b/tensorflow/python/distribute/integration_test/mwms_peer_failure_test.py @@ -27,9 +27,9 @@ import os import tensorflow as tf from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib -from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.distribute import test_util from tensorflow.python.eager import test @@ -213,4 +213,4 @@ class PeerFailureRecoverTest(test.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py index 704279b4b04..b8ac71ef203 100644 --- a/tensorflow/python/distribute/integration_test/saved_model_test.py +++ b/tensorflow/python/distribute/integration_test/saved_model_test.py @@ -39,6 +39,7 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute import sharded_variable from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -612,4 +613,4 @@ class PSStrategySaveAndLoadTest(test.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/mirrored_run.py b/tensorflow/python/distribute/mirrored_run.py index 2cf23e96e67..4f1f48d30cc 100644 --- a/tensorflow/python/distribute/mirrored_run.py +++ b/tensorflow/python/distribute/mirrored_run.py @@ -246,6 +246,9 @@ class _MirroredReplicaThread(threading.Thread): self.distribution = dist self.devices = devices self.replica_id = replica_id + self.replica_id_in_sync_group = ( + dist.extended._get_replica_id_in_sync_group(replica_id)) # pylint: disable=protected-access + self.variable_creator_fn = variable_creator_fn # State needed to run and return the results of `fn`. self.main_fn = fn @@ -310,7 +313,8 @@ class _MirroredReplicaThread(threading.Thread): _enter_graph(self.graph, self.in_eager, self._variable_creator_stack), \ context.device_policy(self.context_device_policy), \ - _MirroredReplicaContext(self.distribution, self.replica_id), \ + _MirroredReplicaContext(self.distribution, + self.replica_id_in_sync_group), \ ops.device(self.devices[self.replica_id]), \ ops.name_scope(self._name_scope), \ variable_scope.variable_scope( diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 523c71c4fb5..658b45ecec6 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -340,7 +340,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): self._input_workers_devices = ( (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) self._inferred_cross_device_ops = None if self._cross_device_ops else ( - cross_device_ops_lib.choose_the_best(devices)) + cross_device_ops_lib.select_cross_device_ops(devices)) self._host_input_device = numpy_dataset.SingleDevice( self._input_workers_devices[0][0]) self._is_multi_worker_training = False @@ -387,7 +387,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): "supported.") self._inferred_cross_device_ops = self._cross_device_ops else: - # TODO(yuefengz): make `choose_the_best` work with device strings + # TODO(yuefengz): make `select_cross_device_ops` work with device strings # containing job names. self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() @@ -479,7 +479,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -501,7 +501,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers_with_options(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _experimental_make_numpy_dataset(self, numpy_input, session): return numpy_dataset.one_host_numpy_dataset( @@ -632,6 +632,17 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): del value # Unused. return self._cross_device_ops or self._inferred_cross_device_ops + def _gather_to_implementation(self, value, destinations, axis, + experimental_hints): + if not isinstance(value, values.DistributedValues): + # ReductionToOneDevice._gather accepts DistributedValues only. + return value + return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access + value, + destinations=destinations, + axis=axis, + experimental_hints=experimental_hints) + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (distribute_utils.is_mirrored(value) and reduce_op == reduce_util.ReduceOp.MEAN): @@ -758,3 +769,9 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): def _in_multi_worker_mode(self): """Whether this strategy indicates working in multi-worker settings.""" return False + + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + + def _get_replica_id_in_sync_group(self, replica_id): + return replica_id diff --git a/tensorflow/python/distribute/moving_averages_test.py b/tensorflow/python/distribute/moving_averages_test.py index 577a6c1168f..22522ea2389 100644 --- a/tensorflow/python/distribute/moving_averages_test.py +++ b/tensorflow/python/distribute/moving_averages_test.py @@ -23,6 +23,7 @@ from absl.testing import parameterized from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import def_function from tensorflow.python.eager import test @@ -286,4 +287,4 @@ class ExponentialMovingAverageTest(test.TestCase, parameterized.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/multi_process_runner.py b/tensorflow/python/distribute/multi_process_runner.py index 62b35070502..107aa1a6a48 100644 --- a/tensorflow/python/distribute/multi_process_runner.py +++ b/tensorflow/python/distribute/multi_process_runner.py @@ -39,6 +39,7 @@ from tensorflow.python import tf2 from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import multi_process_lib from tensorflow.python.eager import context +from tensorflow.python.util.tf_export import tf_export multiprocessing = multi_process_lib.multiprocessing @@ -123,49 +124,51 @@ class MultiProcessRunner(object): """ def __init__(self, - proc_func, + fn, cluster_spec, rpc_layer=None, max_run_time=None, grpc_fail_fast=None, - stream_stdout=True, - list_stdout=False, + stream_output=True, + return_output=False, use_dill_for_args=True, daemon=False, dependence_on_chief=True, auto_restart=False, args=None, kwargs=None): - """Creates a multi-process runner. + """Instantiation of a `MultiProcessRunner`. Args: - proc_func: Function to be run on child processes. This will be run on - processes for all task types. - cluster_spec: Dict for cluster spec. The following is an example of - cluster with three workers and two ps's. + fn: Function to be run on child processes. This will be run on processes + for all task types. + cluster_spec: Dict for cluster spec. The utility function + `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` + can be conveniently used to create such dict. The following is an + example of cluster with three workers and two ps's. {"worker": ["worker0.example.com:2222", "worker1.example.com:2222", "worker2.example.com:2222"], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]} rpc_layer: RPC layer to use. Default value is 'grpc'. - max_run_time: If set, child processes is forced to exit at approximately - this many seconds after `start` is called. We achieve this through - `signal.alarm()` api. Note that this is best effort at Python level - since Python signal handler does not get executed when it runs lower - level C/C++ code. So it can be delayed for arbitrarily long time. - If any of the child process is still running when `max_run_time` is up, - they will be force-terminated and a `UnexpectedSubprocessExitError` - may be raised at `join()`. + max_run_time: `None` or integer. If not `None`, child processes are forced + to exit at approximately this many seconds after this utility is called. + We achieve this through `signal.alarm()` api. Note that this is best + effort at Python level since Python signal handler does not get executed + when it runs lower level C/C++ code. So it can be delayed for + arbitrarily long time. If any of the child process is still running when + `max_run_time` is up, they will be force-terminated and an + `UnexpectedSubprocessExitError` may be raised. If `None`, child + processes are not forced to exit. grpc_fail_fast: Whether GRPC connection between processes should fail without retrying. Defaults to None, in which case the environment variable is not explicitly set. - stream_stdout: True if the output/error from the subprocesses should be + stream_output: True if the output/error from the subprocesses should be streamed to be printed in parent process' log. Defaults to True. - list_stdout: True if the output/error from the subprocesses should be - collected to be attached to the resulting `MultiProcessRunnerResult` - returned from `MultiProcessRunner.join()`. If True, the list of stdout - can be retrieved via `MultiProcessRunnerResult.stdout` attribute. + return_output: If True, the output/error from the subprocesses should be + collected to be attached to the resulting namedtuple returned from + `join()`. The list of output can be retrieved via `stdout` attribute. Defaults to False. use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill can pickle more objects, but doesn't work with types in @@ -176,36 +179,37 @@ class MultiProcessRunner(object): exits with a zero exit code. auto_restart: Whether to automatically restart processes that exit with non-zero exit code. - args: Positional arguments to be sent to functions run on processes. - kwargs: Keyword arguments to be sent to functions run on processes. + args: Positional arguments to be sent to `fn` run on subprocesses. + kwargs: Keyword arguments to be sent to `fn` run on subprocesses. Raises: RuntimeError: if `multi_process_runner.test_main()` is not called. ValueError: if there are more than one chief in the `cluster_spec`. """ + assert cluster_spec is not None if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1: raise ValueError('If chief exists in the cluster, there must be at most ' 'one chief. Current `cluster_spec` has {} chiefs.' .format(len(cluster_spec['chief']))) if not multi_process_lib.initialized(): - raise MultiProcessRunnerNotInitializedError( + raise NotInitializedError( '`multi_process_runner` is not initialized. ' - 'Please call `multi_process_runner.test_main()` ' - 'within `if __name__ == \'__main__\':` block ' + 'Please call `tf.__internal__.distribute.multi_process_runner.' + 'test_main()` within `if __name__ == \'__main__\':` block ' 'in your python module to properly initialize ' '`multi_process_runner`.') - if not callable(proc_func): - raise ValueError('proc_func is not a callable') + if not callable(fn): + raise ValueError('fn is not a callable') - self._proc_func = proc_func + self._fn = fn self._cluster_spec = cluster_spec self._rpc_layer = rpc_layer or 'grpc' self._max_run_time = max_run_time self._grpc_fail_fast = grpc_fail_fast - self._stream_stdout = stream_stdout - # TODO(rchao): Revisit list_stdout argument to consider other solution. - self._list_stdout = list_stdout + self._stream_output = stream_output + # TODO(rchao): Revisit return_output argument to consider other solution. + self._return_output = return_output self._dependence_on_chief = dependence_on_chief self._use_dill_for_args = use_dill_for_args self._daemon = daemon @@ -251,18 +255,18 @@ class MultiProcessRunner(object): for line in reader: task_string = '[{}-{}]:'.format(task_type, task_id) formatted_line = '{} {}'.format(task_string.ljust(14), line) - if self._stream_stdout: + if self._stream_output: # TODO(rchao): Use a lock here to ensure the printed lines are not # broken. print(formatted_line, end='', flush=True) - if self._list_stdout: + if self._return_output: self._streaming_queue.put(formatted_line) def _start_subprocess_and_reading_thread(self, task_type, task_id, cluster_spec=None, - proc_func=None, + fn=None, args=None, kwargs=None): """Start a subprocess and a thread the reads lines from the subprocess.""" @@ -287,11 +291,11 @@ class MultiProcessRunner(object): streaming_pipe_w=pipe_w, barrier=self._barrier, ) - if proc_func is None: - proc_func, args, kwargs = self._proc_func, self._args, self._kwargs - # Always use dill to pickle proc_func so that we support more callable + if fn is None: + fn, args, kwargs = self._fn, self._args, self._kwargs + # Always use dill to pickle fn so that we support more callable # types, e.g. lambda. - proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL) + fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) if self._use_dill_for_args: args = dill.dumps(args, dill.HIGHEST_PROTOCOL) kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL) @@ -299,8 +303,7 @@ class MultiProcessRunner(object): p = _Process( test_env=test_env, target=_ProcFunc(), - args=(resources, test_env, proc_func, args, kwargs, - self._use_dill_for_args), + args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args), daemon=self._daemon) p.start() self._processes[(task_type, task_id)] = p @@ -357,7 +360,7 @@ class MultiProcessRunner(object): called: ```python - def proc_func(): + def fn(): # user code to be run import pdb; pdb.set_trace() @@ -368,7 +371,7 @@ class MultiProcessRunner(object): task_id=0) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=1)) threading.Thread(target=follow_ups).start() @@ -376,7 +379,7 @@ class MultiProcessRunner(object): mpr.join() ``` - Note that if `list_stdout=True`, the logs/stdout by task + Note that if `return_output=True`, the logs/stdout by task run by the main process is not available in result.stdout. Args: @@ -396,19 +399,19 @@ class MultiProcessRunner(object): _set_tf_config(as_task_type, as_task_id, self._cluster_spec, self._rpc_layer) - self._proc_func(*self._args, **self._kwargs) + self._fn(*self._args, **self._kwargs) def start_single_process(self, task_type, task_id, cluster_spec=None, - proc_func=None, + fn=None, args=None, kwargs=None): """Starts a single process. This starts a process in the cluster with the task type, task id, and the - process function (`proc_func`). If process function is `None`, the function + process function (`fn`). If process function is `None`, the function provided at `__init__` will be used. If `cluster_spec` is `None`, the cluster spec provided at `__init__` will be used. @@ -422,11 +425,11 @@ class MultiProcessRunner(object): cluster_spec: The cluster spec to be used on the newly started process. If `None`, the cluster spec provided at `__init__` will be used. - proc_func: The process function to be run on the newly started + fn: The process function to be run on the newly started process. If specified, specify `args` and `kwargs` as well. If `None`, the function provided at `__init__` will be used. - args: Optional positional arguments to be supplied in `proc_func`. - kwargs: Optional keyword arguments to be supplied in `proc_func`. + args: Optional positional arguments to be supplied in `fn`. + kwargs: Optional keyword arguments to be supplied in `fn`. """ with self._process_lock: if self._joined: @@ -436,7 +439,7 @@ class MultiProcessRunner(object): task_type, task_id, cluster_spec=cluster_spec, - proc_func=proc_func, + fn=fn, args=args or (), kwargs=kwargs or {}) @@ -570,11 +573,11 @@ class MultiProcessRunner(object): times out. Returns: - A MultiProcessRunnerResult object, which has two attributes, - `return_value` and `stdout`. `return_value` always contains the return - values from the subprocesses. If `list_stdout` argument is True at - `__init__`, `stdout` is available that contains a list of all messages - from subprocesses' stdout and stderr. + A `MultiProcessRunnerResult` object, which has two attributes, + `return_value` and `stdout`. `return_value` always contains a list of + return values from the subprocesses, although the order is not meaningful. + If `return_output` argument is True at `__init__`, `stdout` is available + that contains a list of all messages from subprocesses' stdout and stderr. Raises: SubprocessTimeoutError: if not all processes report status approximately @@ -788,15 +791,14 @@ class _ProcFunc(object): sys.stderr.close() self._resources.streaming_pipe_w.close() - def __call__(self, resources, test_env, proc_func, args, kwargs, - use_dill_for_args): + def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args): """The wrapper function that actually gets run in child process(es).""" global _barrier self._resources = resources _barrier = self._resources.barrier - proc_func = dill.loads(proc_func) + fn = dill.loads(fn) if use_dill_for_args: args = dill.loads(args) kwargs = dill.loads(kwargs) @@ -830,8 +832,8 @@ class _ProcFunc(object): v2_compat.enable_v2_behavior() with self._runtime_mode(test_env.executing_eagerly): - info = _run_contained(test_env.task_type, test_env.task_id, proc_func, - args, kwargs) + info = _run_contained(test_env.task_type, test_env.task_id, fn, args, + kwargs) self._resources.process_status_queue.put(info) # Re-raise the exception in addition to reporting it to the parent @@ -912,14 +914,14 @@ class MultiProcessPoolRunner(object): def _start(self): """Starts the worker pool.""" # We need different arguments for different processes so we're passing a - # no-op proc_func here and use start_single_process instead. + # no-op fn here and use start_single_process instead. if dill is None: raise unittest.SkipTest( 'TODO(b/150264776): Resolve dependency issue in CI') self._runner = MultiProcessRunner( - proc_func=lambda: None, + fn=lambda: None, cluster_spec=self._cluster_spec, use_dill_for_args=False) if self._initializer: @@ -933,20 +935,20 @@ class MultiProcessPoolRunner(object): self._runner.start_single_process( task_type, task_id, - proc_func=_pool_runner_worker, + fn=_pool_runner_worker, args=(task_type, task_id, initializer, conn2)) # In the case MultiProcessPoolRunner is not GC-ed, we register an atexit # callback to shut them down. For example, when there're global # MultiProcessPoolRunner. atexit.register(_shutdown_all_pool_runners) - def run(self, proc_func, args=None, kwargs=None): - """Runs `proc_func` with `args` and `kwargs` on all jobs. + def run(self, fn, args=None, kwargs=None): + """Runs `fn` with `args` and `kwargs` on all jobs. Args: - proc_func: The function to be run. - args: Optional positional arguments to be supplied in `proc_func`. - kwargs: Optional keyword arguments to be supplied in `proc_func`. + fn: The function to be run. + args: Optional positional arguments to be supplied in `fn`. + kwargs: Optional keyword arguments to be supplied in `fn`. Returns: A list of return values. @@ -956,9 +958,9 @@ class MultiProcessPoolRunner(object): if self._runner is None: self._start() - proc_func = dill.dumps(proc_func, dill.HIGHEST_PROTOCOL) + fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) for conn in self._conn.values(): - conn.send((proc_func, args or [], kwargs or {})) + conn.send((fn, args or [], kwargs or {})) process_statuses = [] for (task_type, task_id), conn in self._conn.items(): @@ -966,7 +968,7 @@ class MultiProcessPoolRunner(object): try: process_statuses.append(conn.recv()) except EOFError: - # This shouldn't happen due to exceptions in proc_func. This usually + # This shouldn't happen due to exceptions in fn. This usually # means bugs in the runner. self.shutdown() raise RuntimeError('Unexpected EOF. Worker process may have died. ' @@ -1002,11 +1004,11 @@ def _pool_runner_worker(task_type, task_id, initializer, conn): initializer() while True: try: - proc_func, args, kwargs = conn.recv() + fn, args, kwargs = conn.recv() except EOFError: break - proc_func = dill.loads(proc_func) - info = _run_contained(task_type, task_id, proc_func, args, kwargs) + fn = dill.loads(fn) + info = _run_contained(task_type, task_id, fn, args, kwargs) sys.stdout.flush() sys.stderr.flush() conn.send(info) @@ -1018,8 +1020,8 @@ def _pool_runner_worker(task_type, task_id, initializer, conn): _shutdown_all_pool_runners() -def _run_contained(task_type, task_id, proc_func, args, kwargs): - """Runs `proc_func` with `args` and `kwargs`. +def _run_contained(task_type, task_id, fn, args, kwargs): + """Runs `fn` with `args` and `kwargs`. The function returns _ProcessStatusInfo which captures the return value and the exception. @@ -1027,9 +1029,9 @@ def _run_contained(task_type, task_id, proc_func, args, kwargs): Args: task_type: the task type. task_id: the task index. - proc_func: the function to be run. - args: optional positional arguments to be supplied in `proc_func`. - kwargs: optional keyword arguments to be supplied in `proc_func`. + fn: the function to be run. + args: optional positional arguments to be supplied in `fn`. + kwargs: optional keyword arguments to be supplied in `fn`. Returns: a _ProcessStatusInfo. @@ -1039,7 +1041,7 @@ def _run_contained(task_type, task_id, proc_func, args, kwargs): return_value = None exc_info = None try: - return_value = proc_func(*args, **kwargs) + return_value = fn(*args, **kwargs) is_successful = True return _ProcessStatusInfo( task_type=task_type, @@ -1048,7 +1050,7 @@ def _run_contained(task_type, task_id, proc_func, args, kwargs): exc_info=exc_info, return_value=return_value) - # If `proc_func` ends up exiting with `sys.exit()`, the `SystemExit` is not + # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not # handled here. except Exception: # pylint: disable=broad-except exc_info = sys.exc_info() @@ -1060,12 +1062,17 @@ def _run_contained(task_type, task_id, proc_func, args, kwargs): return_value=return_value) +@tf_export('__internal__.distribute.multi_process_runner' + '.SubprocessTimeoutError', + v1=[]) class SubprocessTimeoutError(RuntimeError): """An error that indicates there is at least one subprocess timing out. - When this is raised, a `MultiProcessRunnerResult` object can be retrieved by - `SubprocessTimeoutError`'s mpr_result attribute. See - `MultiProcessRunner.join()` for more information. + When this is raised, a namedtuple object representing the multi-process run + result can be retrieved by + `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s + `mpr_result` attribute. See + `tf.__internal__.distribute.multi_process_runner.run` for more information. """ def __init__(self, msg, mpr_result): @@ -1073,12 +1080,18 @@ class SubprocessTimeoutError(RuntimeError): self.mpr_result = mpr_result +@tf_export('__internal__.distribute.multi_process_runner' + '.UnexpectedSubprocessExitError', + v1=[]) class UnexpectedSubprocessExitError(RuntimeError): """An error indicating there is at least one subprocess with unexpected exit. - When this is raised, a `MultiProcessRunnerResult` object can be retrieved by - `UnexpectedSubprocessExitError`'s mpr_result attribute. See - `MultiProcessRunner.join()` for more information. + When this is raised, a namedtuple object representing the multi-process run + result can be retrieved by + `tf.__internal__.distribute.multi_process_runner + .UnexpectedSubprocessExitError`'s + `mpr_result` attribute. See + `tf.__internal__.distribute.multi_process_runner.run` for more information. """ def __init__(self, msg, mpr_result): @@ -1086,12 +1099,15 @@ class UnexpectedSubprocessExitError(RuntimeError): self.mpr_result = mpr_result -class MultiProcessRunnerNotInitializedError(RuntimeError): - """An error indicating `MultiProcessRunner` is used without initialization. +@tf_export( + '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[]) +class NotInitializedError(RuntimeError): + """An error indicating `multi_process_runner.run` is used without init. When this is raised, user is supposed to call - `multi_process_runner.test_main()` within `if __name__ == '__main__':` block - to properly initialize `multi_process_runner`. + `tf.__internal__.distribute.multi_process_runner.test_main()` within + `if __name__ == '__main__':` block to properly initialize + `multi_process_runner.run`. """ pass @@ -1110,33 +1126,177 @@ def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) -def run(proc_func, +@tf_export('__internal__.distribute.multi_process_runner.run', v1=[]) +def run(fn, cluster_spec, rpc_layer=None, max_run_time=None, - grpc_fail_fast=None, - stream_stdout=True, - list_stdout=False, + return_output=False, timeout=_DEFAULT_TIMEOUT_SEC, args=None, - kwargs=None): # pylint: disable=g-doc-args - """Runs functions in local child processes. + kwargs=None): + """Run `fn` in multiple processes according to `cluster_spec`. - It is a convenience method that creates a `MultiProcessRunner` object and - invokes `start` and `join` method. Please see these methods for detailed - documentations. + Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run` + launches multiple processes, each of which runs `fn`. These processes are + referred to as "subprocesses" or "child processes". Each of those subprocesses + will have their `TF_CONFIG` environment variable set, according to + `cluster_spec` and their task types. The stdout of the subprocesses are + streamed to the main process' and thus available in logs (if `stream_output` + is True), with [type-id] prefix. + + `tf.__internal__.distribute.multi_process_runner.run` will block until all + subprocesses have successfully exited, and return a namedtuple object that + represents the run result. This object has a `return_value` attribute, which + is a list that contains subprocesses `fn`'s return values, for those + subprocesses that successfully returned from `fn`. The order of `return_value` + list is not meaningful. If an optional arg `return_output` (default to False) + is set to True, the namedtuple object will have an additional attribute + `stdout`, which is a list containing the stdout of the subprocesses. If any + subprocess' `fn` ends up raising an error, that error will be reraised from + `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned + namedtuple object will be available through the exception's + `mpr_result` attribute. + + This utility is used for simulating running TensorFlow programs across + multiple task types, and each of the task type may contain more than one task + (except for "chief" where more than one task is prohibited). Test coverage of + multi-worker training is the main application of this utility, where code + written for multi-worker training can be realistically covered in unit tests. + + Any test module that uses + `tf.__internal__.distribute.multi_process_runner.run()` must call + `tf.__internal__.distribute.multi_process_runner.test_main()` instead of + regular `test.main()` inside `if __name__ == '__main__':` block for proper + initialization. + + Args: + fn: Function to be run on child processes. This will be run on processes for + all task types. + cluster_spec: Dict for cluster spec. The utility function + `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can + be conveniently used to create such dict. The following is an example of + cluster with three workers and two ps's. + {"worker": ["worker0.example.com:2222", + "worker1.example.com:2222", + "worker2.example.com:2222"], + "ps": ["ps0.example.com:2222", + "ps1.example.com:2222"]} + rpc_layer: RPC layer to use. Default value is 'grpc'. + max_run_time: `None` or integer. If not `None`, child processes are forced + to exit at approximately this many seconds after this utility is called. + We achieve this through `signal.alarm()` api. Note that this is best + effort at Python level since Python signal handler does not get executed + when it runs lower level C/C++ code. So it can be delayed for arbitrarily + long time. If any of the child process is still running when + `max_run_time` is up, they will be force-terminated and an + `tf.__internal__.distribute.multi_process_runner + .UnexpectedSubprocessExitError` + may be raised. If `None`, child processes are not forced to exit. + return_output: If True, the output/error from the subprocesses should be + collected to be attached to the resulting namedtuple returned from this + utility. The list of output can be retrieved via `stdout` attribute. + Defaults to False. + timeout: optional integer or `None`. If provided as an integer, and not all + processes report status within roughly `timeout` seconds, a + `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError` + exception will be raised. If `None`, + `tf.__internal__.distribute.multi_process_runner.run` never times out. + Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in + `multi_process_runner` module. + args: Positional arguments to be sent to `fn` run on subprocesses. + kwargs: Keyword arguments to be sent to `fn` run on subprocesses. Returns: - A MultiProcessRunnerResult object returned from `MultiProcessRunner.join()`. + A namedtuple object, which has two attributes, + `return_value` and `stdout`. `return_value` always contains a list of + returnvalues from the subprocesses, although the order is not meaningful. + If `return_output` argument is True, `stdout` is available that contains a + list of all messages from subprocesses' stdout and stderr, and the order + is mostly chronological. + + Raises: + RuntimeError: if + `tf.__internal__.distribute.multi_process_runner.test_main()` is + not called in test's `if __name__ == '__main__':` block. + ValueError: if there are more than one chief in the `cluster_spec`. + tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if + not all processes report status approximately + within `timeout` seconds. When this is raised, a + namedtuple object can be retrieved by + `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s + `mpr_result` attribute, which has the same + structure as above 'Returns' section describes. + tf.__internal__.distribute.multi_process_runner + .UnexpectedSubprocessExitError: + If any of the subprocesses did not exit + properly (for example, they exit on SIGTERM or SIGKILL signal). When + this is raised, a namedtuple object can be retrieved by + `tf.__internal__.distribute.multi_process_runner + .UnexpectedSubprocessExitError`'s + `mpr_result` attribute, which has the + same structure as above 'Returns' section describes. If `max_run_time` + is not `None`, it is expected that some subprocesses may be + force-killed when `max_run_time` is up, and this is raised in those + cases. + Exception: if there is an Exception propagated from any subprocess. When + this is raised, a namedtuple object can be retrieved by + `tf.__internal__.distribute.multi_process_runner + .UnexpectedSubprocessExitError` + `mpr_result` attribute, which has the + same structure as above 'Returns' section describes. + + Examples: + + ```python + class SimpleMultiProcessTest(tf.test.TestCase): + + def test_simple_printing_and_return(self): + + def fn(): + resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() + + # This will print "[chief-0]: Task type: chief , task id: 0" + # for chief, for example. + logging.info('Task type: %s, task id: %d', + resolver.task_type, resolver.task_id) + + return resolver.task_type + + result = tf.__internal__.distribute.multi_process_runner.run( + fn=fn, + cluster_spec=( + tf.__internal__ + .distribute.multi_process_runner.create_cluster_spec( + has_chief=True, num_workers=2))) + assert sorted(result.return_value) == ['chief', 'worker', 'worker'] + + def test_error_from_fn(self): + + def fn(): + resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() + raise ValueError('Task type {}, task id {} is errors out'.format( + resolver.task_type, resolver.task_id)) + + with self.assertRaisesRegexp(ValueError, + 'Task type worker, task id 0 is errors out'): + cluster_spec = ( + tf.__internal__.distribute.multi_process_runner.create_cluster_spec( + num_workers=1)) + tf.__internal__.distribute.multi_process_runner.run( + fn=fn, cluster_spec=cluster_spec) + + + if __name__ == '__main__': + tf.__internal__.distribute.multi_process_runner.test_main() + ``` """ runner = MultiProcessRunner( - proc_func, + fn, cluster_spec, rpc_layer, max_run_time=max_run_time, - grpc_fail_fast=grpc_fail_fast, - stream_stdout=stream_stdout, - list_stdout=list_stdout, + return_output=return_output, args=args, kwargs=kwargs) runner.start() @@ -1147,11 +1307,49 @@ def run(proc_func, _barrier = None -def barrier(): +@tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[]) +def get_barrier(): + """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`. + + `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns + a `multiprocessing.Barrier` object which can be used within `fn` of + `tf.__internal__.distribute.multi_process_runner` to wait with + `barrier.wait()` call until all other tasks have also reached the + `barrier.wait()` call, before they can proceed individually. + + Note that all tasks (subprocesses) have to reach `barrier.wait()` call to + proceed. Currently it is not supported to block on only a subset of tasks + in the cluster. + + Example: + ```python + + def fn(): + some_work_to_be_done_by_all_tasks() + + tf.__internal__.distribute.multi_process_runner.get_barrier().wait() + + # The barrier guarantees that at this point, all tasks have finished + # `some_work_to_be_done_by_all_tasks()` + some_other_work_to_be_done_by_all_tasks() + + result = tf.__internal__.distribute.multi_process_runner.run( + fn=fn, + cluster_spec=( + tf.__internal__ + .distribute.multi_process_runner.create_cluster_spec( + num_workers=2))) + ``` + + + Returns: + A `multiprocessing.Barrier` for `multi_process_runner.run`. + """ if _barrier is None: raise ValueError( - 'barrier is not defined. It is likely because you are calling barrier()' - 'in the main process. barrier() can only be called in the subprocesses.' + 'barrier is not defined. It is likely because you are calling ' + 'get_barrier() in the main process. get_barrier() can only be called ' + 'in the subprocesses.' ) return _barrier @@ -1170,7 +1368,7 @@ def manager(): ```python manager = multi_process_runner.manager() some_event_happening_in_subprocess = manager.Event() - mpr = multi_process_runner.MultiProcessRunner(proc_func, cluster_spec, + mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec, args=(some_event_happening_in_subprocess,)) mpr.start() some_event_happening_in_subprocess.wait() @@ -1191,6 +1389,27 @@ def manager(): return _manager +@tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[]) def test_main(): - """Main function to be called within `__main__` of a test file.""" + """Main function to be called within `__main__` of a test file. + + Any test module that uses + `tf.__internal__.distribute.multi_process_runner.run()` + must call this instead of regular `test.main()` inside + `if __name__ == '__main__':` block, or an error will be raised when + `tf.__internal__.distribute.multi_process_runner.run()` is used. This method + takes + care of needed initialization for launching multiple subprocesses. + + Example: + ```python + class MyTestClass(tf.test.TestCase): + def testSomething(self): + # Testing code making use of + # `tf.__internal__.distribute.multi_process_runner.run()`. + + if __name__ == '__main__': + tf.__internal__.distribute.multi_process_runner.test_main() + ``` + """ multi_process_lib.test_main() diff --git a/tensorflow/python/distribute/multi_process_runner_no_init_test.py b/tensorflow/python/distribute/multi_process_runner_no_init_test.py index 9276555c26b..dad05fc386f 100644 --- a/tensorflow/python/distribute/multi_process_runner_no_init_test.py +++ b/tensorflow/python/distribute/multi_process_runner_no_init_test.py @@ -30,9 +30,8 @@ class MultiProcessRunnerNoInitTest(test.TestCase): def simple_func(): return 'foobar' - with self.assertRaisesRegex( - multi_process_runner.MultiProcessRunnerNotInitializedError, - '`multi_process_runner` is not initialized.'): + with self.assertRaisesRegex(multi_process_runner.NotInitializedError, + '`multi_process_runner` is not initialized.'): multi_process_runner.run( simple_func, multi_worker_test_base.create_cluster_spec(num_workers=1)) diff --git a/tensorflow/python/distribute/multi_process_runner_test.py b/tensorflow/python/distribute/multi_process_runner_test.py index 2a0c0c89676..c164bd490e3 100644 --- a/tensorflow/python/distribute/multi_process_runner_test.py +++ b/tensorflow/python/distribute/multi_process_runner_test.py @@ -33,38 +33,38 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.eager import test -def proc_func_that_adds_task_type_in_return_data(): +def fn_that_adds_task_type_in_return_data(): return multi_worker_test_base.get_task_type() -def proc_func_that_errors(): +def fn_that_errors(): raise ValueError('This is an error.') -def proc_func_that_does_nothing(): +def fn_that_does_nothing(): pass -def proc_func_that_adds_simple_return_data(): +def fn_that_adds_simple_return_data(): return 'dummy_data' -def proc_func_that_returns_args_and_kwargs(*args, **kwargs): +def fn_that_returns_args_and_kwargs(*args, **kwargs): return list(args) + list(kwargs.items()) -def proc_func_with_barrier(): - return multi_process_runner.barrier() +def fn_with_barrier(): + return multi_process_runner.get_barrier() -def proc_func_that_returns_pid(): +def fn_that_returns_pid(): return os.getpid() V = None -def proc_func_that_sets_global(val): +def fn_that_sets_global(val): global V old_val = V V = val @@ -79,21 +79,21 @@ class MultiProcessRunnerTest(test.TestCase): def test_multi_process_runner(self): mpr_result = multi_process_runner.run( - proc_func_that_adds_task_type_in_return_data, + fn_that_adds_task_type_in_return_data, multi_worker_test_base.create_cluster_spec( - num_workers=2, num_ps=3, has_eval=1)) + num_workers=2, num_ps=3, has_chief=True)) - job_count_dict = {'worker': 2, 'ps': 3, 'evaluator': 1} + job_count_dict = {'worker': 2, 'ps': 3, 'chief': 1} for data in mpr_result.return_value: job_count_dict[data] -= 1 self.assertEqual(job_count_dict['worker'], 0) self.assertEqual(job_count_dict['ps'], 0) - self.assertEqual(job_count_dict['evaluator'], 0) + self.assertEqual(job_count_dict['chief'], 0) def test_multi_process_runner_error_propagates_from_subprocesses(self): runner = multi_process_runner.MultiProcessRunner( - proc_func_that_errors, + fn_that_errors, multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1), max_run_time=20) runner.start() @@ -102,18 +102,18 @@ class MultiProcessRunnerTest(test.TestCase): def test_multi_process_runner_queue_emptied_between_runs(self): cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) - return_value = multi_process_runner.run( - proc_func_that_adds_simple_return_data, cluster_spec).return_value + return_value = multi_process_runner.run(fn_that_adds_simple_return_data, + cluster_spec).return_value self.assertTrue(return_value) self.assertEqual(return_value[0], 'dummy_data') self.assertEqual(return_value[1], 'dummy_data') - return_value = multi_process_runner.run(proc_func_that_does_nothing, + return_value = multi_process_runner.run(fn_that_does_nothing, cluster_spec).return_value self.assertFalse(return_value) def test_multi_process_runner_args_passed_correctly(self): return_value = multi_process_runner.run( - proc_func_that_returns_args_and_kwargs, + fn_that_returns_args_and_kwargs, multi_worker_test_base.create_cluster_spec(num_workers=1), args=('a', 'b'), kwargs={ @@ -132,7 +132,7 @@ class MultiProcessRunnerTest(test.TestCase): mpr_result = multi_process_runner.run( simple_print_func, multi_worker_test_base.create_cluster_spec(num_workers=2), - list_stdout=True) + return_output=True) std_stream_results = mpr_result.stdout return_value = mpr_result.return_value self.assertIn('[worker-0]: This is something printed.\n', @@ -143,16 +143,16 @@ class MultiProcessRunnerTest(test.TestCase): def test_termination(self): - def proc_func(): + def fn(): for i in range(0, 10): print( 'index {}, iteration {}'.format(self._worker_idx(), i), flush=True) time.sleep(5) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec(num_workers=2), - list_stdout=True) + return_output=True) mpr.start() time.sleep(5) mpr.terminate('worker', 0) @@ -169,16 +169,16 @@ class MultiProcessRunnerTest(test.TestCase): def test_termination_and_start_single_process(self): - def proc_func(): + def fn(): for i in range(0, 10): print( 'index {}, iteration {}'.format(self._worker_idx(), i), flush=True) time.sleep(1) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec(num_workers=2), - list_stdout=True) + return_output=True) mpr.start() time.sleep(3) mpr.terminate('worker', 0) @@ -196,7 +196,7 @@ class MultiProcessRunnerTest(test.TestCase): def test_streaming(self): - def proc_func(): + def fn(): for i in range(5): logging.info('(logging) %s-%d, i: %d', multi_worker_test_base.get_task_type(), self._worker_idx(), @@ -208,10 +208,10 @@ class MultiProcessRunnerTest(test.TestCase): time.sleep(1) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec( - has_chief=True, num_workers=2, num_ps=2, has_eval=True), - list_stdout=True) + has_chief=True, num_workers=2, num_ps=2), + return_output=True) mpr._dependence_on_chief = False mpr.start() @@ -221,7 +221,7 @@ class MultiProcessRunnerTest(test.TestCase): list_to_assert = mpr_result.stdout - for job in ['chief', 'evaluator']: + for job in ['chief']: for iteration in range(5): self.assertTrue( any('(logging) {}-0, i: {}'.format(job, iteration) in line @@ -249,17 +249,17 @@ class MultiProcessRunnerTest(test.TestCase): def test_start_in_process_as(self): - def proc_func(): + def fn(): for i in range(5): logging.info('%s-%d, i: %d', multi_worker_test_base.get_task_type(), self._worker_idx(), i) time.sleep(1) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=1), - list_stdout=True) + return_output=True) def eval_func(): time.sleep(1) @@ -278,9 +278,9 @@ class MultiProcessRunnerTest(test.TestCase): def test_terminate_all_does_not_ignore_error(self): mpr = multi_process_runner.MultiProcessRunner( - proc_func_that_errors, + fn_that_errors, multi_worker_test_base.create_cluster_spec(num_workers=2), - list_stdout=True) + return_output=True) mpr.start() time.sleep(60) mpr.terminate_all() @@ -289,26 +289,26 @@ class MultiProcessRunnerTest(test.TestCase): def test_barrier(self): multi_process_runner.run( - proc_func_with_barrier, + fn_with_barrier, cluster_spec=multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=1), ) def test_barrier_called_in_main_process(self): with self.assertRaises(ValueError): - multi_process_runner.barrier() + multi_process_runner.get_barrier() def test_stdout_available_when_timeout(self): - def proc_func(): + def fn(): logging.info('something printed') time.sleep(10000) # Intentionally make the test timeout. with self.assertRaises(multi_process_runner.SubprocessTimeoutError) as cm: mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec(num_workers=1), - list_stdout=True) + return_output=True) mpr.start() mpr.join(timeout=60) mpr.terminate_all() @@ -319,15 +319,15 @@ class MultiProcessRunnerTest(test.TestCase): def test_seg_fault_raises_error(self): - def proc_func_expected_to_seg_fault(): + def fn_expected_to_seg_fault(): ctypes.string_at(0) # Intentionally made seg fault. with self.assertRaises( multi_process_runner.UnexpectedSubprocessExitError) as cm: multi_process_runner.run( - proc_func_expected_to_seg_fault, + fn_expected_to_seg_fault, multi_worker_test_base.create_cluster_spec(num_workers=1), - list_stdout=True) + return_output=True) self.assertIn('Subprocess worker-0 exited with exit code', str(cm.exception)) list_to_assert = cm.exception.mpr_result.stdout @@ -335,7 +335,7 @@ class MultiProcessRunnerTest(test.TestCase): def test_seg_fault_in_chief_raises_error(self): - def proc_func_expected_to_seg_fault(): + def fn_expected_to_seg_fault(): if multi_worker_test_base.get_task_type() == 'worker': time.sleep(10000) ctypes.string_at(0) # Intentionally made seg fault. @@ -343,10 +343,10 @@ class MultiProcessRunnerTest(test.TestCase): with self.assertRaises( multi_process_runner.UnexpectedSubprocessExitError) as cm: multi_process_runner.run( - proc_func_expected_to_seg_fault, + fn_expected_to_seg_fault, multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=1), - list_stdout=True) + return_output=True) self.assertIn('Subprocess chief-0 exited with exit code', str(cm.exception)) list_to_assert = cm.exception.mpr_result.stdout @@ -354,13 +354,13 @@ class MultiProcessRunnerTest(test.TestCase): def test_exit_code_is_reported_by_chief_subprocess(self): - def proc_func_expected_to_exit_with_20(): + def fn_expected_to_exit_with_20(): if multi_worker_test_base.get_task_type() == 'worker': time.sleep(10000) sys.exit(20) mpr = multi_process_runner.MultiProcessRunner( - proc_func_expected_to_exit_with_20, + fn_expected_to_exit_with_20, multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=1)) mpr.start() @@ -372,11 +372,11 @@ class MultiProcessRunnerTest(test.TestCase): def test_exit_code_is_reported_by_subprocess(self): - def proc_func_expected_to_exit_with_10(): + def fn_expected_to_exit_with_10(): sys.exit(10) mpr = multi_process_runner.MultiProcessRunner( - proc_func_expected_to_exit_with_10, + fn_expected_to_exit_with_10, multi_worker_test_base.create_cluster_spec(num_workers=1)) mpr.start() @@ -387,7 +387,7 @@ class MultiProcessRunnerTest(test.TestCase): def test_auto_restart(self): - def proc_func(counter): + def fn(counter): counter.value += 1 if counter.value == 1: raise ValueError @@ -395,7 +395,7 @@ class MultiProcessRunnerTest(test.TestCase): manager = multi_process_runner.manager() counter = manager.Value(int, 0) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec(num_workers=1), args=(counter,), auto_restart=True) @@ -405,16 +405,16 @@ class MultiProcessRunnerTest(test.TestCase): def test_auto_restart_and_timeout(self): - def proc_func(): + def fn(): logging.info('Running') time.sleep(1) raise ValueError mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec(num_workers=1), auto_restart=True, - list_stdout=True) + return_output=True) mpr.start() with self.assertRaises(ValueError) as cm: mpr.join(timeout=10) @@ -425,14 +425,14 @@ class MultiProcessRunnerTest(test.TestCase): # If the chief has exited with zero exit code, auto restart should stop # restarting other tasks even if they fail. - def proc_func(): + def fn(): time.sleep(1) if multi_worker_test_base.get_task_type() != 'chief': raise ValueError manager = multi_process_runner.manager() mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=1), auto_restart=True) @@ -443,11 +443,11 @@ class MultiProcessRunnerTest(test.TestCase): def test_auto_restart_failure_immediate_after_restart(self): # Test the case when worker-0 fails immediately after worker-1 restarts. - def proc_func(): + def fn(): time.sleep(5) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec( has_chief=False, num_workers=2), auto_restart=True) @@ -462,7 +462,7 @@ class MultiProcessRunnerTest(test.TestCase): def test_auto_restart_terminate(self): # Tasks terminated by the user should also be restarted. - def proc_func(counter): + def fn(counter): counter.value += 1 if counter.value == 1: time.sleep(100) @@ -471,7 +471,7 @@ class MultiProcessRunnerTest(test.TestCase): counter = manager.Value(int, 0) mpr = multi_process_runner.MultiProcessRunner( - proc_func, + fn, multi_worker_test_base.create_cluster_spec( has_chief=False, num_workers=1), args=(counter,), @@ -484,14 +484,13 @@ class MultiProcessRunnerTest(test.TestCase): def test_error_reporting_overrides_timeout_reporting(self): - def proc_func(): + def fn(): if self._worker_idx() == 1: time.sleep(10000) raise ValueError('Worker 0 errored') mpr = multi_process_runner.MultiProcessRunner( - proc_func, - multi_worker_test_base.create_cluster_spec(num_workers=2)) + fn, multi_worker_test_base.create_cluster_spec(num_workers=2)) mpr.start() with self.assertRaisesRegex( @@ -501,12 +500,11 @@ class MultiProcessRunnerTest(test.TestCase): def test_process_exists(self): - def proc_func(): + def fn(): time.sleep(100000) mpr = multi_process_runner.MultiProcessRunner( - proc_func, - multi_worker_test_base.create_cluster_spec(num_workers=1)) + fn, multi_worker_test_base.create_cluster_spec(num_workers=1)) mpr.start() self.assertTrue(mpr.process_exists('worker', 0)) mpr.terminate('worker', 0) @@ -516,12 +514,12 @@ class MultiProcessRunnerTest(test.TestCase): def test_timeout_none(self): - def proc_func(): + def fn(): time.sleep(250) raise ValueError('Worker 0 errored') mpr = multi_process_runner.MultiProcessRunner( - proc_func, multi_worker_test_base.create_cluster_spec(num_workers=1)) + fn, multi_worker_test_base.create_cluster_spec(num_workers=1)) mpr.start() with self.assertRaisesRegex(ValueError, 'Worker 0 errored'): @@ -537,23 +535,23 @@ class MultiProcessPoolRunnerTest(test.TestCase): def test_same_process_across_runs(self): cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) - pid = runner.run(proc_func_that_returns_pid) + pid = runner.run(fn_that_returns_pid) for _ in range(3): - self.assertAllEqual(runner.run(proc_func_that_returns_pid), pid) + self.assertAllEqual(runner.run(fn_that_returns_pid), pid) def test_exceptions_in_sub_process(self): cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) - pid = runner.run(proc_func_that_returns_pid) + pid = runner.run(fn_that_returns_pid) with self.assertRaisesRegex(ValueError, 'This is an error.'): - runner.run(proc_func_that_errors) - self.assertAllEqual(runner.run(proc_func_that_returns_pid), pid) + runner.run(fn_that_errors) + self.assertAllEqual(runner.run(fn_that_returns_pid), pid) def test_tf_config(self): cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) - result = runner.run(proc_func_that_adds_task_type_in_return_data) + result = runner.run(fn_that_adds_task_type_in_return_data) job_count_dict = {'worker': 2, 'chief': 1} for data in result: @@ -570,27 +568,27 @@ class MultiProcessPoolRunnerTest(test.TestCase): cluster_spec = multi_worker_test_base.create_cluster_spec( has_chief=True, num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec) - runner.run(proc_func_that_returns_pid) + runner.run(fn_that_returns_pid) raise ValueError('failure') def test_initializer(self): cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) runner = multi_process_runner.MultiProcessPoolRunner( - cluster_spec, initializer=lambda: proc_func_that_sets_global(1)) - result = runner.run(proc_func_that_sets_global, args=(2,)) + cluster_spec, initializer=lambda: fn_that_sets_global(1)) + result = runner.run(fn_that_sets_global, args=(2,)) self.assertAllEqual(result, [1, 1]) def test_global_pool(self): - _global_pool.run(proc_func_that_does_nothing) + _global_pool.run(fn_that_does_nothing) def test_nested_pool(self): - def proc_func(): + def fn(): # This runs in sub processes, so they are each using their own # MultiProcessPoolRunner. - _global_pool.run(proc_func_that_does_nothing) + _global_pool.run(fn_that_does_nothing) - _global_pool.run(proc_func) + _global_pool.run(fn) if __name__ == '__main__': diff --git a/tensorflow/python/distribute/multi_worker_continuous_run_test.py b/tensorflow/python/distribute/multi_worker_continuous_run_test.py index 3cfbe022204..971c2c5a8c3 100644 --- a/tensorflow/python/distribute/multi_worker_continuous_run_test.py +++ b/tensorflow/python/distribute/multi_worker_continuous_run_test.py @@ -74,7 +74,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): def worker_step_fn(worker_id): strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() # Make sure the processeses are in sync after updating the cluster - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() @def_function.function def run_reduce(): @@ -107,7 +107,7 @@ class MultiWorkerContinuousRunTest(test.TestCase, parameterized.TestCase): def worker_step_fn(worker_id, num_dims): strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() # Make sure the processeses are in sync after updating the cluster - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() tensor_shape = [2] * num_dims def variable_fn(): diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py index b0c51f4767f..9e56e6d1bf7 100644 --- a/tensorflow/python/distribute/multi_worker_test_base.py +++ b/tensorflow/python/distribute/multi_worker_test_base.py @@ -32,7 +32,7 @@ import six _portpicker_import_error = None try: import portpicker # pylint: disable=g-import-not-at-top -except ImportError as _error: # pylint: disable=invalid-name +except (ImportError, ModuleNotFoundError) as _error: # pylint: disable=invalid-name _portpicker_import_error = _error portpicker = None @@ -56,6 +56,7 @@ from tensorflow.python.training import server_lib from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.tf_export import tf_export original_run_std_server = dc._run_std_server # pylint: disable=protected-access @@ -271,8 +272,8 @@ class MultiProcessCluster(object): self._cluster_spec, args=(self._start_events, self._finish_events), rpc_layer=self._rpc_layer, - stream_stdout=False, - list_stdout=False, + stream_output=False, + return_output=False, use_dill_for_args=False) self._mpr.start() for task_type, task_addresses in self._cluster_spec.items(): @@ -353,16 +354,53 @@ def create_multi_process_cluster(num_workers, return cluster -# TODO(rchao): Remove `test_obj` once estimator repo picks up the updated -# nightly TF. +@tf_export( + '__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[]) def create_cluster_spec(has_chief=False, num_workers=1, num_ps=0, - has_eval=False, - test_obj=None): - """Create a cluster spec with tasks with unused local ports.""" - del test_obj + has_eval=False): + """Create a cluster spec with tasks with unused local ports. + This utility finds available ports at localhost, and returns a dict that + represents the cluster spec that utilizes those ports, according to the + arguments. The dict representing the cluster spec contains task types, and + their instances' addresses. Note that this is usually only for testing purpose + using multiple processes in the local machine, and should not be used for real + multi-worker TensorFlow programs, where the addresses need to point to the + processes at separate machines. + + This util is useful when creating the `cluster_spec` arg for + `tf.__internal__.distribute.multi_process_runner.run`. + + Arguments: + has_chief: Whether the generated cluster spec should contain "chief" task + type. + num_workers: Number of workers to use in the cluster spec. + num_ps: Number of parameter servers to use in the cluster spec. + has_eval: Whether this cluster spec has evaluator. + + Returns: + A dict that represents the cluster spec using localhost ports for the tasks. + + Example: + + ```python + cluster_spec = + tf.__internal__.distribute.multi_process_runner.create_cluster_spec( + has_chief=True, num_workers=2, num_ps=2) + # An example of cluster_spec is + # {'chief': ['localhost:23381'], + # 'worker': ['localhost:19197', 'localhost:22903'], + # 'ps': ['localhost:16912', 'localhost:21535']} + + cluster_spec = + tf.__internal__.distribute.multi_process_runner.create_cluster_spec( + has_chief=False, num_workers=0, num_ps=0, has_eval=True) + # An example of cluster_spec is + # {'evaluator': ['localhost:23381']} + ``` + """ if _portpicker_import_error: raise _portpicker_import_error # pylint: disable=raising-bad-type diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index c256c2df78f..3d5175d9055 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -383,6 +383,11 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): del reduce_op, destinations, experimental_hints return value + def _gather_to_implementation(self, value, destinations, axis, + experimental_hints): + del destinations, axis, experimental_hints + return value + def _update(self, var, fn, args, kwargs, group): # The implementations of _update() and _update_non_slot() are identical # except _update() passes `var` as the first argument to `fn()`. @@ -453,6 +458,9 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): def _support_per_replica_values(self): return False + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): """ReplicaContext for OneDeviceStrategy.""" diff --git a/tensorflow/python/distribute/parallel_device/BUILD b/tensorflow/python/distribute/parallel_device/BUILD index 5fc294a5f5c..331e0a3c3af 100644 --- a/tensorflow/python/distribute/parallel_device/BUILD +++ b/tensorflow/python/distribute/parallel_device/BUILD @@ -1,3 +1,5 @@ +load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test") + package( default_visibility = ["//tensorflow:internal"], licenses = ["notice"], # Apache 2.0 @@ -17,6 +19,7 @@ py_library( ":saving", "//tensorflow/python:_pywrap_parallel_device", "//tensorflow/python/distribute:device_util", + "//tensorflow/python/tpu:tpu_ops", ], ) @@ -27,15 +30,13 @@ py_library( deps = ["//tensorflow/python:framework_ops"], ) -py_test( +distribute_py_test( name = "parallel_device_test", srcs = ["parallel_device_test.py"], python_version = "PY3", tags = [ # Dependencies aren't otherwise included in the pip package yet. "no_pip", - # MRO broken; needs investigation - "no_windows", ], deps = [ ":parallel_device", diff --git a/tensorflow/python/distribute/parallel_device/parallel_device.py b/tensorflow/python/distribute/parallel_device/parallel_device.py index 30381e2a95d..218bd68d824 100644 --- a/tensorflow/python/distribute/parallel_device/parallel_device.py +++ b/tensorflow/python/distribute/parallel_device/parallel_device.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import threading +import weakref from tensorflow.python import _pywrap_parallel_device from tensorflow.python.distribute import device_util @@ -32,6 +33,16 @@ from tensorflow.python.tpu.ops import tpu_ops _next_device_number = 0 _next_device_number_lock = threading.Lock() +_all_parallel_devices = weakref.WeakValueDictionary() + + +def unpack(tensor): + """Finds `tensor`'s parallel device and unpacks its components.""" + parallel_device = _all_parallel_devices.get(tensor.device, None) + if parallel_device is None: + raise ValueError("{} is not a parallel device".format(tensor.device)) + return parallel_device.unpack(tensor) + # TODO(allenl): Expand this docstring once things like getting components on and # off the device are stable. @@ -67,6 +78,7 @@ class ParallelDevice(object): self._device_ids = None self._device_scope = None self._saving_scope = None + _all_parallel_devices[self._name] = self def pack(self, tensors): """Create a tensor on the parallel device from a sequence of tensors. diff --git a/tensorflow/python/distribute/parallel_device/parallel_device_test.py b/tensorflow/python/distribute/parallel_device/parallel_device_test.py index 066f9ea376c..cf86a7362fb 100644 --- a/tensorflow/python/distribute/parallel_device/parallel_device_test.py +++ b/tensorflow/python/distribute/parallel_device/parallel_device_test.py @@ -93,19 +93,30 @@ class _VirtualDeviceTestCase(test.TestCase): def setUp(self): super(_VirtualDeviceTestCase, self).setUp() - cpus = context.context().list_physical_devices("CPU") - # Set 4 virtual CPUs - context.context().set_logical_device_configuration(cpus[0], [ - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration() - ]) + ctx = context.context() + if ctx.list_physical_devices("TPU"): + self.device_type = "TPU" + elif ctx.list_physical_devices("GPU"): + self.device_type = "GPU" + gpus = ctx.list_physical_devices(self.device_type) + ctx.set_logical_device_configuration(gpus[0], [ + context.LogicalDeviceConfiguration(memory_limit=100), + context.LogicalDeviceConfiguration(memory_limit=100), + ]) + else: + self.device_type = "CPU" + cpus = ctx.list_physical_devices("CPU") + ctx.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + ]) - self.device = parallel_device.ParallelDevice( - components=["/job:localhost/device:CPU:0", "CPU:1"]) - self.assertIn("CPU:0", self.device.components[0]) - self.assertIn("CPU:1", self.device.components[1]) + self.device = parallel_device.ParallelDevice(components=[ + "/job:localhost/device:{}:0".format(self.device_type), + self.device_type + ":1" + ]) + self.assertIn(self.device_type + ":0", self.device.components[0]) + self.assertIn(self.device_type + ":1", self.device.components[1]) class ParallelDeviceTests(_VirtualDeviceTestCase): @@ -124,10 +135,14 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): def test_device_id(self): device_ids = self.device.unpack(self.device.device_ids) self.assertAllClose([0, 1], device_ids) - self.assertIn(self.device.components[0], device_ids[0].backing_device) - self.assertIn(self.device.components[1], device_ids[1].backing_device) + # TODO(allenl): Should device IDs be int64 so they can be placed on GPUs? + # Currently backing_device is CPU. + self.assertIn(self.device.components[0], device_ids[0].device) + self.assertIn(self.device.components[1], device_ids[1].device) def test_collective_reduce(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") with self.device: x = self.device.pack( [constant_op.constant(-1.5), @@ -139,6 +154,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], outputs[1].backing_device) def test_collective_reduce_async_scope(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") # Note that ops on the parallel device currently don't execute # asynchronously. The test is just that we don't get deadlocks. with context.async_scope(), self.device: @@ -152,6 +169,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], outputs[1].backing_device) def test_collective_reduce_async_context(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") previous = config.get_synchronous_execution() try: context._reset_context() @@ -173,6 +192,8 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): config.set_synchronous_execution(previous) def test_collective_in_function(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") c = constant_op.constant([2]) @def_function.function @@ -313,6 +334,33 @@ class ParallelDeviceTests(_VirtualDeviceTestCase): return y, tape.gradient(y, x) self._assert_close_to_non_parallel(_test_fn) + def test_variable_created_in_function(self): + + class M(module.Module): + + def __init__(self): + self.v = None + self.w = None + self.x = None + self.z = None + + @def_function.function(autograph=False) + def __call__(self, x): + if self.v is None: + with ops.init_scope(): + initial_value = constant_op.constant(2.) + self.z = variables.Variable(initial_value) + self.x = variables.Variable(initial_value) + self.w = variables.Variable(lambda: constant_op.constant(2.)) + self.v = variables.Variable(constant_op.constant(2.)) + return x * self.v * self.w * self.x * self.z + + with self.device: + m = M() + packed_outputs = m(array_ops.ones([])) + outputs = self.device.unpack(packed_outputs) + self.assertAllClose([16., 16.], outputs) + class LayerTests(_VirtualDeviceTestCase): @@ -340,6 +388,8 @@ class LayerTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], outputs[1].backing_device) def test_layer_sync_training(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") with self.device: layer = _Dense(5) @@ -389,6 +439,8 @@ class LayerTests(_VirtualDeviceTestCase): self.assertIn(self.device.components[1], final_kernels[1].backing_device) def test_training_loop(self): + if self.device_type == "TPU": + self.skipTest("ParallelDevice collectives on TPUs need work") for _ in range(5): layer = _Dense(5) checkpoint = tracking.Checkpoint(layer=layer) diff --git a/tensorflow/python/distribute/parallel_device/saving.py b/tensorflow/python/distribute/parallel_device/saving.py index f1539e49651..5fdd7ae5d3a 100644 --- a/tensorflow/python/distribute/parallel_device/saving.py +++ b/tensorflow/python/distribute/parallel_device/saving.py @@ -20,6 +20,8 @@ from __future__ import print_function import contextlib import functools +import six +import wrapt from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import resource_variable_ops @@ -47,14 +49,32 @@ class _ParallelComponentSaveable(saveable_object.SaveableObject): resource=self._handle, value=restored_tensor) -class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable): - """Mixin to to override variable checkpointing, saving each component.""" +_wrapt_type = type(wrapt.ObjectProxy) +_variable_type = type(resource_variable_ops.BaseResourceVariable) +if issubclass(_variable_type, _wrapt_type): + # Some wrapt versions do not have a meta-class, which would create an invalid + # MRO. + VariableProxyMetaClass = _variable_type +else: + class VariableProxyMetaClass(_wrapt_type, _variable_type): # pylint: disable=duplicate-bases + """A combined MetaClasses for ParallelVariable. - def __init__(self, parallel_device, expected_shape=None, use_resource=None, - **kwargs): - del expected_shape, use_resource - self._parallel_device = parallel_device - super(ParallelSavingMixin, self).__init__(**kwargs) + Satisfies the requirement "the metaclass of a derived class must be a + (non-strict) subclass of the metaclasses of all its bases." At the time of + writing these two MetaClasses are compatible (overriding different methods, + both relatively trivial). + """ + pass + + +class ParallelVariable( + six.with_metaclass(VariableProxyMetaClass, wrapt.ObjectProxy, + resource_variable_ops.BaseResourceVariable)): + """Overrides variable checkpointing, saving each component.""" + + def __init__(self, parallel_device, wrapped_variable): + self._self_parallel_device = parallel_device + super(ParallelVariable, self).__init__(wrapped_variable) # TODO(allenl): Consider either adding a boolean argument for # save-primary-only or looking at synchronization/aggregation properties. @@ -63,7 +83,8 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable): component_saveables = {} # Create one SaveableObject per device, each one of which looks like a # regular ResourceVariable saveable. - for index, handle in enumerate(self._parallel_device.unpack(self.handle)): + for index, handle in enumerate( + self._self_parallel_device.unpack(self.handle)): if index == 0: # This is the name regular tf.Variables use to save. Using it for the # component on the first device means non-parallel tf.Variable objects @@ -80,26 +101,24 @@ class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable): return component_saveables -class ParallelVariable( - ParallelSavingMixin, resource_variable_ops.ResourceVariable): - pass - - -class UninitializedParallelVariable( - ParallelSavingMixin, resource_variable_ops.UninitializedVariable): - pass - - -def _variable_creator(next_creator, parallel_device, initial_value=None, - **kwargs): - del next_creator - if initial_value is not None: +def _variable_creator(next_creator, parallel_device, **kwargs): + """Wraps intercepted variables to add parallel saving.""" + # Depending on the context (SavedModel loading, tf.function, etc.) we may get + # one of several different variable types. For variables placed on the + # parallel device we only want to affect saving and otherwise preserve + # behavior. This wrapping to override behavior is similar to tf.distribute's + # DistributedVariable, but much more limited. + variable = next_creator(**kwargs) + if variable.device == parallel_device._name: # Friend access; pylint: disable=protected-access return ParallelVariable( - parallel_device=parallel_device, initial_value=initial_value, **kwargs) + parallel_device=parallel_device, wrapped_variable=variable) else: - # SavedModel loading does not pass an initial value. - return UninitializedParallelVariable( - parallel_device=parallel_device, **kwargs) + # Variables not placed on the handler (because of a device scope) don't + # need wrapping. + # + # TODO(allenl): Device scopes should merge with parallel devices rather + # than overriding them like this. + return variable @contextlib.contextmanager diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index b60ea74dd04..0cc2b21c3aa 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -117,10 +117,6 @@ class ParameterServerStrategy(distribute_lib.Strategy): extended = ParameterServerStrategyExtended( self, cluster_resolver=cluster_resolver) super(ParameterServerStrategy, self).__init__(extended) - distribute_lib.distribution_strategy_gauge.get_cell("V2").set( - "ParameterServerStrategy") - distribute_lib.distribution_strategy_replica_gauge.get_cell("num_ps").set( - len(self.extended.parameter_devices)) def experimental_distribute_dataset(self, dataset, options=None): self._raise_pss_error_if_eager() @@ -160,8 +156,6 @@ class ParameterServerStrategyV1(distribute_lib.StrategyV1): self, cluster_resolver=cluster_resolver)) distribute_lib.distribution_strategy_gauge.get_cell("V1").set( "ParameterServerStrategy") - distribute_lib.distribution_strategy_replica_gauge.get_cell("num_ps").set( - len(self.extended.parameter_devices)) __init__.__doc__ = ParameterServerStrategy.__init__.__doc__ @@ -351,14 +345,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers_with_options(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator( dataset, self._input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -509,6 +503,17 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): "Cannot reduce to another worker: %r, current worker is %r" % (d, self._worker_device)) + def _gather_to_implementation(self, value, destinations, axis, + experimental_hints): + self._verify_destinations_not_different_worker(destinations) + if not isinstance(value, values.DistributedValues): + return value + return self._cross_device_ops._gather( # pylint: disable=protected-access + value, + destinations=destinations, + axis=axis, + experimental_hints=experimental_hints) + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): @@ -699,3 +704,9 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): Boolean. """ return True + + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + + def _get_replica_id_in_sync_group(self, replica_id): + return replica_id diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index d6cacda8ed7..d215be0dd94 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -94,6 +94,8 @@ class ParameterServerStrategyV2(distribute_lib.Strategy): # TODO(b/167894802): Make chief, worker, and ps names customizable. self._connect_to_cluster(client_name="chief") super(ParameterServerStrategyV2, self).__init__(self._extended) + distribute_lib.distribution_strategy_gauge.get_cell("V2").set( + "ParameterServerStrategy") def _connect_to_cluster(self, client_name): if client_name in ["worker", "ps"]: @@ -124,6 +126,11 @@ class ParameterServerStrategyV2(distribute_lib.Strategy): protocol=self._cluster_resolver.rpc_layer, cluster_device_filters=device_filters) + distribute_lib.distribution_strategy_replica_gauge.get_cell( + "ps_strategy_num_workers").set(self._num_workers) + distribute_lib.distribution_strategy_replica_gauge.get_cell( + "ps_strategy_num_ps").set(self._num_ps) + def _verify_args_and_config(self, cluster_resolver): if not cluster_resolver.cluster_spec(): raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.") diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index 57273e9bb15..2a7afabf166 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -27,22 +27,45 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy as mirrored_lib from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import one_device_strategy as one_device_lib +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy as tpu_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver from tensorflow.python.eager import context from tensorflow.python.eager import remote -from tensorflow.python.framework import config from tensorflow.python.platform import flags from tensorflow.python.tpu import device_assignment as device_assignment_lib from tensorflow.python.tpu import tpu_strategy_util +from tensorflow.python.util.tf_export import tf_export -FLAGS = flags.FLAGS +_TF_INTERNAL_API_PREFIX = "__internal__.distribute.combinations." _did_connect_to_cluster = False CollectiveAllReduceExtended = ( collective_all_reduce_strategy.CollectiveAllReduceExtended) +def _version_chooser(tf1_cls, tf2_cls): + + def creator(*args, **kwargs): + if tf2.enabled(): + return tf2_cls(*args, **kwargs) + return tf1_cls(*args, **kwargs) + + return creator + + +MirroredStrategy = _version_chooser(mirrored_lib.MirroredStrategyV1, + mirrored_lib.MirroredStrategy) +CentralStorageStrategy = _version_chooser( + central_storage_strategy.CentralStorageStrategyV1, + central_storage_strategy.CentralStorageStrategy) +OneDeviceStrategy = _version_chooser(one_device_lib.OneDeviceStrategyV1, + one_device_lib.OneDeviceStrategy) +# Only V2 CollectiveAllReduceStrategy combinations are supported. +CollectiveAllReduceStrategy = ( + collective_all_reduce_strategy.CollectiveAllReduceStrategy) + + # pylint: disable=missing-docstring def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, @@ -50,6 +73,7 @@ def _get_tpu_strategy_creator(steps_per_run, **kwargs): def _create_tpu_strategy(): + FLAGS = flags.FLAGS # pylint: disable=invalid-name global _did_connect_to_cluster try: @@ -77,8 +101,8 @@ def _get_tpu_strategy_creator(steps_per_run, device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( - topology, core_assignment=device_assignment_lib. - SINGLE_CORE_ASSIGNMENT) + topology, + core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) # Steps per run is only supported in TF 1.x if tf2.enabled(): @@ -118,16 +142,15 @@ def _get_multi_worker_mirrored_creator(required_gpus): # configures the eager context. The eager context can no longer be # configured after initialization. with context.eager_mode(): - strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - cluster_resolver=resolver) + strategy = CollectiveAllReduceStrategy(cluster_resolver=resolver) # TODO(b/152320929): Wait for the cluster before proceeding, otherwise # collectives may hang if any worker launches collectives before the chief # creates the strategy. try: - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() except ValueError: # If the creator is called in the main process, - # multi_process_runner.barrier() raises ValueError, which is safe to + # multi_process_runner.get_barrier() raises ValueError, which is safe to # ignore. pass return strategy @@ -141,20 +164,16 @@ default_strategy = combinations.NamedDistribution( distribution_strategy_context._get_default_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = combinations.NamedDistribution( - "OneDeviceCPU", - lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), - required_gpus=None) + "OneDeviceCPU", lambda: OneDeviceStrategy("/cpu:0"), required_gpus=None) one_device_strategy_gpu = combinations.NamedDistribution( - "OneDeviceGPU", - lambda: one_device_lib.OneDeviceStrategy("/gpu:0"), - required_gpus=1) + "OneDeviceGPU", lambda: OneDeviceStrategy("/gpu:0"), required_gpus=1) one_device_strategy_on_worker_1 = combinations.NamedDistribution( "OneDeviceOnWorker1CPU", - lambda: one_device_lib.OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"), # pylint: disable=line-too-long + lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/cpu:0"), required_gpus=None) one_device_strategy_gpu_on_worker_1 = combinations.NamedDistribution( "OneDeviceOnWorker1GPU", - lambda: one_device_lib.OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"), # pylint: disable=line-too-long + lambda: OneDeviceStrategy("/job:worker/replica:0/task:1/gpu:0"), required_gpus=1) tpu_strategy = combinations.NamedDistribution( "TPU", _get_tpu_strategy_creator(steps_per_run=2), required_tpu=True) @@ -178,30 +197,32 @@ cloud_tpu_strategy = combinations.NamedDistribution( required_tpu=True, use_cloud_tpu=True) mirrored_strategy_with_one_cpu = combinations.NamedDistribution( - "Mirrored1CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) + "Mirrored1CPU", lambda: MirroredStrategy(["/cpu:0"])) mirrored_strategy_with_one_gpu = combinations.NamedDistribution( - "Mirrored1GPU", - lambda: mirrored_lib.MirroredStrategy(["/gpu:0"]), - required_gpus=1) + "Mirrored1GPU", lambda: MirroredStrategy(["/gpu:0"]), required_gpus=1) mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution( "MirroredCPUAndGPU", - lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), + lambda: MirroredStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) mirrored_strategy_with_two_gpus = combinations.NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), + lambda: MirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) # Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods. mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution( - "Mirrored2CPU", lambda: mirrored_lib.MirroredStrategy(["/cpu:1", "/cpu:2"])) + "Mirrored2CPU", lambda: MirroredStrategy(["/cpu:1", "/cpu:2"])) +mirrored_strategy_with_cpu_1_and_2.__doc__ = ( + """Mirrored strategy with 2 virtual CPUs. + + Should set up logical devices before use + """) central_storage_strategy_with_two_gpus = combinations.NamedDistribution( "CentralStorage2GPUs", - lambda: central_storage_strategy.CentralStorageStrategy._from_num_gpus(2), # pylint: disable=protected-access + lambda: CentralStorageStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) central_storage_strategy_with_gpu_and_cpu = combinations.NamedDistribution( "CentralStorageCPUAndGPU", - lambda: central_storage_strategy.CentralStorageStrategy( - ["/gpu:0", "/cpu:0"]), + lambda: CentralStorageStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) # chief + 1 worker, with CPU. multi_worker_mirrored_2x1_cpu = combinations.NamedDistribution( @@ -246,27 +267,9 @@ multi_worker_mirrored_4x1_cpu = combinations.NamedDistribution( graph_and_eager_modes = ["graph", "eager"] -# This function should be called in a test's `setUp` method with the -# maximum value needed in any test. +# TODO(crccw): remove after tf-nightly picks up the new API. def set_virtual_cpus_to_at_least(num_virtual_cpus): - """Create virtual CPU devices if they haven't yet been created.""" - if num_virtual_cpus < 1: - raise ValueError("`num_virtual_cpus` must be at least 1 not %r" % - (num_virtual_cpus,)) - physical_devices = config.list_physical_devices("CPU") - if not physical_devices: - raise RuntimeError("No CPUs found") - configs = config.get_logical_device_configuration(physical_devices[0]) - if configs is None: - logical_devices = [ - context.LogicalDeviceConfiguration() for _ in range(num_virtual_cpus) - ] - config.set_logical_device_configuration(physical_devices[0], - logical_devices) - else: - if len(configs) < num_virtual_cpus: - raise RuntimeError("Already configured with %d < %d virtual CPUs" % - (len(configs), num_virtual_cpus)) + test_util.set_logical_devices_to_at_least("CPU", num_virtual_cpus) strategies_minus_tpu = [ @@ -321,8 +324,7 @@ multidevice_strategies = [ ] multiworker_strategies = [ - multi_worker_mirrored_2x1_cpu, - multi_worker_mirrored_2x1_gpu, + multi_worker_mirrored_2x1_cpu, multi_worker_mirrored_2x1_gpu, multi_worker_mirrored_2x2_gpu ] @@ -352,3 +354,57 @@ def all_strategy_minus_default_and_tpu_combinations(): def all_strategy_combinations_minus_default(): return (all_strategy_minus_default_and_tpu_combinations() + tpu_strategy_combinations()) + + +tf_export( + _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_gpu_and_cpu", + v1=[]).export_constant(__name__, + "central_storage_strategy_with_gpu_and_cpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "central_storage_strategy_with_two_gpus", + v1=[]).export_constant(__name__, "central_storage_strategy_with_two_gpus") +tf_export( + _TF_INTERNAL_API_PREFIX + "cloud_tpu_strategy", + v1=[]).export_constant(__name__, "cloud_tpu_strategy") +tf_export( + _TF_INTERNAL_API_PREFIX + "default_strategy", + v1=[]).export_constant(__name__, "default_strategy") +tf_export( + _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_cpu_1_and_2", + v1=[]).export_constant(__name__, "mirrored_strategy_with_cpu_1_and_2") +tf_export( + _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_gpu_and_cpu", + v1=[]).export_constant(__name__, "mirrored_strategy_with_gpu_and_cpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_cpu", + v1=[]).export_constant(__name__, "mirrored_strategy_with_one_cpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_one_gpu", + v1=[]).export_constant(__name__, "mirrored_strategy_with_one_gpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "mirrored_strategy_with_two_gpus", + v1=[]).export_constant(__name__, "mirrored_strategy_with_two_gpus") +tf_export( + _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_cpu", + v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_cpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x1_gpu", + v1=[]).export_constant(__name__, "multi_worker_mirrored_2x1_gpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "multi_worker_mirrored_2x2_gpu", + v1=[]).export_constant(__name__, "multi_worker_mirrored_2x2_gpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "one_device_strategy", + v1=[]).export_constant(__name__, "one_device_strategy") +tf_export( + _TF_INTERNAL_API_PREFIX + "one_device_strategy_gpu", + v1=[]).export_constant(__name__, "one_device_strategy_gpu") +tf_export( + _TF_INTERNAL_API_PREFIX + "tpu_strategy", + v1=[]).export_constant(__name__, "tpu_strategy") +tf_export( + _TF_INTERNAL_API_PREFIX + "tpu_strategy_one_core", + v1=[]).export_constant(__name__, "tpu_strategy_one_core") +tf_export( + _TF_INTERNAL_API_PREFIX + "tpu_strategy_packed_var", + v1=[]).export_constant(__name__, "tpu_strategy_packed_var") diff --git a/tensorflow/python/distribute/strategy_combinations_test.py b/tensorflow/python/distribute/strategy_combinations_test.py index 38ace7da42d..9ee5eab93c6 100644 --- a/tensorflow/python/distribute/strategy_combinations_test.py +++ b/tensorflow/python/distribute/strategy_combinations_test.py @@ -20,52 +20,22 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.python import tf2 +from tensorflow.python.distribute import central_storage_strategy +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import one_device_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.eager import context +from tensorflow.python.distribute import test_util +from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import def_function -from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class VirtualDevicesTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - context._reset_context() # pylint: disable=protected-access - # Need to call set_virtual_cpus_to_at_least() in setUp with the maximum - # value needed in any test. - strategy_combinations.set_virtual_cpus_to_at_least(3) - super(VirtualDevicesTest, self).setUp() - - def test3VirtualCPUs(self): - cpu_device = config.list_physical_devices("CPU")[0] - self.assertLen(config.get_logical_device_configuration(cpu_device), 3) - - def testSetVirtualCPUsAgain(self): - strategy_combinations.set_virtual_cpus_to_at_least(2) - cpu_device = config.list_physical_devices("CPU")[0] - self.assertLen(config.get_logical_device_configuration(cpu_device), 3) - - def testSetVirtualCPUsErrors(self): - with self.assertRaises(ValueError): - strategy_combinations.set_virtual_cpus_to_at_least(0) - with self.assertRaisesRegex(RuntimeError, "with 3 < 5 virtual CPUs"): - strategy_combinations.set_virtual_cpus_to_at_least(5) - - @combinations.generate(combinations.combine( - distribution=[strategy_combinations.mirrored_strategy_with_cpu_1_and_2], - mode=["graph", "eager"])) - def testMirrored2CPUs(self, distribution): - with distribution.scope(): - one_per_replica = distribution.run(lambda: constant_op.constant(1)) - num_replicas = distribution.reduce( - reduce_util.ReduceOp.SUM, one_per_replica, axis=None) - self.assertEqual(2, self.evaluate(num_replicas)) - - class StrategyCombinationsTest(test.TestCase, parameterized.TestCase): @combinations.generate( @@ -100,6 +70,126 @@ class StrategyCombinationsTest(test.TestCase, parameterized.TestCase): reduce_util.ReduceOp.SUM, one_per_replica, axis=None) self.assertEqual(self.evaluate(num_replicas), 4.) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_cpu_1_and_2 + ], + mode=["graph", "eager"])) + def testMirrored2CPUs(self, distribution): + with distribution.scope(): + one_per_replica = distribution.run(lambda: constant_op.constant(1)) + num_replicas = distribution.reduce( + reduce_util.ReduceOp.SUM, one_per_replica, axis=None) + self.assertEqual(2, self.evaluate(num_replicas)) + + +class V1StrategyTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + tf2.disable() + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.one_device_strategy, + strategy_combinations.one_device_strategy_gpu, + strategy_combinations.one_device_strategy_gpu_on_worker_1, + strategy_combinations.one_device_strategy_on_worker_1 + ])) + def testOneDevice(self, strategy): + self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategyV1) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_one_gpu, + strategy_combinations.mirrored_strategy_with_two_gpus, + ])) + def testMirrored(self, strategy): + self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategyV1) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + strategy_combinations.multi_worker_mirrored_4x1_cpu, + ])) + def testMultiWorkerMirrored(self, strategy): + # MultiWorkerMirroredStrategy combinations only supports V2. + self.assertIsInstance( + strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.central_storage_strategy_with_two_gpus, + ])) + def testCentralStorage(self, strategy): + self.assertIsInstance(strategy, + central_storage_strategy.CentralStorageStrategyV1) + + @combinations.generate( + combinations.combine(strategy=strategy_combinations.tpu_strategies)) + def testTPU(self, strategy): + self.assertIsInstance(strategy, tpu_strategy.TPUStrategyV1) + + +class V2StrategyTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + super().setUp() + tf2.enable() + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.one_device_strategy, + strategy_combinations.one_device_strategy_gpu, + strategy_combinations.one_device_strategy_gpu_on_worker_1, + strategy_combinations.one_device_strategy_on_worker_1 + ])) + def testOneDevice(self, strategy): + self.assertIsInstance(strategy, one_device_strategy.OneDeviceStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_one_gpu, + strategy_combinations.mirrored_strategy_with_two_gpus, + ])) + def testMirrored(self, strategy): + self.assertIsInstance(strategy, mirrored_strategy.MirroredStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + strategy_combinations.multi_worker_mirrored_4x1_cpu, + ])) + def testMultiWorkerMirrored(self, strategy): + self.assertIsInstance( + strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy) + + @combinations.generate( + combinations.combine(strategy=[ + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.central_storage_strategy_with_two_gpus, + ])) + def testCentralStorage(self, strategy): + self.assertIsInstance(strategy, + central_storage_strategy.CentralStorageStrategy) + + @combinations.generate( + combinations.combine(strategy=strategy_combinations.tpu_strategies)) + def testTPU(self, strategy): + self.assertIsInstance(strategy, tpu_strategy.TPUStrategy) + if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py index 5eeeb11fa8f..6b19a744457 100644 --- a/tensorflow/python/distribute/strategy_common_test.py +++ b/tensorflow/python/distribute/strategy_common_test.py @@ -20,7 +20,6 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context as ds_context @@ -28,12 +27,12 @@ from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_test_lib +from tensorflow.python.distribute import test_util +from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy -from tensorflow.python.distribute.tpu_strategy import TPUStrategy from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -51,27 +50,6 @@ from tensorflow.python.util import nest mode=['eager'])) class StrategyTest(test.TestCase, parameterized.TestCase): - def testSimpleReduce(self, strategy): - per_replica_value = strategy.experimental_distribute_values_from_function( - lambda _: array_ops.ones((), dtypes.float32)) - - def fn_eager(): - - return strategy.reduce( - reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None) - - fn_graph = def_function.function(fn_eager) - # Run reduce under the strategy scope to explicitly enter - # strategy default_device scope. - with strategy.scope(): - self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) - self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) - - # Run reduce without a strategy scope to implicitly enter - # strategy default_device scope. - self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) - self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) - def testCaptureReplicaId(self, strategy): m = {} @@ -92,482 +70,45 @@ class StrategyTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( strategy=[ - strategy_combinations.multi_worker_mirrored_2x2_gpu, strategy_combinations.multi_worker_mirrored_2x1_cpu, strategy_combinations.multi_worker_mirrored_2x1_gpu, - ], - mode=['eager'], - pure_eager=[True, False])) -class GatherTest(test.TestCase, parameterized.TestCase): + ] + strategy_combinations.all_strategies, + mode=['eager'])) +class ReduceTest(test.TestCase, parameterized.TestCase): - def _gather_same_shape_and_verify(self, value_on_replica, axis, pure_eager, - strategy): - distributed_values = strategy.experimental_distribute_values_from_function( - lambda _: array_ops.identity(value_on_replica)) - - def run(): - return strategy._gather(distributed_values, axis=axis) - - if not pure_eager: - run = def_function.function(run) - - all_results = [ - value_on_replica for _ in range(strategy.num_replicas_in_sync) - ] - expected_result = array_ops.concat(all_results, axis=axis) - self.assertAllEqual(run().numpy(), expected_result) - - def testGatherPerReplicaDense1D0Axis(self, strategy, pure_eager): - """A DistributedValues object with two tensors of shape [3] on each replica gathers to a tensor of [6].""" - single_value = constant_op.constant([1, 2, 3]) - axis = 0 - self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) - - def testGatherPerReplicaDense2D0Axis(self, strategy, pure_eager): - """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 0th dim to a tensor of [2, 3].""" - single_value = constant_op.constant([[1, 2, 3]]) - axis = 0 - self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) - - def testGatherPerReplicaDense2D1Axis(self, strategy, pure_eager): - """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 1st dim to a tensor of [1, 6].""" - single_value = constant_op.constant([[1, 2, 3]]) - axis = 1 - self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) - - def testGatherPerReplicaDense3D0Axis(self, strategy, pure_eager): - """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 0th dim to a tensor of [2, 2, 2].""" - single_value = constant_op.constant([[[1, 2], [1, 2]]]) - axis = 0 - self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) - - def testGatherPerReplicaDense3D1Axis(self, strategy, pure_eager): - """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 1nd dimension to a tensor of [1, 4, 2].""" - single_value = constant_op.constant([[[1, 2], [1, 2]]]) - axis = 1 - self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) - - def testGatherPerReplicaDense3D2Axis(self, strategy, pure_eager): - """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 2nd dimension to a tensor of [1, 2, 4].""" - single_value = constant_op.constant([[[1, 2], [1, 2]]]) - axis = 2 - self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) - - def testGatherDiffShapeAtAxis0(self, strategy, pure_eager): - """Different `Axis`-th (0) dimension: shape [1, 1], [2, 1] -> [3, 1].""" - - def value_fn(ctx): - return constant_op.constant( - 1, shape=(ctx.replica_id_in_sync_group + 1, 1)) - - distributed_values = strategy.experimental_distribute_values_from_function( - value_fn) - axis = 0 - - def run(): - return strategy._gather(distributed_values, axis=axis) - - if not pure_eager: - run = def_function.function(run) - - if strategy.num_replicas_in_sync == 1: - expected_result = constant_op.constant(1, shape=(1, 1)) - elif strategy.num_replicas_in_sync == 2: - expected_result = constant_op.constant(1, shape=(3, 1)) - elif strategy.num_replicas_in_sync == 4: - expected_result = constant_op.constant(1, shape=(10, 1)) - else: - # should follow expected_result = constant_op.constant( - # 1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1)) - raise ValueError('Add your own expect according to num_replicas_in sync') - - self.assertAllEqual(run().numpy(), expected_result) - - def testGatherDiffShapeAtAxis1(self, strategy, pure_eager): - """Different `Axis`-th (non-0) dimension: shape [1, 1], [1, 2] -> [1, 3].""" - - def value_fn(ctx): - return constant_op.constant( - 1, shape=(1, ctx.replica_id_in_sync_group + 1)) - - distributed_values = strategy.experimental_distribute_values_from_function( - value_fn) - axis = 1 - - def run(): - return strategy._gather(distributed_values, axis=axis) - - if not pure_eager: - run = def_function.function(run) - - if strategy.num_replicas_in_sync == 1: - expected_result = constant_op.constant(1, shape=(1, 1)) - elif strategy.num_replicas_in_sync == 2: - expected_result = constant_op.constant(1, shape=(1, 3)) - elif strategy.num_replicas_in_sync == 4: - expected_result = constant_op.constant(1, shape=(1, 10)) - else: - # should follow expected_result = constant_op.constant( - # 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) - raise ValueError('Add your own expect according to num_replicas_in sync') - - self.assertAllEqual(run().numpy(), expected_result) - - def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): - """Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error.""" - if _get_num_devices_per_worker(strategy) > 1: - self.skipTest('b/167331966') - def value_fn(ctx): - return constant_op.constant( - 1, shape=(1, ctx.replica_id_in_sync_group + 1)) - - distributed_values = strategy.experimental_distribute_values_from_function( - value_fn) - axis = 0 - - def run(): - return strategy._gather(distributed_values, axis=axis) - - error_message = 'Shape mismatch' - if not pure_eager: - run = def_function.function(run) - - with self.assertRaisesRegex(errors.InvalidArgumentError, error_message): - run() - - def testGatherRaiseSparsePerReplicaMultiWorker(self, strategy, pure_eager): - if strategy.num_replicas_in_sync != 2: - self.skipTest('Test for two replicas.') - dense_shape = [5, 2] - if multi_worker_test_base.get_task_type() == 'chief': - t0 = _make_indexed_slices( - values=[[1., 2.]], indices=[2], dense_shape=dense_shape) - if multi_worker_test_base.get_task_type() == 'worker': - t0 = _make_indexed_slices( - values=[[3., 4.], [5., 6.]], indices=[1, 3], dense_shape=dense_shape) - - def run(value): - return strategy._gather(value, axis=0) - - with self.assertRaisesRegex( - NotImplementedError, - r'gather/all_gather does not support IndexedSlices'): - if pure_eager: - run(t0) - else: - def_function.function(run)(t0) - - def testGatherRaiseDifferentRank(self, strategy, pure_eager): - """Different rank: [1,], [1, 2] -> raise error.""" - if strategy.num_replicas_in_sync <= 1: - self.skipTest('Test for more than 1 replicas.') - if _get_num_devices_per_worker(strategy) > 1: - self.skipTest('b/167331966') - def value_fn(ctx): - return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2))) - - distributed_values = strategy.experimental_distribute_values_from_function( - value_fn) - axis = 0 - - def run(): - return strategy._gather(distributed_values, axis=axis) - - error_message = 'Shape mismatch' - - if not pure_eager: - run = def_function.function(run) - - with self.assertRaisesRegex(errors.InvalidArgumentError, error_message): - run() - - -@combinations.generate( - combinations.combine( - strategy=[ - strategy_combinations.multi_worker_mirrored_2x2_gpu, - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - ], - mode=['eager'], - pure_eager=[True, False])) -class AllGatherTest(test.TestCase, parameterized.TestCase): - - def _all_gather_same_shape_and_verify(self, value_on_replica, axis, - pure_eager, strategy): + def testBasic(self, strategy): per_replica_value = strategy.experimental_distribute_values_from_function( - lambda _: array_ops.identity(value_on_replica)) + lambda _: array_ops.ones((), dtypes.float32)) - def replica_fn(per_replica_value): - ctx = ds_context.get_replica_context() - local_value = array_ops.identity(per_replica_value) - return ctx._all_gather(local_value, axis=axis) + def fn_eager(): - if not pure_eager: - replica_fn = def_function.function(replica_fn) + return strategy.reduce( + reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None) - result = strategy.experimental_local_results( - strategy.run(replica_fn, args=(per_replica_value,))) + fn_graph = def_function.function(fn_eager) + # Run reduce under the strategy scope to explicitly enter + # strategy default_device scope. + with strategy.scope(): + self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) + self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) - all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)] - expect = array_ops.concat(all_value, axis=axis) - expected_result = [expect] * _get_num_devices_per_worker(strategy) + # Run reduce without a strategy scope to implicitly enter + # strategy default_device scope. + self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) + self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) - self.assertAllClose(result, expected_result) + def testAxis(self, strategy): - def testAllGatherPerReplicaDense1D0Axis(self, strategy, pure_eager): - """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (3,) on two replica returns a PerReplica of tensor(s) with shape (6,).""" - single_value = constant_op.constant([1, 2, 3], dtype=dtypes.float32) - axis = 0 - self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, - strategy) + @def_function.function + def fn(): + return constant_op.constant([1., 2.]) - def testAllGatherPerReplicaDense2D0Axis(self, strategy, pure_eager): - """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (2,3).""" - single_value = constant_op.constant([[1, 2, 3]]) - axis = 0 - self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, - strategy) + x = strategy.run(fn) - def testAllGatherPerReplicaDense2D1Axis(self, strategy, pure_eager): - """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (1,6).""" - single_value = constant_op.constant([[1, 2, 3]]) - axis = 1 - self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, - strategy) - - def testAllGatherPerReplicaDense3D0Axis(self, strategy, pure_eager): - """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (2,2,2).""" - single_value = constant_op.constant([[[1, 2], [1, 2]]]) - axis = 0 - self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, - strategy) - - def testAllGatherPerReplicaDense3D1Axis(self, strategy, pure_eager): - """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,4,2).""" - single_value = constant_op.constant([[[1, 2], [1, 2]]]) - axis = 1 - self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, - strategy) - - def testAllGatherPerReplicaDense3D2Axis(self, strategy, pure_eager): - """all_gather(..., axis=2,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,2,4).""" - single_value = constant_op.constant([[[1, 2], [1, 2]]]) - axis = 2 - self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, - strategy) - - def testAllGatherDiffShapeAtAxis0(self, strategy, pure_eager): - """Different `Axis==0`-th dimension: shape [1, 1], [2, 1] -> [3, 1].""" - - def value_fn(ctx): - return constant_op.constant( - 1, shape=(ctx.replica_id_in_sync_group + 1, 1)) - - per_replica_value = strategy.experimental_distribute_values_from_function( - value_fn) - - if strategy.num_replicas_in_sync == 1: - expect = constant_op.constant(1, shape=(1, 1)) - elif strategy.num_replicas_in_sync == 2: - expect = constant_op.constant(1, shape=(3, 1)) - elif strategy.num_replicas_in_sync == 4: - expect = constant_op.constant(1, shape=(10, 1)) - else: - # should follow expect = constant_op.constant( - # 1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1)) - raise ValueError('Add your own expect according to num_replicas_in sync') - - def run(value): - value_identity = array_ops.identity(value) - ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) - - if not pure_eager: - run = def_function.function(run) - - expected_result = [expect] * _get_num_devices_per_worker(strategy) - result = strategy.experimental_local_results( - strategy.run(run, args=(per_replica_value,))) - self.assertAllEqual(result, expected_result) - - def testAllGatherDiffShapeAtAxis1(self, strategy, pure_eager): - """Different `Axis`-th (not 0th) dimension: shape [1, 1], [1, 2] -> [1, 3].""" - - def value_fn(ctx): - return constant_op.constant( - 1, shape=(1, ctx.replica_id_in_sync_group + 1)) - - per_replica_value = strategy.experimental_distribute_values_from_function( - value_fn) - - if strategy.num_replicas_in_sync == 1: - expect = constant_op.constant(1, shape=(1, 1)) - elif strategy.num_replicas_in_sync == 2: - expect = constant_op.constant(1, shape=(1, 3)) - elif strategy.num_replicas_in_sync == 4: - expect = constant_op.constant(1, shape=(1, 10)) - else: - # should follow expect = constant_op.constant( - # 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) - raise ValueError('Add your own expect according to num_replicas_in sync') - - def run(value): - value_identity = array_ops.identity(value) - ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=1) - - if not pure_eager: - run = def_function.function(run) - - expected_result = [expect] * _get_num_devices_per_worker(strategy) - result = strategy.experimental_local_results( - strategy.run(run, args=(per_replica_value,))) - self.assertAllEqual(result, expected_result) - - def testAllGatherNest(self, strategy, pure_eager): - axis = 1 - - def value_fn(ctx): - value = constant_op.constant( - 1, shape=(1, ctx.replica_id_in_sync_group + 1)) - return value - per_replica_value = strategy.experimental_distribute_values_from_function( - value_fn) - - if strategy.num_replicas_in_sync == 1: - expect_1 = constant_op.constant(1, shape=(1, 1)) - elif strategy.num_replicas_in_sync == 2: - expect_1 = constant_op.constant(1, shape=(1, 3)) - elif strategy.num_replicas_in_sync == 4: - expect_1 = constant_op.constant(1, shape=(1, 10)) - else: - # should follow expect_1 = constant_op.constant( - # 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) - raise ValueError('Add your own expect according to num_replicas_in sync') - - expected_per_replica_1 = [expect_1] * _get_num_devices_per_worker(strategy) - - value_2 = constant_op.constant([[[1, 2], [1, 2]]]) - - if strategy.num_replicas_in_sync == 1: - expect_2 = constant_op.constant([[[1, 2], [1, 2]]]) - elif strategy.num_replicas_in_sync == 2: - expect_2 = constant_op.constant([[[1, 2], [1, 2], [1, 2], [1, 2]]]) - elif strategy.num_replicas_in_sync == 4: - expect_2 = constant_op.constant([[[1, 2], [1, 2], [1, 2], [1, 2], [1, 2], - [1, 2], [1, 2], [1, 2]]]) - else: - # should follow expect_2 = array_ops.concat( - # [value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis) - raise ValueError('Add your own expect according to num_replicas_in sync') - - expected_per_replica_2 = [expect_2] * _get_num_devices_per_worker(strategy) - - def run(value): - value_1 = array_ops.identity(value) - value_3 = array_ops.identity(value_2) - ctx = ds_context.get_replica_context() - return ctx._all_gather([value_1, value_3], axis=axis) - - if not pure_eager: - run = def_function.function(run) - - result = strategy.run(run, args=(per_replica_value,)) - self.assertAllEqual( - strategy.experimental_local_results(result[0]), expected_per_replica_1) - self.assertAllEqual( - strategy.experimental_local_results(result[1]), expected_per_replica_2) - - def testAllGatherNest1D0Axis(self, strategy, pure_eager): - """all_gather(..., axis=0,...) a nest of DistributedValues.""" - single_value = constant_op.constant([1, 2, 3]) - axis = 0 - - def run(): - value_identity = array_ops.identity(single_value) - ctx = ds_context.get_replica_context() - return ctx._all_gather([value_identity, value_identity], axis=axis) - - if not pure_eager: - run = def_function.function(run) - - all_value = [single_value for _ in range(strategy.num_replicas_in_sync)] - expect = array_ops.concat(all_value, axis=axis) - expected_per_replica = [expect] * _get_num_devices_per_worker(strategy) - - result = strategy.run(run) - for gathered_result in result: - self.assertAllEqual( - strategy.experimental_local_results(gathered_result), - expected_per_replica) - - def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): - """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error.""" - if _get_num_devices_per_worker(strategy) > 1: - self.skipTest('b/167331966') - - def value_fn(ctx): - return constant_op.constant( - 1, shape=(1, ctx.replica_id_in_sync_group + 1)) - - per_replica_value = strategy.experimental_distribute_values_from_function( - value_fn) - - def run(value): - value_identity = array_ops.identity(value) - ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) - - if not pure_eager: - run = def_function.function(run) - - with self.assertRaisesRegex(errors.InvalidArgumentError, r'Shape mismatch'): - strategy.run(run, args=(per_replica_value,)) - - def testAllGatherRaiseSparsePerReplica(self, strategy, pure_eager): - # all_gather supports sparse when using tf.function, because sparse tensors - # are converted to dense in - # third_party/tensorflow/python/ops/custom_gradient.py _graph_mode_decorator - if strategy.num_replicas_in_sync != 2: - self.skipTest('Test for two replicas.') - dense_shape = [5, 2] - t0 = _make_indexed_slices( - values=[[1., 2.]], indices=[2], dense_shape=dense_shape) - - def replica_fn(value): - ctx = ds_context.get_replica_context() - return ctx._all_gather(value, axis=0) - - with self.assertRaisesRegex( - NotImplementedError, - r'gather/all_gather does not support IndexedSlices'): - strategy.run(replica_fn, args=(t0,)) - - def testAllGatherRaiseDifferentRank(self, strategy, pure_eager): - """Different rank: [1,], [1, 2] -> raise error.""" - if strategy.num_replicas_in_sync <= 1: - self.skipTest('Test for more than 1 replicas.') - if _get_num_devices_per_worker(strategy) > 1: - self.skipTest('b/167331966') - def value_fn(ctx): - return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2))) - - per_replica_value = strategy.experimental_distribute_values_from_function( - value_fn) - - def run(value): - value_identity = array_ops.identity(value) - ctx = ds_context.get_replica_context() - return ctx._all_gather(value_identity, axis=0) - - error_message = 'Shape mismatch' - - if not pure_eager: - run = def_function.function(run) - - with self.assertRaisesRegex(errors.InvalidArgumentError, error_message): - strategy.run(run, args=(per_replica_value,)) + x_m = strategy.reduce(reduce_util.ReduceOp.MEAN, x, axis=0) + self.assertEqual(1.5, x_m) + x_s = strategy.reduce(reduce_util.ReduceOp.SUM, x, axis=0) + self.assertEqual(3 * strategy.num_replicas_in_sync, x_s) def _make_indexed_slices(values, indices, dense_shape): @@ -578,10 +119,18 @@ def _make_indexed_slices(values, indices, dense_shape): return tensor -def _get_num_devices_per_worker(strategy): - """Returns the number of workers in the current cluster for multi-worker.""" - resolver = strategy.cluster_resolver - return max(nest.flatten(resolver.num_accelerators())[0], 1) +def _get_num_replicas_per_client(strategy): + if isinstance(strategy, CollectiveAllReduceStrategy): + resolver = strategy.cluster_resolver + return max(nest.flatten(resolver.num_accelerators())[0], 1) + else: + return strategy.num_replicas_in_sync + + +def _is_tpu_strategy(strategy): + return isinstance(strategy, + (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, + tpu_strategy.TPUStrategyV2)) @combinations.generate( @@ -634,8 +183,9 @@ class DistributedCollectiveAllReduceStrategyTest( result = run(input_iterator) expected_data_on_workers = {'chief': [8, 9, 10], 'worker': [11, 12, 13]} self.assertAllEqual( + expected_data_on_workers[multi_worker_test_base.get_task_type()], result.numpy(), - expected_data_on_workers[multi_worker_test_base.get_task_type()]) + ) def testSimpleInputFromFnLastPartialBatch(self, strategy): @@ -661,8 +211,8 @@ class DistributedCollectiveAllReduceStrategyTest( expected_data_on_worker = {'chief': [8, 9, 10, 11], 'worker': [12, 13]} self.assertAllEqual( - result.numpy(), - expected_data_on_worker[multi_worker_test_base.get_task_type()]) + expected_data_on_worker[multi_worker_test_base.get_task_type()], + result.numpy()) def testReduceHostTensor(self, strategy): reduced = strategy.reduce( @@ -680,7 +230,7 @@ class DistributedCollectiveAllReduceStrategyTest( reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM, [(value, value), (value, value)]) - self.assertAllEqual(reduced, [2., 2.]) + self.assertAllEqual([2., 2.], reduced) def testReduceDeviceTensors(self, strategy): value = strategy.run(lambda: array_ops.identity(1.)) @@ -698,7 +248,7 @@ class DistributedCollectiveAllReduceStrategyTest( reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM, [(value, value), (value, value)]) - self.assertAllEqual(reduced, [2., 2.]) + self.assertAllEqual([2., 2.], reduced) # TODO(crccw): add a test that mixes device and host tensors after multi # worker strategy combinations can run on a fixed number of GPUs. @@ -716,7 +266,7 @@ class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase): # `None` otherwise. resolver = strategy.cluster_resolver if not isinstance(strategy, CollectiveAllReduceStrategy) and not isinstance( - strategy, TPUStrategy): + strategy, tpu_strategy.TPUStrategy): self.assertIsNone(resolver) return @@ -734,5 +284,4 @@ class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase): if __name__ == '__main__': - v2_compat.enable_v2_behavior() - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/strategy_gather_test.py b/tensorflow/python/distribute/strategy_gather_test.py new file mode 100644 index 00000000000..7cefcf396db --- /dev/null +++ b/tensorflow/python/distribute/strategy_gather_test.py @@ -0,0 +1,600 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for common methods in strategy classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import central_storage_strategy +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util +from tensorflow.python.distribute import tpu_strategy +from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +@combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.default_strategy, + strategy_combinations.one_device_strategy, + strategy_combinations.one_device_strategy_gpu, + strategy_combinations.central_storage_strategy_with_two_gpus, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_one_gpu, + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ], + mode=['eager'], + pure_eager=[True, False]) + combinations.combine( + strategy=[ + strategy_combinations.tpu_strategy, + strategy_combinations.tpu_strategy_packed_var, + strategy_combinations.tpu_strategy_one_step, + strategy_combinations.cloud_tpu_strategy, + ], + mode=['eager'], + pure_eager=[False])) +class GatherTest(test.TestCase, parameterized.TestCase): + + def _gather_same_shape_and_verify(self, value_on_replica, axis, pure_eager, + strategy): + distributed_values = strategy.experimental_distribute_values_from_function( + lambda _: array_ops.identity(value_on_replica)) + + def run(): + return strategy._gather(distributed_values, axis=axis) + + if not pure_eager: + run = def_function.function(run) + + all_results = [ + value_on_replica for _ in range(strategy.num_replicas_in_sync) + ] + expected_result = array_ops.concat(all_results, axis=axis) + self.assertAllEqual(expected_result, run().numpy()) + + def testGatherPerReplicaDense1D0Axis(self, strategy, pure_eager): + """A DistributedValues object with two tensors of shape [3] on each replica gathers to a tensor of [6].""" + single_value = constant_op.constant([1, 2, 3]) + axis = 0 + self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) + + def testGatherPerReplicaDense2D0Axis(self, strategy, pure_eager): + """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 0th dim to a tensor of [2, 3].""" + single_value = constant_op.constant([[1, 2, 3]]) + axis = 0 + self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) + + def testGatherPerReplicaDense2D1Axis(self, strategy, pure_eager): + """A DistributedValues object with two tensors of [1, 3] on each replica gathers along 1st dim to a tensor of [1, 6].""" + single_value = constant_op.constant([[1, 2, 3]]) + axis = 1 + self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) + + def testGatherPerReplicaDense3D0Axis(self, strategy, pure_eager): + """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 0th dim to a tensor of [2, 2, 2].""" + single_value = constant_op.constant([[[1, 2], [1, 2]]]) + axis = 0 + self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) + + def testGatherPerReplicaDense3D1Axis(self, strategy, pure_eager): + """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 1nd dimension to a tensor of [1, 4, 2].""" + single_value = constant_op.constant([[[1, 2], [1, 2]]]) + axis = 1 + self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) + + def testGatherPerReplicaDense3D2Axis(self, strategy, pure_eager): + """A DistributedValues object with two tensors of [1, 2, 2] on each replica gathers along 2nd dimension to a tensor of [1, 2, 4].""" + single_value = constant_op.constant([[[1, 2], [1, 2]]]) + axis = 2 + self._gather_same_shape_and_verify(single_value, axis, pure_eager, strategy) + + def testGatherDiffShapeAtAxis0(self, strategy, pure_eager): + """Different `Axis`-th (0) dimension: shape [1, 1], [2, 1] -> [3, 1].""" + + def value_fn(ctx): + return constant_op.constant( + 1, shape=(ctx.replica_id_in_sync_group + 1, 1)) + + distributed_values = strategy.experimental_distribute_values_from_function( + value_fn) + axis = 0 + + def run(): + return strategy._gather(distributed_values, axis=axis) + + if not pure_eager: + run = def_function.function(run) + + expected_result = constant_op.constant( + 1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1)) + + self.assertAllEqual(expected_result, run().numpy()) + + def testGatherDiffShapeAtAxis1(self, strategy, pure_eager): + """Different `Axis`-th (non-0) dimension: shape [1, 1], [1, 2] -> [1, 3].""" + + def value_fn(ctx): + return constant_op.constant( + 1, shape=(1, ctx.replica_id_in_sync_group + 1)) + + distributed_values = strategy.experimental_distribute_values_from_function( + value_fn) + axis = 1 + + def run(): + return strategy._gather(distributed_values, axis=axis) + + if not pure_eager: + run = def_function.function(run) + + expected_result = constant_op.constant( + 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) + + self.assertAllEqual(expected_result, run().numpy()) + + def testGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): + """Different at non-`axis`-th dimension : [1, 1], [1, 2], 0th -> raise error.""" + if isinstance(strategy, CollectiveAllReduceStrategy + ) and _get_num_replicas_per_client(strategy) > 1: + self.skipTest('b/167331966') + + if strategy.num_replicas_in_sync <= 1: + self.skipTest('Test for more than 1 replica only.') + + def value_fn(ctx): + return constant_op.constant( + 1, shape=(1, ctx.replica_id_in_sync_group + 1)) + + distributed_values = strategy.experimental_distribute_values_from_function( + value_fn) + axis = 0 + + def run(): + return strategy._gather(distributed_values, axis=axis) + + if not pure_eager: + run = def_function.function(run) + + if isinstance(strategy, CollectiveAllReduceStrategy): + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Shape mismatch'): + run() + elif isinstance( + strategy, + (mirrored_strategy.MirroredStrategy, + central_storage_strategy.CentralStorageStrategy)) and pure_eager: + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Dimensions of inputs should match'): + run() + else: + with self.assertRaisesRegex(ValueError, + r'Dimension \d in both shapes must be equal'): + run() + + def testGatherRaiseSparse(self, strategy, pure_eager): + dense_shape = [5, 2] + t0 = _make_indexed_slices( + values=[[1., 2.]], indices=[2], dense_shape=dense_shape) + + def run(value): + return strategy._gather(value, axis=0) + + with self.assertRaisesRegex( + NotImplementedError, + r'gather/all_gather does not support IndexedSlices'): + if pure_eager: + run(t0) + else: + def_function.function(run)(t0) + + def testGatherRaiseDifferentRank(self, strategy, pure_eager): + """Different rank: [1,], [1, 2] -> raise error.""" + if strategy.num_replicas_in_sync <= 1: + self.skipTest('Test for more than 1 replicas.') + if isinstance(strategy, CollectiveAllReduceStrategy + ) and _get_num_replicas_per_client(strategy) > 1: + self.skipTest('b/167331966') + def value_fn(ctx): + return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2))) + + distributed_values = strategy.experimental_distribute_values_from_function( + value_fn) + axis = 0 + + def run(): + return strategy._gather(distributed_values, axis=axis) + + if not pure_eager: + run = def_function.function(run) + + if isinstance(strategy, CollectiveAllReduceStrategy): + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Shape mismatch'): + run() + elif isinstance( + strategy, + (mirrored_strategy.MirroredStrategy, + central_storage_strategy.CentralStorageStrategy)) and pure_eager: + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Ranks of all input tensors should match'): + run() + elif _is_tpu_strategy(strategy) and pure_eager: + with self.assertRaisesRegex(ValueError, + r'Dimension \d in both shapes must be equal'): + run() + else: + with self.assertRaisesRegex(ValueError, + r'Shape must be rank \d but is rank \d'): + run() + + # Ideally, here we should split them into another test class, AllGatherTest. + # But doing that makes two initialize_tpu_system() calls and one of them times + # out, on Kokoro. Integrating two into one avoids it. + def _all_gather_same_shape_and_verify(self, value_on_replica, axis, + pure_eager, strategy): + per_replica_value = strategy.experimental_distribute_values_from_function( + lambda _: array_ops.identity(value_on_replica)) + + def replica_fn(per_replica_value): + ctx = ds_context.get_replica_context() + local_value = array_ops.identity(per_replica_value) + return ctx._all_gather(local_value, axis=axis) + + if not pure_eager: + replica_fn = def_function.function(replica_fn) + + result = strategy.experimental_local_results( + strategy.run(replica_fn, args=(per_replica_value,))) + + all_value = [value_on_replica for _ in range(strategy.num_replicas_in_sync)] + expect = array_ops.concat(all_value, axis=axis) + expected_result = [expect] * _get_num_replicas_per_client(strategy) + + self.assertAllClose(expected_result, result) + + def testAllGatherPerReplicaDense1D0Axis(self, strategy, pure_eager): + """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (3,) on two replica returns a PerReplica of tensor(s) with shape (6,).""" + single_value = constant_op.constant([1, 2, 3], dtype=dtypes.float32) + axis = 0 + self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, + strategy) + + def testAllGatherPerReplicaDense2D0Axis(self, strategy, pure_eager): + """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (2,3).""" + single_value = constant_op.constant([[1, 2, 3]]) + axis = 0 + self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, + strategy) + + def testAllGatherPerReplicaDense2D1Axis(self, strategy, pure_eager): + """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,3) on two replica returns PerReplica of tensor(s) with shape (1,6).""" + single_value = constant_op.constant([[1, 2, 3]]) + axis = 1 + self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, + strategy) + + def testAllGatherPerReplicaDense3D0Axis(self, strategy, pure_eager): + """all_gather(..., axis=0,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (2,2,2).""" + single_value = constant_op.constant([[[1, 2], [1, 2]]]) + axis = 0 + self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, + strategy) + + def testAllGatherPerReplicaDense3D1Axis(self, strategy, pure_eager): + """all_gather(..., axis=1,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,4,2).""" + single_value = constant_op.constant([[[1, 2], [1, 2]]]) + axis = 1 + self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, + strategy) + + def testAllGatherPerReplicaDense3D2Axis(self, strategy, pure_eager): + """all_gather(..., axis=2,...) a DistributedValues with a Tensor of shape (1,2,2) on two replica returns PerReplica of tensor(s) with shape (1,2,4).""" + single_value = constant_op.constant([[[1, 2], [1, 2]]]) + axis = 2 + self._all_gather_same_shape_and_verify(single_value, axis, pure_eager, + strategy) + + def testAllGatherDiffValueTPU(self, strategy, pure_eager): + # Test for TPU only since it can't be tested via testAllGatherDiffShape* + if not _is_tpu_strategy(strategy): + self.skipTest('Test for TPU only. For other strategies case already' + ' covered in other tests') + + data = [[1], [2], [3], [4], [5], [6], [7], [8]] + + axis = 0 + dataset = dataset_ops.DatasetV2.from_tensor_slices(data).batch(8) + input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) + + @def_function.function + def replica_fn(per_replica_value): + ctx = ds_context.get_replica_context() + return ctx._all_gather(array_ops.identity(per_replica_value), axis=axis) + + result = strategy.experimental_local_results( + strategy.run(replica_fn, args=(next(input_iterator),))) + + expected_result = [data] * _get_num_replicas_per_client(strategy) + self.assertAllClose(expected_result, result) + + def testAllGatherDiffShapeAtAxis0(self, strategy, pure_eager): + """Different `Axis==0`-th dimension: shape [1, 1], [2, 1] -> [3, 1].""" + + if _is_tpu_strategy(strategy): + self.skipTest('TPU does not support all_gather different shapes') + + def value_fn(ctx): + return constant_op.constant( + 1, shape=(ctx.replica_id_in_sync_group + 1, 1)) + + per_replica_value = strategy.experimental_distribute_values_from_function( + value_fn) + + expect = constant_op.constant( + 1, shape=(sum(range(strategy.num_replicas_in_sync + 1)), 1)) + + def run(value): + value_identity = array_ops.identity(value) + ctx = ds_context.get_replica_context() + return ctx._all_gather(value_identity, axis=0) + + if not pure_eager: + run = def_function.function(run) + + expected_result = [expect] * _get_num_replicas_per_client(strategy) + result = strategy.experimental_local_results( + strategy.run(run, args=(per_replica_value,))) + self.assertAllEqual(expected_result, result) + + def testAllGatherDiffShapeAtAxis1(self, strategy, pure_eager): + """Different `Axis`-th (not 0th) dimension: shape [1, 1], [1, 2] -> [1, 3].""" + if _is_tpu_strategy(strategy): + self.skipTest('TPU does not support all_gather different shapes') + + def value_fn(ctx): + return constant_op.constant( + 1, shape=(1, ctx.replica_id_in_sync_group + 1)) + + per_replica_value = strategy.experimental_distribute_values_from_function( + value_fn) + + expect = constant_op.constant( + 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) + + def run(value): + value_identity = array_ops.identity(value) + ctx = ds_context.get_replica_context() + return ctx._all_gather(value_identity, axis=1) + + if not pure_eager: + run = def_function.function(run) + + expected_result = [expect] * _get_num_replicas_per_client(strategy) + result = strategy.experimental_local_results( + strategy.run(run, args=(per_replica_value,))) + self.assertAllEqual(expected_result, result) + + def testAllGatherNest(self, strategy, pure_eager): + if _is_tpu_strategy(strategy): + self.skipTest('TPU does not support all_gather different shapes') + + axis = 1 + + def value_fn(ctx): + value = constant_op.constant( + 1, shape=(1, ctx.replica_id_in_sync_group + 1)) + return value + per_replica_value = strategy.experimental_distribute_values_from_function( + value_fn) + + expect_1 = constant_op.constant( + 1, shape=(1, sum(range(strategy.num_replicas_in_sync + 1)))) + + expected_per_replica_1 = [expect_1] * _get_num_replicas_per_client(strategy) + + value_2 = constant_op.constant([[[1, 2], [1, 2]]]) + + expect_2 = array_ops.concat( + [value_2 for _ in range(strategy.num_replicas_in_sync)], axis=axis) + + expected_per_replica_2 = [expect_2] * _get_num_replicas_per_client(strategy) + + def run(value): + value_1 = array_ops.identity(value) + value_3 = array_ops.identity(value_2) + ctx = ds_context.get_replica_context() + return ctx._all_gather([value_1, value_3], axis=axis) + + if not pure_eager: + run = def_function.function(run) + + result = strategy.run(run, args=(per_replica_value,)) + self.assertAllEqual(expected_per_replica_1, + strategy.experimental_local_results(result[0])) + self.assertAllEqual(expected_per_replica_2, + strategy.experimental_local_results(result[1])) + + def testAllGatherNest1D0Axis(self, strategy, pure_eager): + """all_gather(..., axis=0,...) a nest of DistributedValues.""" + single_value = constant_op.constant([1, 2, 3]) + axis = 0 + + def run(): + value_identity = array_ops.identity(single_value) + ctx = ds_context.get_replica_context() + return ctx._all_gather([value_identity, value_identity], axis=axis) + + if not pure_eager: + run = def_function.function(run) + + all_value = [single_value for _ in range(strategy.num_replicas_in_sync)] + expect = array_ops.concat(all_value, axis=axis) + expected_per_replica = [expect] * _get_num_replicas_per_client(strategy) + + result = strategy.run(run) + for gathered_result in result: + self.assertAllEqual(expected_per_replica, + strategy.experimental_local_results(gathered_result)) + + def testAllGatherRaiseDiffShapeAtNonAxis(self, strategy, pure_eager): + """Different at non-`axis`-th dimension : [2, 1], [1, 1], all_gather(...axis=1...) -> raise error.""" + if _is_tpu_strategy(strategy): + self.skipTest('TODO(b/169108777): raise a clear error message in xla.') + + if isinstance(strategy, CollectiveAllReduceStrategy + ) and _get_num_replicas_per_client(strategy) > 1: + self.skipTest('b/167331966') + + if strategy.num_replicas_in_sync <= 1: + self.skipTest('Test for more than 1 replica only.') + + def value_fn(ctx): + return constant_op.constant( + 1, shape=(1, ctx.replica_id_in_sync_group + 1)) + + per_replica_value = strategy.experimental_distribute_values_from_function( + value_fn) + + def run(value): + value_identity = array_ops.identity(value) + ctx = ds_context.get_replica_context() + return ctx._all_gather(value_identity, axis=0) + + if not pure_eager: + run = def_function.function(run) + + if isinstance(strategy, CollectiveAllReduceStrategy): + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Shape mismatch'): + strategy.run(run, args=(per_replica_value,)) + elif isinstance( + strategy, + (mirrored_strategy.MirroredStrategy, + central_storage_strategy.CentralStorageStrategy)) and pure_eager: + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Dimensions of inputs should match'): + strategy.run(run, args=(per_replica_value,)) + else: + with self.assertRaisesRegex(ValueError, + r'Dimension \d in both shapes must be equal'): + strategy.run(run, args=(per_replica_value,)) + + def testAllGatherRaiseSparse(self, strategy, pure_eager): + dense_shape = [5, 2] + t0 = _make_indexed_slices( + values=[[1., 2.]], indices=[2], dense_shape=dense_shape) + + def replica_fn(value): + ctx = ds_context.get_replica_context() + return ctx._all_gather(value, axis=0) + + with self.assertRaisesRegex( + NotImplementedError, + r'gather/all_gather does not support IndexedSlices'): + if not pure_eager: + strategy.run(def_function.function(replica_fn), args=(t0,)) + else: + strategy.run(replica_fn, args=(t0,)) + + def testAllGatherRaiseDifferentRank(self, strategy, pure_eager): + """Different rank: [1,], [1, 2] -> raise error.""" + if _is_tpu_strategy(strategy): + self.skipTest('TODO(b/169108777): raise a clear error message in xla.') + + if strategy.num_replicas_in_sync <= 1: + self.skipTest('Test for more than 1 replicas.') + if isinstance(strategy, CollectiveAllReduceStrategy + ) and _get_num_replicas_per_client(strategy) > 1: + self.skipTest('b/167331966') + def value_fn(ctx): + return array_ops.ones(shape=(range(1, ctx.replica_id_in_sync_group + 2))) + + per_replica_value = strategy.experimental_distribute_values_from_function( + value_fn) + + def run(value): + value_identity = array_ops.identity(value) + ctx = ds_context.get_replica_context() + return ctx._all_gather(value_identity, axis=0) + + if not pure_eager: + run = def_function.function(run) + + if isinstance(strategy, CollectiveAllReduceStrategy): + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Shape mismatch'): + strategy.run(run, args=(per_replica_value,)) + elif isinstance(strategy, + (mirrored_strategy.MirroredStrategy, + central_storage_strategy.CentralStorageStrategy)): + if pure_eager: + with self.assertRaisesRegex(errors.InvalidArgumentError, + r'Ranks of all input tensors should match'): + strategy.run(run, args=(per_replica_value,)) + else: + with self.assertRaisesRegex(ValueError, + r'Shape must be rank \d but is rank \d'): + strategy.run(run, args=(per_replica_value,)) + else: + with self.assertRaisesRegex(ValueError, + r'Dimension \d in both shapes must be equal'): + strategy.run(run, args=(per_replica_value,)) + + +def _make_indexed_slices(values, indices, dense_shape): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _get_num_replicas_per_client(strategy): + if isinstance(strategy, CollectiveAllReduceStrategy): + resolver = strategy.cluster_resolver + return max(nest.flatten(resolver.num_accelerators())[0], 1) + else: + return strategy.num_replicas_in_sync + + +def _is_tpu_strategy(strategy): + return isinstance(strategy, + (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, + tpu_strategy.TPUStrategyV2)) + + +if __name__ == '__main__': + test_util.main() diff --git a/tensorflow/python/distribute/strategy_reduce_test.py b/tensorflow/python/distribute/strategy_reduce_test.py deleted file mode 100644 index a87cce2f0b8..00000000000 --- a/tensorflow/python/distribute/strategy_reduce_test.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for `strategy.reduce`.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized - -from tensorflow.python.distribute import combinations -from tensorflow.python.distribute import reduce_util -from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.eager import def_function -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op - - -class StrategyReduceTest(test.TestCase, parameterized.TestCase): - - @combinations.generate( - combinations.combine( - distribution=strategy_combinations.all_strategies, - mode=["eager"] - )) - def test_reduce_with_axis(self, distribution): - - @def_function.function - def fn(): - return constant_op.constant([1., 2.]) - x = distribution.run(fn) - - x_m = distribution.reduce(reduce_util.ReduceOp.MEAN, x, axis=0) - self.assertEqual(1.5, self.evaluate(x_m)) - x_s = distribution.reduce(reduce_util.ReduceOp.SUM, x, axis=0) - self.assertEqual(3 * distribution.num_replicas_in_sync, self.evaluate(x_s)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index f660e3ab9f8..cbdcf51b9dd 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -28,6 +28,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util @@ -429,6 +430,8 @@ class DistributionTestBase(test.TestCase): self.assertEqual((1,) * len(global_step_tensors), global_step_values) def _test_numpy_dataset(self, strategy, session=None, run_in_function=False): + if not isinstance(strategy, distribute_lib.StrategyV1): + self.skipTest("n/a: V1 only") cached_session = session or self.cached_session() with strategy.scope(), cached_session as sess: x = np.asarray([[1, 2], [6, 12], [2, 4], [5, 10], [3, 6], [4, 8]]) diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py index 0f5dbd55491..82867edb4c2 100644 --- a/tensorflow/python/distribute/test_util.py +++ b/tensorflow/python/distribute/test_util.py @@ -20,11 +20,12 @@ from __future__ import print_function import functools +from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import collective_all_reduce_strategy -from tensorflow.python.distribute import cross_device_utils -from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import values -from tensorflow.python.eager import def_function +from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.util import nest @@ -57,16 +58,40 @@ def _gather(strategy, value): return array_ops.stack(value._values) assert len(strategy.extended.worker_devices) == len(value._values) inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] - collective_keys = strategy.extended._collective_keys - devices = strategy.extended.worker_devices - group_size = strategy.num_replicas_in_sync - - @def_function.function - def gather_fn(): - gathered = cross_device_utils.build_collective_gather( - inputs, devices, group_size, collective_keys, axis=0) - return distribute_utils.update_regroup( - strategy.extended, gathered, group=True) - - return gather_fn() + return strategy._gather(values.PerReplica(inputs), axis=0) # pylint: enable=protected-access + + +def set_logical_devices_to_at_least(device, num): + """Create logical devices of at least a given number.""" + if num < 1: + raise ValueError("`num` must be at least 1 not %r" % (num,)) + physical_devices = config.list_physical_devices(device) + if not physical_devices: + raise RuntimeError("No {} found".format(device)) + if len(physical_devices) >= num: + return + # By default each physical device corresponds to one logical device. We create + # multiple logical devices for the last physical device so that we have `num` + # logical devices. + num = num - len(physical_devices) + 1 + logical_devices = [] + for _ in range(num): + if device.upper() == "GPU": + logical_devices.append( + context.LogicalDeviceConfiguration(memory_limit=2048)) + else: + logical_devices.append(context.LogicalDeviceConfiguration()) + # Create logical devices from the the last device since sometimes the first + # GPU is the primary graphic card and may has less memory available. + config.set_logical_device_configuration(physical_devices[-1], logical_devices) + + +def main(enable_v2_behavior=True): + """All-in-one main function for tf.distribute tests.""" + if enable_v2_behavior: + v2_compat.enable_v2_behavior() + else: + v2_compat.disable_v2_behavior() + # TODO(b/131360402): configure default logical devices. + multi_process_runner.test_main() diff --git a/tensorflow/python/distribute/test_util_test.py b/tensorflow/python/distribute/test_util_test.py index 7dab2e199b1..165f97be6e2 100644 --- a/tensorflow/python/distribute/test_util_test.py +++ b/tensorflow/python/distribute/test_util_test.py @@ -1,13 +1,13 @@ # Copyright 2020 The TensorFlow Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, +# distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -23,8 +23,10 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import test_util +from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -71,5 +73,14 @@ class GatherTest(test.TestCase, parameterized.TestCase): self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync) +class LogicalDevicesTest(test.TestCase): + + def testLogicalCPUs(self): + context._reset_context() + test_util.set_logical_devices_to_at_least('CPU', 3) + cpu_device = config.list_physical_devices('CPU')[0] + self.assertLen(config.get_logical_device_configuration(cpu_device), 3) + + if __name__ == '__main__': - combinations.main() + test_util.main() diff --git a/tensorflow/python/distribute/tf_function_test.py b/tensorflow/python/distribute/tf_function_test.py index 967abebdfb3..38d070c2788 100644 --- a/tensorflow/python/distribute/tf_function_test.py +++ b/tensorflow/python/distribute/tf_function_test.py @@ -145,7 +145,7 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies, mode=["eager"])) - def testRetraceOnSaving(self, distribution): + def testRetraceOnSavingFirstTraceInScope(self, distribution): with distribution.scope(): v = variables.Variable(0.) @@ -167,6 +167,31 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase): func() self.assertEqual(prev_tracing_count, tracing_count[0]) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, mode=["eager"])) + def testRetraceOnSavingFirstTraceOutsideScope(self, distribution): + with distribution.scope(): + v = variables.Variable(0.) + + tracing_count = [0] + + @def_function.function + def func(): + tracing_count[0] += 1 + return v + 1. + + func() + prev_tracing_count = tracing_count[0] + with save_context.save_context(save_options.SaveOptions()): + func() + self.assertEqual(prev_tracing_count + 1, tracing_count[0]) + + prev_tracing_count = tracing_count[0] + with save_context.save_context(save_options.SaveOptions()): + func() + self.assertEqual(prev_tracing_count, tracing_count[0]) + if __name__ == "__main__": v2_compat.enable_v2_behavior() diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 1a3d49a2032..719147db5ee 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -23,8 +23,8 @@ import collections import contextlib import copy import weakref -from absl import logging +from absl import logging import numpy as np from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding @@ -52,6 +52,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -739,7 +740,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): atexit.register(async_wait) # Flag to turn on VariablePolicy - self._use_var_policy = False + self._use_var_policy = True def _validate_colocate_with_variable(self, colocate_with_variable): distribute_utils. validate_colocate(colocate_with_variable, self) @@ -752,7 +753,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): dataset, input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -809,7 +810,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): dataset, self._get_input_workers(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _distribute_datasets_from_function(self, dataset_fn, options): input_workers = self._get_input_workers(options) @@ -1021,6 +1022,54 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): distribute_utils.TPU_VARIABLE_CLASS_MAPPING, distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs) + def _gather_to_implementation(self, value, destinations, axis, + experimental_hints): + if not isinstance(value, values.DistributedValues): + return value + + value_list = value.values + # pylint: disable=protected-access + if isinstance( + value, + values.DistributedVariable) and value._packed_variable is not None: + value_list = tuple( + value._packed_variable.on_device(d) + for d in value._packed_variable.devices) + # pylint: enable=protected-access + + # Currently XLA op by op mode has a limit for the number of inputs for a + # single op, thus we break one `add_n` op into a group of `add_n` ops to + # work around the constraint. + if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT: + output = array_ops.concat(value_list, axis=axis) + else: + output = array_ops.concat( + value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis) + for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list), + _XLA_OP_BY_OP_INPUTS_LIMIT - 1): + output = array_ops.concat( + [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1], + axis=axis) + + output = self._broadcast_output(destinations, output) + return output + + def _broadcast_output(self, destinations, output): + devices = cross_device_ops_lib.get_devices_from(destinations) + + if len(devices) == 1: + # If necessary, copy to requested destination. + dest_canonical = device_util.canonicalize(devices[0]) + host_canonical = device_util.canonicalize(self._host_device) + + if dest_canonical != host_canonical: + with ops.device(dest_canonical): + output = array_ops.identity(output) + else: + output = cross_device_ops_lib.simple_broadcast(output, destinations) + + return output + def _reduce_to(self, reduce_op, value, destinations, experimental_hints): if (isinstance(value, values.DistributedValues) or tensor_util.is_tensor(value) @@ -1065,19 +1114,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): if reduce_op == reduce_util.ReduceOp.MEAN: output *= (1. / len(value_list)) - devices = cross_device_ops_lib.get_devices_from(destinations) - - if len(devices) == 1: - # If necessary, copy to requested destination. - dest_canonical = device_util.canonicalize(devices[0]) - host_canonical = device_util.canonicalize(self._host_device) - - if dest_canonical != host_canonical: - with ops.device(dest_canonical): - output = array_ops.identity(output) - else: - output = cross_device_ops_lib.simple_broadcast(output, destinations) - + output = self._broadcast_output(destinations, output) return output def _update(self, var, fn, args, kwargs, group): @@ -1345,6 +1382,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): # TPUStrategy. return False + def _get_local_replica_id(self, replica_id_in_sync_group): + return replica_id_in_sync_group + class _TPUReplicaContext(distribute_lib.ReplicaContext): """Replication Context class for TPU Strategy.""" @@ -1371,6 +1411,70 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext): """Places variables and ops on the specified logical device.""" return self.strategy.extended.experimental_logical_device(logical_device_id) + # TODO(wxinyi): Investigate whether to use cross_replica_sum to optimize it. + def _all_gather(self, value, axis, experimental_hints=None): + del experimental_hints + for v in nest.flatten(value): + if isinstance(v, ops.IndexedSlices): + raise NotImplementedError("gather/all_gather does not support " + "IndexedSlices") + + def _all_to_all(value, axis): + # The underlying AllToAllOp first do a split of the input value and then + # cross-replica communication and concatenation of the result. So we + # concatenate the local tensor here first. + inputs = array_ops.concat( + [value for _ in range(self.num_replicas_in_sync)], axis=0) + unordered_output = tpu_ops.all_to_all( + inputs, + concat_dimension=axis, + split_dimension=0, + split_count=self.num_replicas_in_sync) + + # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch. + # xla_id = xla.replica_id() + concat_replica_id = array_ops.concat([ + array_ops.expand_dims_v2(self.replica_id_in_sync_group, 0) + for _ in range(self.num_replicas_in_sync) + ], + axis=0) + replica_ids = tpu_ops.all_to_all( + concat_replica_id, + concat_dimension=0, + split_dimension=0, + split_count=self.num_replicas_in_sync) + + splited_unordered = array_ops.split( + unordered_output, + num_or_size_splits=self.num_replicas_in_sync, + axis=axis) + sorted_with_extra_dim = math_ops.unsorted_segment_sum( + array_ops.concat([ + array_ops.expand_dims(replica, axis=0) + for replica in splited_unordered + ], + axis=0), + replica_ids, + num_segments=self.num_replicas_in_sync) + + splited_with_extra_dim = array_ops.split( + sorted_with_extra_dim, + num_or_size_splits=self.num_replicas_in_sync, + axis=0) + squeezed = [ + array_ops.squeeze(replica, axis=0) + for replica in splited_with_extra_dim + ] + result = array_ops.concat(squeezed, axis=axis) + return result + + @custom_gradient.custom_gradient + def grad_wrapper(*xs): + ys = [_all_to_all(t, axis=axis) for t in xs] + return ys, lambda *dy_s: self._all_gather(dy_s, axis) + + return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) + def _set_last_step_outputs(ctx, last_step_tensor_outputs): """Sets the last step outputs on the given context.""" diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 70f05a3bb9c..e0afda84359 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -132,6 +132,28 @@ class TPUTest(test.TestCase): result = bar() + 1 self.assertAllEqual(result, 2) + def test_tpu_output_device(self): + + def foo(): + return 1 + 1 + + func1 = function.defun_with_attributes( + foo, attributes={"_XlaMustCompile": False}) + func2 = function.defun_with_attributes( + foo, attributes={ + "_OutputsOnOpDevice": True, + "_XlaMustCompile": False + }) + + with ops.device("/device:TPU:0"): + ret1 = func1() + ret2 = func2() + + self.assertAllEqual(ret1.backing_device, + "/job:localhost/replica:0/task:0/device:CPU:0") + self.assertAllEqual(ret2.backing_device, + "/job:localhost/replica:0/task:0/device:TPU:0") + def test_on_demand_op_with_dynamic_output(self): with ops.device("/device:TPU:0"): where_output = array_ops.where([True, False, True]) diff --git a/tensorflow/python/distribute/v1/BUILD b/tensorflow/python/distribute/v1/BUILD index ea973cce272..3c45d9d441e 100644 --- a/tensorflow/python/distribute/v1/BUILD +++ b/tensorflow/python/distribute/v1/BUILD @@ -16,7 +16,6 @@ cuda_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:kernels", "//tensorflow/python:math_ops", @@ -32,11 +31,10 @@ cuda_py_test( "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/distribute:values", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/eager:remote", "//tensorflow/python/eager:test", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/distribute/v1/cross_device_ops_test.py b/tensorflow/python/distribute/v1/cross_device_ops_test.py index 6026b6135bc..e54e1878d6d 100644 --- a/tensorflow/python/distribute/v1/cross_device_ops_test.py +++ b/tensorflow/python/distribute/v1/cross_device_ops_test.py @@ -27,7 +27,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.distribute import cluster_resolver -from tensorflow.python.distribute import collective_all_reduce_strategy +from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib @@ -40,11 +40,8 @@ from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import context -from tensorflow.python.eager import def_function -from tensorflow.python.eager import remote from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.framework import errors from tensorflow.python.framework import kernels from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -281,9 +278,6 @@ class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): - # TODO(b/159829194): move eager tests in this test class to - # tensorflow/python/distribute/cross_device_ops_test.py - reduction_to_one_combinations = combinations.combine( cross_device_ops=[ combinations.NamedObject("DefaultReductionToOneDevice", @@ -334,13 +328,13 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): def testChooseAlgorithm(self): # Not use nccl if there is any cpu device. self.assertIsInstance( - cross_device_ops_lib.choose_the_best(["/cpu:0"]), + cross_device_ops_lib.select_cross_device_ops(["/cpu:0"]), cross_device_ops_lib.ReductionToOneDevice) # Not use nccl if requested device is not visible to TensorFlow. - # TODO(yuefengz): make `choose_the_best` work with device strings + # TODO(yuefengz): make `select_cross_device_ops` work with device strings # self.assertIsInstance( - # cross_device_ops_lib.choose_the_best(["/gpu:100"]), + # cross_device_ops_lib.select_cross_device_ops(["/gpu:100"]), # cross_device_ops_lib.ReductionToOneDevice) if context.num_gpus() < 1: @@ -358,14 +352,14 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): with test.mock.patch.object(kernels, "get_registered_kernels_for_op", mock_get_registered_kernels_for_op): self.assertIsInstance( - cross_device_ops_lib.choose_the_best(devices), + cross_device_ops_lib.select_cross_device_ops(devices), cross_device_ops_lib.NcclAllReduce) # Not use nccl if nccl kernel is not found. with test.mock.patch.object(kernels, "get_registered_kernels_for_op", lambda _: []): self.assertIsInstance( - cross_device_ops_lib.choose_the_best(devices), + cross_device_ops_lib.select_cross_device_ops(devices), cross_device_ops_lib.ReductionToOneDevice) @combinations.generate(combinations.combine( @@ -444,9 +438,6 @@ CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, CrossDeviceOpsTestBase): - # TODO(b/159829194): move eager tests in this test class to - # tensorflow/python/distribute/cross_device_ops_test.py - collective_key_base = 100000 @classmethod @@ -460,6 +451,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, # Reusing keys is not supported well. So we have to give a different # collective key base for different tests. CollectiveAllReduceTest.collective_key_base += 100000 + mwms_lib.CollectiveAllReduceStrategy._collective_key_base = ( + CollectiveAllReduceTest.collective_key_base) def _get_test_objects(self, task_type, @@ -469,10 +462,7 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, use_strategy_object=False, local_mode=False): collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=10 + CollectiveAllReduceTest.collective_key_base, - op_instance_key_start=100 + CollectiveAllReduceTest.collective_key_base, - variable_instance_key_start=10000 + - CollectiveAllReduceTest.collective_key_base) + group_key_start=10 + CollectiveAllReduceTest.collective_key_base) if local_mode: if num_gpus: devices = ["/device:GPU:%d" % i for i in range(num_gpus)] @@ -480,13 +470,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, devices = ["/device:CPU:0"] if use_strategy_object: - strategy = ( - collective_all_reduce_strategy.CollectiveAllReduceStrategy - ._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access - strategy.extended._collective_keys = collective_keys - strategy.extended._cross_device_ops._collective_keys = collective_keys - strategy.extended._host_cross_device_ops._collective_keys = ( - collective_keys) + strategy = (mwms_lib.CollectiveAllReduceStrategy + ._from_local_devices(devices, communication=communication)) # pylint: disable=protected-access return strategy, devices, "" else: collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( @@ -516,10 +501,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, task_type=task_type, task_id=task_id, num_accelerators={"GPU": num_gpus}) - strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + strategy = mwms_lib.CollectiveAllReduceStrategy( cluster_resolver=resolver, communication=communication) - strategy.extended._collective_keys = collective_keys - strategy.extended._cross_device_ops._collective_keys = collective_keys return (strategy, devices, "grpc://" + self._cluster_spec[task_type][task_id]) else: @@ -862,176 +845,8 @@ class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase, self.assertAllEqual(reduced[1].values, [4.0, 4.0]) t.join() - @combinations.generate( - combinations.combine( - required_gpus=2, - mode="eager", - communication=[ - CollectiveCommunication.NCCL, CollectiveCommunication.RING - ])) - def testInputsAreFunctionArgs(self, communication): - # Function inputs don't have device placement. - hints = collective_util.Hints(bytes_per_pack=1) - collective, devices, _ = self._get_test_objects( - None, - None, - num_gpus=2, - communication=communication, - use_strategy_object=False, - local_mode=True) - devices = [device_util.canonicalize(d) for d in devices] - - @def_function.function - def reduce_fn(v): - self.assertEqual(v.values[0].device, "") - self.assertEqual(v.values[1].device, "") - # We only use NCCL for batch reduce with two or more values, so we use two - # values here. - reduced = collective.batch_reduce( - reduce_util.ReduceOp.SUM, [(v, v), (v, v)], experimental_hints=hints) - self.assertEqual(reduced[0].values[0].device, devices[0]) - self.assertEqual(reduced[0].values[1].device, devices[1]) - self.assertEqual(reduced[1].values[0].device, devices[0]) - self.assertEqual(reduced[1].values[1].device, devices[1]) - # Returning Mirrored only evaluates the primary value, which causes - # hanging, - return [reduced[0].values, reduced[1].values] - - v = _make_per_replica([1.0, 2.0], devices) - reduced = reduce_fn(v) - self.assertAllEqual(self.evaluate(reduced), [[3.0, 3.0], [3.0, 3.0]]) - - @combinations.generate( - combinations.combine( - required_gpus=[0, 1], - mode="eager", - communication=[CollectiveCommunication.RING])) - def testTimeoutReduceDense(self, communication, required_gpus): - hints = collective_util.Hints(timeout_seconds=1) - collective, devices, _ = self._get_test_objects( - "worker", - 0, - num_gpus=required_gpus, - communication=communication, - use_strategy_object=False) - remote.connect_to_cluster( - multi_worker_util.normalize_cluster_spec(self._cluster_spec), - protocol="grpc") - devices = [device_util.canonicalize(d) for d in devices] - v = _make_per_replica([1.0], devices) - - @def_function.function - def reduce_dense(): - collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) - - # The collective should time out because we only launch it on worker-0, - # while there're three workers in total. - with self.assertRaises(errors.DeadlineExceededError): - reduce_dense() - - # Reset since collective failures poison the context. - context._reset_context() # pylint: disable=protected-access - - @combinations.generate( - combinations.combine( - required_gpus=[0, 1], - mode="eager", - communication=[CollectiveCommunication.RING])) - def testTimeoutBatchReduceDense(self, communication, required_gpus): - hints = collective_util.Hints(timeout_seconds=1) - collective, devices, _ = self._get_test_objects( - "worker", - 0, - num_gpus=required_gpus, - communication=communication, - use_strategy_object=False) - remote.connect_to_cluster( - multi_worker_util.normalize_cluster_spec(self._cluster_spec), - protocol="grpc") - devices = [device_util.canonicalize(d) for d in devices] - v = _make_per_replica([1.0], devices) - - @def_function.function - def batch_reduce_dense(): - collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], hints) - - # The collective should time out because we only launch it on worker-0, - # while there're three workers in total. - with self.assertRaises(errors.DeadlineExceededError): - batch_reduce_dense() - - # Reset since collective failures poison the context. - context._reset_context() # pylint: disable=protected-access - - @combinations.generate( - combinations.combine( - required_gpus=[0, 1], - mode="eager", - communication=[CollectiveCommunication.RING])) - def testTimeoutReduceSparse(self, communication, required_gpus): - hints = collective_util.Hints(timeout_seconds=1) - collective, devices, _ = self._get_test_objects( - "worker", - 0, - num_gpus=required_gpus, - communication=communication, - use_strategy_object=False) - remote.connect_to_cluster( - multi_worker_util.normalize_cluster_spec(self._cluster_spec), - protocol="grpc") - devices = [device_util.canonicalize(d) for d in devices] - v = value_lib.PerReplica([ - _make_indexed_slices([[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) - ]) - - @def_function.function - def reduce_sparse(): - collective.reduce(reduce_util.ReduceOp.SUM, v, v, hints) - - # The collective should time out because we only launch it on worker-0, - # while there're three workers in total. - with self.assertRaises(errors.DeadlineExceededError): - reduce_sparse() - - # Reset since collective failures poison the context. - context._reset_context() # pylint: disable=protected-access - - @combinations.generate( - combinations.combine( - required_gpus=[0, 1], - mode="eager", - communication=[CollectiveCommunication.RING])) - def testTimeoutBatchReduceSparse(self, communication, required_gpus): - hints = collective_util.Hints(timeout_seconds=1) - collective, devices, _ = self._get_test_objects( - "worker", - 0, - num_gpus=required_gpus, - communication=communication, - use_strategy_object=False) - remote.connect_to_cluster( - multi_worker_util.normalize_cluster_spec(self._cluster_spec), - protocol="grpc") - devices = [device_util.canonicalize(d) for d in devices] - v = value_lib.PerReplica([ - _make_indexed_slices([[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) - ]) - - @def_function.function - def batch_reduce_sparse(): - collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)], hints) - - # The collective should time out because we only launch it on worker-0, - # while there're three workers in total. - with self.assertRaises(errors.DeadlineExceededError): - batch_reduce_sparse() - - # Reset since collective failures poison the context. - context._reset_context() # pylint: disable=protected-access - - if __name__ == "__main__": # Set default inter op thread pool size to one to ensure we don't exhaust the # thread pool with the additional executors to run collectives in eager. os.environ["TF_NUM_INTEROP_THREADS"] = "1" - combinations.main() + test.main() diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 7b4232d99cf..1464941523b 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -443,6 +443,12 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, """Holds a map from replica to variables.""" def __init__(self, strategy, values, aggregation, var_policy=None): + if (aggregation == variables_lib.VariableAggregation.MEAN and + not values[0].dtype.is_floating): + raise ValueError( + "creating distributed tf.Variable with aggregation=MEAN and a " + "non-floating dtype is not supported, please use a different " + "aggregation or dtype") self._distribute_strategy = strategy self._aggregation = aggregation super(DistributedVariable, self).__init__(values) @@ -955,6 +961,30 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, self._primary.handle] return obj_map, resource_map + def _write_object_proto(self, proto, options): + """Update a SavedObject proto for the caller. + + If a DistributedVariable object supports this method, it will be called when + saving with a pre-built `SavedObject` proto representing the object, plus an + instance of `SaveOptions`. This method is then free to modify that proto + instance. + + `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally + write out information about their components to the + `experimental_distributed_variable_components` field of a + `SavedVariable` (depending on the `SaveOptions` variable policy). + + Args: + proto: A pre-built `SavedObject` proto for this object. It is assumed this + will be a `SavedVariable` instance. + options: A `SaveOptions` instance. + """ + if self._policy: + if self._policy._is_mirrored(): # pylint: disable=protected-access + self._policy._write_object_proto(self, proto, options) # pylint: disable=protected-access + else: + self._write_object_proto(proto, options) + # We extend from `saveable_object.SaveableObject` instead of # `saveable_object_util.ResourceVariableSaveable` since we need to read the @@ -1055,6 +1085,26 @@ class MirroredVariable(DistributedVariable, Mirrored): return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} + def _write_object_proto(self, proto, options): + """Update a SavedObject proto for the caller. + + If a DistributedVariable object supports this method, it will be called when + saving with a pre-built `SavedObject` proto representing the object, plus an + instance of `SaveOptions`. This method is then free to modify that proto + instance. + + `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally + write out information about their components to the + `experimental_distributed_variable_components` field of a + `SavedVariable` (depending on the `SaveOptions` variable policy). + + Args: + proto: A pre-built `SavedObject` proto for this object. It is assumed this + will be a `SavedVariable` instance. + options: A `SaveOptions` instance. + """ + values_util.write_object_proto(self, proto, options) + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): """Converts a variable to a tensor.""" # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ @@ -1224,6 +1274,26 @@ class SyncOnReadVariable(DistributedVariable): return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} + def _write_object_proto(self, proto, options): + """Update a SavedObject proto for the caller. + + If a DistributedVariable object supports this method, it will be called when + saving with a pre-built `SavedObject` proto representing the object, plus an + instance of `SaveOptions`. This method is then free to modify that proto + instance. + + `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally + write out information about their components to the + `experimental_distributed_variable_components` field of a + `SavedVariable` (depending on the `SaveOptions` variable policy). + + Args: + proto: A pre-built `SavedObject` proto for this object. It is assumed this + will be a `SavedVariable` instance. + options: A `SaveOptions` instance. + """ + pass + # Register a conversion functions which reads the value of the variable, # allowing instances of the class to be used as tensors. @@ -1508,6 +1578,27 @@ class AutoPolicy(VariablePolicy): def get_restore_ops(self, var, tensor): return values_util.get_on_write_restore_ops(var, tensor) + def _write_object_proto(self, var, proto, options): + """Update a SavedObject proto for the caller. + + If a DistributedVariable object supports this method, it will be called when + saving with a pre-built `SavedObject` proto representing the object, plus an + instance of `SaveOptions`. This method is then free to modify that proto + instance. + + `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally + write out information about their components to the + `experimental_distributed_variable_components` field of a + `SavedVariable` (depending on the `SaveOptions` variable policy). + + Args: + var : A DistributedVariable object + proto: A pre-built `SavedObject` proto for this object. It is assumed this + will be a `SavedVariable` instance. + options: A `SaveOptions` instance. + """ + values_util.write_object_proto(var, proto, options) + class OnWritePolicy(AutoPolicy): """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 02a9926ea18..8a9f0acbd75 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -113,23 +113,6 @@ def mirrored_and_tpu_strategy_combinations(): class DistributedValuesTest(test.TestCase, parameterized.TestCase): - def testGetEager(self): - one = constant_op.constant(1) - two = constant_op.constant(2) - v = values_lib.DistributedValues((one, two)) - self.assertEqual(one, v._get()) - with distribute_lib.ReplicaContext(None, 1): - self.assertEqual(two, v._get()) - - def testGetGraph(self): - with context.graph_mode(), ops.Graph().as_default(): - one = constant_op.constant(1) - two = constant_op.constant(2) - v = values_lib.DistributedValues((one, two)) - self.assertEqual(one, v._get()) - with distribute_lib.ReplicaContext(None, 1): - self.assertEqual(two, v._get()) - @combinations.generate( combinations.combine( distribution=(strategy_combinations.all_strategies_minus_default + @@ -682,13 +665,13 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): with distribution.scope(): # We use four variables for convenience reasons. They have no special # meaning. - # - v is used whenever possible, and for the methods that require the - # dtype to be integer. + # - v is used whenever possible. # - w is used for scatter and gather, which require the variable to be # non-scalar. - # - y is used when the dtype needs to be float. + # - y is used when the dtype needs to be integer. Note that aggregation + # cannot be MEAN for integers. v = variables_lib.Variable( - 0, + 0., synchronization=synchronization, aggregation=aggregation, trainable=True) @@ -696,10 +679,11 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): synchronization=synchronization, aggregation=aggregation, trainable=True) - y = variables_lib.Variable( - 7., - synchronization=synchronization, - aggregation=aggregation) + if aggregation != variables_lib.VariableAggregation.MEAN: + y = variables_lib.Variable( + 0, + synchronization=synchronization, + aggregation=aggregation) # pylint: disable=g-long-lambda @@ -708,7 +692,7 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): _test(lambda: self.assertIs(v.constraint, None), v) # TODO(crccw): should we raise an error instead? _test(lambda: self.assertEqual(v.device, v._primary.device), v) - _test(lambda: self.assertEqual(v.dtype, dtypes.int32), v) + _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v) if not context.executing_eagerly(): _test(lambda: self.assertIs(v.graph, v._primary.graph), v) if not context.executing_eagerly(): @@ -722,9 +706,9 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): _test(lambda: self.assertTrue(v.trainable, True), v) # tf.Variable methods. - _test(lambda: check_ops.assert_equal_v2(v.assign(1), 1), v) - _test(lambda: check_ops.assert_equal_v2(v.assign_add(1), 2), v) - _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1), 1), v) + _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v) + _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v) + _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v) # TODO(b/148689177): Implement batch_scatter_update. # count_up_to() is skipped since it's deprecated. # eval() is skipped since it shouldn't called in a tf.function. @@ -736,7 +720,7 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): tensor_shape.TensorShape(())), v) # initialized_value() is skipped since it shouldn't called in a tf.function. # load() is skipped since it shouldn't called in a tf.function. - _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1), v) + _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v) # ref() is skipped since it shouldn't called in a tf.function. _test( lambda: check_ops.assert_equal_v2( @@ -770,62 +754,65 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): [2., 0.5, 1.]), w) # set_shape() is skipped since ResourceVariable doesn't implement it. # to_proto() is skipped since it shouldn't called in a tf.function. - _test(lambda: check_ops.assert_equal_v2(v.value(), 1), v) + _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v) # DistributedVariable should be treated as ResourceVariable, so it needs to # conform to ResourceVariable interface as well. _test(lambda: self.assertIs(v.handle, v._primary.handle), v) # Convert to tensor. - _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1), v) + _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v) # Control dependency. def _with_control_dep(): - with ops.control_dependencies([v.assign(1)]): + with ops.control_dependencies([v.assign(1.)]): return array_ops.identity(1) _test(_with_control_dep, v) # Operator overloads. - _test(lambda: check_ops.assert_equal_v2(v.assign(7), 7), v) - _test(lambda: check_ops.assert_equal_v2(v + 1, 8), v) - _test(lambda: check_ops.assert_equal_v2(3 + v, 10), v) - _test(lambda: check_ops.assert_equal_v2(v + v, 14), v) - _test(lambda: check_ops.assert_equal_v2(v - 2, 5), v) - _test(lambda: check_ops.assert_equal_v2(v - v, 0), v) - _test(lambda: check_ops.assert_equal_v2(v * 2, 14), v) - _test(lambda: check_ops.assert_equal_v2(3 * v, 21), v) - _test(lambda: check_ops.assert_equal_v2(v * v, 49), v) + _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v) + _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v) + _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v) + _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v) + _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v) + _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v) + _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v) + _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v) + _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v) _test( lambda: check_ops.assert_equal_v2( - math_ops.cast(v / 2, dtypes.float32), 3.5), v) + math_ops.cast(v / 2., dtypes.float32), 3.5), v) _test( lambda: check_ops.assert_equal_v2( - math_ops.cast(14 / v, dtypes.float32), 2.), v) - _test(lambda: check_ops.assert_equal_v2(v // 2, 3), v) - _test(lambda: check_ops.assert_equal_v2(15 // v, 2), v) - _test(lambda: check_ops.assert_equal_v2(v % 2, 1), v) - _test(lambda: check_ops.assert_equal_v2(16 % v, 2), v) - _test(lambda: _assert(v < 12), v) - _test(lambda: _assert(v <= 12), v) - _test(lambda: _assert(not v > 12), v) - _test(lambda: _assert(not v >= 12), v) - _test(lambda: _assert(not 12 < v), v) - _test(lambda: _assert(not 12 <= v), v) - _test(lambda: _assert(12 > v), v) - _test(lambda: _assert(12 >= v), v) - # XLA doesn't implement pow() with integers. - _test(lambda: check_ops.assert_near_v2(pow(y, 3.), 343.), y) - _test(lambda: check_ops.assert_near_v2(pow(2., y), 128.), y) - _test(lambda: check_ops.assert_equal_v2(abs(v), 7), v) - _test(lambda: check_ops.assert_equal_v2(v & 3, 3), v) - _test(lambda: check_ops.assert_equal_v2(3 & v, 3), v) - _test(lambda: check_ops.assert_equal_v2(v | 8, 15), v) - _test(lambda: check_ops.assert_equal_v2(16 | v, 23), v) - _test(lambda: check_ops.assert_equal_v2(v ^ 3, 4), v) - _test(lambda: check_ops.assert_equal_v2(11 ^ v, 12), v) - _test(lambda: check_ops.assert_equal_v2(-v, -7), v) - _test(lambda: check_ops.assert_equal_v2(~v, ~7), v) + math_ops.cast(14. / v, dtypes.float32), 2.), v) + _test(lambda: _assert(v < 12.), v) + _test(lambda: _assert(v <= 12.), v) + _test(lambda: _assert(not v > 12.), v) + _test(lambda: _assert(not v >= 12.), v) + _test(lambda: _assert(not 12. < v), v) + _test(lambda: _assert(not 12. <= v), v) + _test(lambda: _assert(12. > v), v) + _test(lambda: _assert(12. >= v), v) + _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v) + _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v) + _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v) + + # Operator overloads that only works for integers. + if aggregation != variables_lib.VariableAggregation.MEAN: + _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y) + _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y) + _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y) + _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y) + _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y) + _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y) + _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y) + _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y) + _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y) + _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y) + _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y) + _test(lambda: check_ops.assert_equal_v2(-y, -7), y) + _test(lambda: check_ops.assert_equal_v2(~y, ~7), y) # Index. if isinstance(distribution.extended, tpu_strategy.TPUExtended): @@ -1451,4 +1438,4 @@ def _make_index_slices(values, indices, dense_shape=None): if __name__ == "__main__": - combinations.main() + ds_test_util.main() diff --git a/tensorflow/python/distribute/values_util.py b/tensorflow/python/distribute/values_util.py index c9bcf8f9a56..0071ee67b67 100644 --- a/tensorflow/python/distribute/values_util.py +++ b/tensorflow/python/distribute/values_util.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import reduce_util +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops @@ -31,11 +32,43 @@ from tensorflow.python.saved_model import save_options from tensorflow.python.training.saving import saveable_object +def write_object_proto(var, proto, options): + """Update a SavedObject proto for the caller. + + If a DistributedVariable object supports this method, it will be called when + saving with a pre-built `SavedObject` proto representing the object, plus an + instance of `SaveOptions`. This method is then free to modify that proto + instance. + + `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally + write out information about their components to the + `experimental_distributed_variable_components` field of a + `SavedVariable` (depending on the `SaveOptions` variable policy). + + Args: + var: The DistributedVariable object. + proto: A pre-built `SavedObject` proto for this object. It is assumed this + will be a `SavedVariable` instance. + options: A `SaveOptions` instance. + """ + if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access + ): + for var in var.values: + var_proto = ( + proto.variable.experimental_distributed_variable_components.add()) + var_proto.name = var.name.split(":")[0] + var_proto.device = var.device + + def get_on_write_saveable(var, primary_var, name): """Return saveable spec for AUTO and ON_WRITE variables.""" # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): + if context.executing_eagerly() and not primary_var.is_initialized(): + # A SaveSpec tensor value of `None` indicates that the variable is + # uninitialized. + return None strategy = var.distribute_strategy return strategy.extended.read_var(var) @@ -256,7 +289,7 @@ def get_current_replica_id_as_int(): """Returns the current replica ID as an integer, or `None`.""" replica_context = ds_context.get_replica_context() if replica_context: - replica_id = replica_context.replica_id_in_sync_group + replica_id = replica_context._replica_id # pylint: disable=protected-access if not isinstance(replica_id, int): replica_id = tensor_util.constant_value(replica_id) else: diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py index ba77384a83a..e9eb9b77460 100644 --- a/tensorflow/python/distribute/vars_test.py +++ b/tensorflow/python/distribute/vars_test.py @@ -26,6 +26,7 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import values from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver @@ -302,63 +303,6 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): for component in mirrored.values: self.assertEqual(self.evaluate(component.read_value()), 3.) - @combinations.generate(strategy_with_var_policy()) - def testAssignAggregationMeanDTypeNonFloat(self, distribution): - if isinstance(distribution, _TPU_STRATEGIES): - self.skipTest("Fix sponge/6e8ab540-4c0f-4da5-aedf-86505ff810c9 before " - "reenabling test.") - - with distribution.scope(): - v = variables_lib.Variable( - 1, - aggregation=variable_scope.VariableAggregation.MEAN, - dtype=dtypes.int32) - self.evaluate(v.initializer) - - @def_function.function - def assign(): - ctx = ds_context.get_replica_context() - return v.assign(ctx.replica_id_in_sync_group) - - # disallow assign() with distributed value in replica context. - with self.assertRaisesRegex(ValueError, - "Cannot update non-float variables"): - self.evaluate( - distribution.experimental_local_results( - distribution.run(assign))) - - # allow assign() with same value in replica context. - @def_function.function - def assign_same(): - return v.assign(2) - - self.evaluate( - distribution.experimental_local_results( - distribution.run(assign_same))) - self.assertEqual(self.evaluate(v.read_value()), 2) - - # allow assign() with mirrored variable in replica context. - with distribution.scope(): - v2 = variables_lib.Variable( - 3, - aggregation=variable_scope.VariableAggregation.SUM, - dtype=dtypes.int32) - self.evaluate(v2.initializer) - - @def_function.function - def assign_mirrored(): - return v.assign(v2) - - self.evaluate( - distribution.experimental_local_results( - distribution.run(assign_mirrored))) - self.assertEqual(self.evaluate(v.read_value()), 3) - - # allow assign() in cross replica context. - with distribution.scope(): - self.evaluate(v.assign(4)) - self.assertEqual(self.evaluate(v.read_value()), 4) - @combinations.generate(strategy_with_var_policy()) def testInitializedToSameValueInsideEagerRun(self, distribution): if not context.executing_eagerly(): self.skipTest("eager only test") @@ -415,24 +359,24 @@ class OnWriteVariableSync(test.TestCase, parameterized.TestCase): with ops.init_scope(): if obj.w is None: obj.w = variables_lib.Variable( - 0, aggregation=variables_lib.VariableAggregation.MEAN) + 0., aggregation=variables_lib.VariableAggregation.MEAN) obj.v = variables_lib.Variable( obj.w.read_value(), aggregation=variables_lib.VariableAggregation.MEAN) self.evaluate(variables_lib.global_variables_initializer()) - return obj.v.assign_add(2) + return obj.v.assign_add(2.) per_replica_results = self.evaluate( distribution.experimental_local_results(distribution.run(assign))) - self.assertAllEqual([2, 2], per_replica_results) + self.assertAllEqual([2., 2.], per_replica_results) @combinations.generate(strategy_with_var_policy()) def testOperatorOverride(self, distribution): with distribution.scope(): v = variable_scope.variable( - 1, aggregation=variables_lib.VariableAggregation.MEAN) + 1, aggregation=variables_lib.VariableAggregation.SUM) self.evaluate(variables_lib.global_variables_initializer()) self.assertEqual(2, self.evaluate(v + 1)) @@ -1329,4 +1273,4 @@ class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase): if __name__ == "__main__": - combinations.main() + test_util.main() diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 50c50ce3c42..b743d6f736c 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -212,6 +212,7 @@ py_library( "//tensorflow/python:pywrap_tf_session", "//tensorflow/python:pywrap_tfe", "//tensorflow/python:tf2", + "//tensorflow/python:tfrt_utils", "//tensorflow/python:util", "//third_party/py/numpy", ], @@ -263,6 +264,7 @@ cuda_py_test( size = "small", srcs = ["context_test.py"], python_version = "PY3", + tfrt_enabled = True, deps = [ ":context", ":test", @@ -374,6 +376,7 @@ cuda_py_test( python_version = "PY3", tags = [ "no_windows", #TODO(b/139745667) + "notsan", #TODO(b/139745667) ], deps = [ ":backprop", @@ -432,7 +435,7 @@ cuda_py_test( size = "medium", srcs = ["function_argument_naming_test.py"], python_version = "PY3", - tfrt_enabled = True, + # TODO(b/169371527): insert transfer op in eager lowering for TFRT. deps = [ ":backprop", ":def_function", @@ -793,6 +796,7 @@ tf_py_test( name = "pywrap_tfe_test", srcs = ["pywrap_tfe_test.py"], python_version = "PY3", + tfrt_enabled = True, deps = [ ":backprop", ":context", @@ -835,6 +839,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove. + "//tensorflow/python/distribute/parallel_device", "//tensorflow/python/profiler:trace", "//tensorflow/python/training/tracking:base", ], @@ -871,7 +876,7 @@ cuda_py_test( name = "def_function_test", srcs = ["def_function_test.py"], python_version = "PY3", - tfrt_enabled = True, + # TODO(b/169371527): insert transfer op in eager lowering for TFRT. deps = [ ":def_function", "//tensorflow/python:client_testlib", @@ -980,6 +985,26 @@ tf_py_test( ], ) +cuda_py_test( + name = "wrap_function_device_test", + srcs = ["wrap_function_device_test.py"], + python_version = "PY3", + xla_enable_strict_auto_jit = False, + xla_enabled = False, + deps = [ + ":def_function", + ":wrap_function", + "//tensorflow/python:client_testlib", + "//tensorflow/python:config", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "remote", srcs = ["remote.py"], diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 0adb4698529..584fed73158 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn @@ -1722,6 +1723,43 @@ class JacobianTest(test.TestCase): g.jacobian(output, inp, experimental_use_pfor=True), g.jacobian(output, inp, experimental_use_pfor=False)) + def test_foldl_partial_function(self): + x = array_ops.zeros([3]) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(x) + result = def_function.function( + functools.partial(functional_ops.foldl_v2, lambda a, b: a + b))( + x) + self.assertAllClose([1., 1., 1.], + tape.jacobian(result, x, experimental_use_pfor=True)) + self.assertAllClose([1., 1., 1.], + tape.jacobian(result, x, experimental_use_pfor=False)) + + # Non-persistent tapes take a different function gradient path, but also + # work with pfor=True. + x = array_ops.zeros([3]) + with backprop.GradientTape() as tape: + tape.watch(x) + result = def_function.function( + functools.partial(functional_ops.foldl_v2, lambda a, b: a + b))( + x) + self.assertAllClose([1., 1., 1.], + tape.jacobian(result, x, experimental_use_pfor=True)) + + def test_foldl_pure_function(self): + + @def_function.function + def compute_jacobian(use_pfor): + x = array_ops.zeros([3]) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(x) + result = functools.partial(functional_ops.foldl_v2, lambda a, b: a + b)( + x) + return tape.jacobian(result, x, experimental_use_pfor=use_pfor) + + self.assertAllClose(compute_jacobian(use_pfor=True), + compute_jacobian(use_pfor=False)) + @test_util.run_all_in_graph_and_eager_modes class BatchJacobianTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py index e034cf0e296..52803be7bf9 100644 --- a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py +++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py @@ -108,16 +108,12 @@ class ResNet50Test(tf.test.TestCase): def test_apply(self): self._apply(defun=False) - @test_util.disable_tfrt( - 'TFE_ContextGetExecutorForThread not implemented b/156188669') def test_apply_async(self): self._apply(defun=False, execution_mode=context.ASYNC) - @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def test_apply_with_defun(self): self._apply(defun=True) - @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def test_apply_with_defun_async(self): self._apply(defun=True, execution_mode=context.ASYNC) diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 82e4d1e74cd..bf6e43f5e65 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -322,7 +322,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func = lambda: math_ops.multiply(m, m) self._run(func, num_iters) - @test_util.disable_tfrt("numpy() not supported") def benchmark_np_multiply(self): self._benchmark_np_multiply(self._m_2, 30000) @@ -331,7 +330,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = self._m_2.cpu() self._benchmark_tf_multiply(m, 30000) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tf_multiply_GPU(self): if not context.num_gpus(): return @@ -344,7 +342,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = self._m_2.cpu() self._benchmark_tf_multiply_op(m, 30000) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tf_multiply_op_GPU(self): if not context.num_gpus(): return @@ -358,7 +355,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m2 = self._m_1_3_3_1.cpu() self._benchmark_tf_conv2d(m1, m2, 30000) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tf_conv2d_GPU(self): if not context.num_gpus(): return @@ -371,7 +367,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = self._m_2 self._run(lambda: gen_array_ops.identity(m), 30000) - @test_util.disable_tfrt("identity not supported") def benchmark_slowpath_tf_identity(self): self._run(lambda: gen_array_ops.identity(1), 30000) @@ -386,7 +381,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(f, 30000) - @test_util.disable_tfrt("identity not supported") def benchmark_tf_gradient_function_identity(self): with context.device(CPU): m = gen_array_ops.identity(self._m_2) @@ -394,14 +388,12 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): lambda: backprop.gradients_function(gen_array_ops.identity, [0])(m), 30000) - @test_util.disable_tfrt("identity not supported") def benchmark_tf_gradient_forward_identity(self): with backprop.GradientTape() as tape: m = self._m_2 tape.watch(m) self._run(lambda: gen_array_ops.identity(m), 30000) - @test_util.disable_tfrt("gradients not supported") def benchmark_tf_gradient_tape_push_pop(self): def f(): @@ -410,7 +402,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(f, 30000) - @test_util.disable_tfrt("gradients not supported") def benchmark_tf_gradient_function_no_op(self): with context.device(CPU): m = gen_array_ops.identity(self._m_2) @@ -558,7 +549,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tf_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("async not supported") def benchmark_tf_matmul_2_by_2_CPU_async(self): with context.device(CPU): m = self._m_2_by_2.cpu() @@ -604,13 +594,11 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_defun_matmul_relaxed_shape( m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_args_matmul_2_by_2_CPU(self): with context.device(CPU): m = self._m_2_by_2.cpu() self._benchmark_defun_args_matmul(m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("async not supported") def benchmark_defun_matmul_2_by_2_CPU_async(self): with context.device(CPU): m = self._m_2_by_2.cpu() @@ -628,7 +616,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m, transpose_b=False, num_iters=self._num_iters_2_by_2) def_function.run_functions_eagerly(False) - @test_util.disable_tfrt("async not supported") def _benchmark_matmul_forward_backward_2_by_2_CPU_async( self, run_eager=False): def_function.run_functions_eagerly(run_eager) @@ -643,14 +630,12 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self): self._benchmark_matmul_forward_backward_2_by_2_CPU(False) - @test_util.disable_tfrt("async not supported") def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self): self._benchmark_matmul_forward_backward_2_by_2_CPU_async(False) def benchmark_defun_eager_matmul_forward_backward_2_by_2_CPU(self): self._benchmark_matmul_forward_backward_2_by_2_CPU(True) - @test_util.disable_tfrt("async not supported") def benchmark_defun_eager_matmul_forward_backward_2_by_2_CPU_async(self): self._benchmark_matmul_forward_backward_2_by_2_CPU_async(True) @@ -662,7 +647,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tf_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("async not supported") def benchmark_tf_matmul_2_by_2_GPU_async(self): if not context.num_gpus(): return @@ -674,7 +658,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): num_iters=self._num_iters_2_by_2, execution_mode=context.ASYNC) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_gen_math_ops_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -683,7 +666,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_gen_math_ops_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tfe_py_execute_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -692,7 +674,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tfe_py_execute_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_defun_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -701,7 +682,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_defun_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_defun_matmul_2_by_2_with_signature_GPU(self): if not context.num_gpus(): return @@ -710,7 +690,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_defun_matmul_with_signature( m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_defun_matmul_2_by_2_relaxed_shape_GPU(self): if not context.num_gpus(): return @@ -719,7 +698,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_defun_matmul_relaxed_shape( m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_args_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -727,7 +705,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = self._m_2_by_2.gpu() self._benchmark_defun_args_matmul(m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("async not supported") def benchmark_defun_matmul_2_by_2_GPU_async(self): if not context.num_gpus(): return @@ -757,7 +734,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tf_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("async not supported") def benchmark_tf_matmul_100_by_784_CPU_async(self): with context.device(CPU): m = self._m_100_by_784.cpu() @@ -779,7 +755,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tfe_py_fastpath_execute_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tfe_py_execute_matmul_100_by_784_CPU(self): with context.device(CPU): m = self._m_100_by_784.cpu() @@ -792,7 +767,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_defun_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tf_matmul_100_by_784_GPU(self): if not context.num_gpus(): return @@ -801,7 +775,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tf_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("async not supported") def benchmark_tf_matmul_100_by_784_GPU_async(self): if not context.num_gpus(): return @@ -813,7 +786,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): num_iters=self._num_iters_100_by_784, execution_mode=context.ASYNC) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_gen_math_ops_matmul_100_by_784_GPU(self): if not context.num_gpus(): return @@ -822,7 +794,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_gen_math_ops_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tfe_py_execute_matmul_100_by_784_GPU(self): if not context.num_gpus(): return @@ -831,7 +802,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tfe_py_execute_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_defun_matmul_100_by_784_GPU(self): if not context.num_gpus(): return @@ -840,7 +810,8 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_defun_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("copy to GPU not supported") + @test_util.disable_tfrt( + "b/169371527: Support inserting transfer op in lowering.") def benchmark_nested_defun_matmul_100_by_784_GPU(self): m = self._m_100_by_784.gpu() self._benchmark_nested_defun_matmul( @@ -952,47 +923,45 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func = lambda: math_ops.reduce_logsumexp(x) self._run(func, 3000, execution_mode=execution_mode) - @test_util.disable_tfrt("reduce logsumexp not supported") + @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_CPU(self): self._benchmark_tf_reduce_logsumexp() - @test_util.disable_tfrt("reduce logsumexp not supported") + @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_CPU_async(self): self._benchmark_tf_reduce_logsumexp(execution_mode=context.ASYNC) - @test_util.disable_tfrt("reduce logsumexp not supported") + @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_GPU(self): self._benchmark_tf_reduce_logsumexp(device=GPU) - @test_util.disable_tfrt("reduce logsumexp not supported") + @test_util.disable_tfrt("b/169371018: Support ScalarHost in RTFB.") def benchmark_tf_reduce_logsumexp_GPU_async(self): self._benchmark_tf_reduce_logsumexp(device=GPU, execution_mode=context.ASYNC) - @test_util.disable_tfrt("reduce logsumexp not supported") + @test_util.disable_tfrt( + "b/169371527: Support inserting transfer op in lowering.") def benchmark_tf_reduce_logsumexp_CPU_defunc(self): self._benchmark_tf_reduce_logsumexp(defunc=True) - @test_util.disable_tfrt("reduce logsumexp not supported") + @test_util.disable_tfrt( + "b/169371527: Support inserting transfer op in lowering.") def benchmark_tf_reduce_logsumexp_CPU_async_defun(self): self._benchmark_tf_reduce_logsumexp( execution_mode=context.ASYNC, defunc=True) - @test_util.disable_tfrt("reduce logsumexp not supported") def benchmark_tf_reduce_logsumexp_GPU_defun(self): self._benchmark_tf_reduce_logsumexp(device=GPU, defunc=True) - @test_util.disable_tfrt("reduce logsumexp not supported") def benchmark_tf_reduce_logsumexp_GPU_async_defun(self): self._benchmark_tf_reduce_logsumexp( device=GPU, execution_mode=context.ASYNC, defunc=True) - @test_util.disable_tfrt("reduce logsumexp not supported") def benchmark_tf_reduce_logsumexp_GPU_defun_compile(self): self._benchmark_tf_reduce_logsumexp( device=GPU, defunc=True, xla_compile=True) - @test_util.disable_tfrt("reduce logsumexp not supported") def benchmark_tf_reduce_logsumexp_GPU_async_defun_compile(self): self._benchmark_tf_reduce_logsumexp( device=GPU, execution_mode=context.ASYNC, defunc=True, xla_compile=True) @@ -1004,19 +973,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func = lambda: math_ops.tensordot(a, b, [[1], [0]]) self._run(func, 30000, execution_mode=execution_mode) - @test_util.disable_tfrt("tensordot not supported") def benchmark_tf_tensordot_CPU(self): self._benchmark_tf_tensordot() - @test_util.disable_tfrt("tensordot not supported") def benchmark_tf_tensordot_CPU_async(self): self._benchmark_tf_tensordot(execution_mode=context.ASYNC) - @test_util.disable_tfrt("tensordot not supported") def benchmark_tf_tensordot_GPU(self): self._benchmark_tf_tensordot(device=GPU) - @test_util.disable_tfrt("tensordot not supported") def benchmark_tf_tensordot_GPU_async(self): self._benchmark_tf_tensordot(device=GPU, execution_mode=context.ASYNC) @@ -1025,63 +990,48 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func = lambda: array_ops.zeros(shape, dtype) self._run(func, 3000) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_2_by_2_float32_CPU(self): self._benchmark_tf_zeros((2, 2), dtypes.float32) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_2_by_2_bool_CPU(self): self._benchmark_tf_zeros((2, 2), dtypes.bool) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_2_by_2_string_CPU(self): self._benchmark_tf_zeros((2, 2), dtypes.string) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_2_by_2_float32_GPU(self): self._benchmark_tf_zeros((2, 2), dtypes.float32, device=GPU) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_2_by_2_bool_GPU(self): self._benchmark_tf_zeros((2, 2), dtypes.bool, device=GPU) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_30_by_30_float32_CPU(self): self._benchmark_tf_zeros((30, 30), dtypes.float32) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_30_by_30_bool_CPU(self): self._benchmark_tf_zeros((30, 30), dtypes.bool) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_30_by_30_string_CPU(self): self._benchmark_tf_zeros((30, 30), dtypes.string) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_30_by_30_float32_GPU(self): self._benchmark_tf_zeros((30, 30), dtypes.float32, device=GPU) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_30_by_30_bool_GPU(self): self._benchmark_tf_zeros((30, 30), dtypes.bool, device=GPU) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_100_by_100_float32_CPU(self): self._benchmark_tf_zeros((100, 100), dtypes.float32) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_100_by_100_bool_CPU(self): self._benchmark_tf_zeros((100, 100), dtypes.bool) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_100_by_100_string_CPU(self): self._benchmark_tf_zeros((100, 100), dtypes.string) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_100_by_100_float32_GPU(self): self._benchmark_tf_zeros((100, 100), dtypes.float32, device=GPU) - @test_util.disable_tfrt("context.device not supported") def benchmark_tf_zeros_100_by_100_bool_GPU(self): self._benchmark_tf_zeros((100, 100), dtypes.bool, device=GPU) @@ -1180,7 +1130,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = self._m_2_by_2.cpu() self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_tf_transpose_2_by_2_GPU(self): with context.device(GPU): m = self._m_2_by_2.gpu() @@ -1191,7 +1140,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = resource_variable_ops.ResourceVariable(self._m_2_by_2) self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("Cannot convert array to EagerTensor of dtype int32") def benchmark_tf_transpose_variable_2_by_2_GPU(self): with context.device(GPU): m = resource_variable_ops.ResourceVariable(self._m_2_by_2) @@ -1261,7 +1209,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = resource_variable_ops.ResourceVariable(self._m_2_by_2) self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_read_variable_op_2_by_2_GPU(self): if not context.num_gpus(): return @@ -1275,7 +1222,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_read_variable_with_tape( m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("copy to GPU not supported") def benchmark_read_variable_op_with_tape_2_by_2_GPU(self): if not context.num_gpus(): return @@ -1293,8 +1239,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(scan, 100) - @test_util.disable_tfrt( - "tf.While not supported in TF to CoreRT lowing. b/162685874") + @test_util.disable_tfrt("tf.While not supported RTFB tensor. b/169374895") def benchmarkScanDefun(self): elems = math_ops.range(1600) @@ -1361,12 +1306,10 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): def benchmark_convert_numpy_float_uncached(self): self._benchmark_convert_constant(np.array(42.0), cached=False) - @test_util.disable_tfrt("convert to tensor not supported") def benchmark_convert_3x_list_to_tensor(self): xs = [1, 2, 3] self._run(lambda: ops.convert_to_tensor(xs), 1000) - @test_util.disable_tfrt("convert to tensor not supported") def benchmark_convert_3x_array_to_tensor(self): xs = np.array([1, 2, 3], dtype=np.int32) self._run(lambda: ops.convert_to_tensor(xs), 1000) @@ -1375,7 +1318,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): xs = [[0] * 2] * 40 self._run(lambda: constant_op.constant(xs), 1000) - @test_util.disable_tfrt("convert to tensor not supported") def benchmark_constant_40x2_array_to_tensor(self): xs = np.array([[0] * 2] * 40, dtype=np.int32) self._run(lambda: constant_op.constant(xs), 1000) @@ -1454,15 +1396,12 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): with context.device(CPU): self._run(benchmark_fn, 10) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkTenThousandResourceReadsInCondInInnerFunc(self): self._benchmarkResourceReadsInCondInInnerFunc(10000) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkHundredResourceReadsInCondInInnerFunc(self): self._benchmarkResourceReadsInCondInInnerFunc(100) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkTenResourceReadsInCondInInnerFunc(self): self._benchmarkResourceReadsInCondInInnerFunc(10) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index c412b338b58..8fe158cff93 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -39,6 +39,7 @@ from tensorflow.python.eager import executor from tensorflow.python.eager import monitoring from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import tfrt_utils from tensorflow.python.util import compat from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib @@ -74,6 +75,9 @@ _python_eager_context_create_counter = monitoring.Counter( "/tensorflow/api/python/eager_context_create_counter", "Counter for number of eager contexts created in Python.") +# Re-exporting through context. +is_tfrt_enabled = tfrt_utils.enabled + class _EagerTensorCache(object): """Simple cache which evicts items based on length in a FIFO manner.""" @@ -335,19 +339,6 @@ class _TensorCacheDeleter(object): del _tensor_caches_map[self._context_id] -# If the below import is made available through the BUILD rule, then this -# function is overridden and will instead return True and cause Tensorflow -# graphs to run with TFRT. -def is_tfrt_enabled(): - return None - - -try: - from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top -except: # pylint: disable=bare-except - pass - - # TODO(agarwal): rename to EagerContext / EagerRuntime ? # TODO(agarwal): consider keeping the corresponding Graph here. class Context(object): @@ -423,10 +414,10 @@ class Context(object): raise ValueError( "execution_mode should be None/SYNC/ASYNC. Got %s" % execution_mode) if execution_mode is None: - execution_mode = ASYNC if is_tfrt_enabled() else SYNC + execution_mode = ASYNC if tfrt_utils.enabled() else SYNC self._default_is_async = execution_mode == ASYNC self._lazy_remote_inputs_copy = None - self._use_tfrt = is_tfrt_enabled() + self._use_tfrt = tfrt_utils.enabled() self._server_def = server_def self._collective_ops_server_def = None self._collective_leader = None @@ -1500,9 +1491,9 @@ class Context(object): self._virtual_device_map[dev] = virtual_devices - def get_compiler_ir(self, function_name, args, stage="hlo"): + def get_compiler_ir(self, device_name, function_name, args, stage="hlo"): return pywrap_tfe.TF_GetCompilerIr(self._context_handle, function_name, - stage, self.device_name, args) + stage, device_name, args) @deprecated( None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True) diff --git a/tensorflow/python/eager/context_test.py b/tensorflow/python/eager/context_test.py index 086f943b3b0..fe86104cc0b 100644 --- a/tensorflow/python/eager/context_test.py +++ b/tensorflow/python/eager/context_test.py @@ -75,6 +75,7 @@ class ContextTest(test.TestCase): del tensor2 self.assertIs(weak_c(), None) + @test_util.disable_tfrt('b/169294215: tfrt does not support RunMetadata yet') def testSimpleGraphCollection(self): @def_function.function @@ -111,20 +112,24 @@ class ContextTest(test.TestCase): _ = context.get_function_def('this_should_not_be_found') @test_util.run_gpu_only + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsage(self): array_ops.zeros([10]) # Allocate some memory on the GPU. self.assertGreater( context.context().get_total_memory_usage('GPU:0'), 0) + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsageCPU(self): with self.assertRaisesRegex(ValueError, 'CPU does not support'): context.context().get_total_memory_usage('CPU:0') + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsageUnknownDevice(self): with self.assertRaisesRegex(ValueError, 'Failed parsing device name'): context.context().get_total_memory_usage('unknown_device') @test_util.run_gpu_only + @test_util.disable_tfrt('b/169293680: TFE_GetTotalMemoryUsage is unsupported') def testGetMemoryUsageAmbiguousDevice(self): if len(context.context().list_physical_devices('GPU')) < 2: self.skipTest('Need at least 2 GPUs') diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index fdc3932d498..2dfb69c9b07 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -324,7 +324,7 @@ class TFETest(test_util.TensorFlowTestCase): else: ops.disable_tensor_equality() - @test_util.disable_tfrt('Async execution mode not supported in TFRT.') + @test_util.disable_tfrt('Get execution mode not supported in TFRT.') def testContext(self): ctx = context.Context() self.assertTrue(ctx.executing_eagerly()) @@ -384,7 +384,6 @@ class TFETest(test_util.TensorFlowTestCase): with ctx.device(device_spec): self.assertEqual(device_name, ctx.device_name) - @test_util.disable_tfrt('Async execution mode not supported in TFRT.') def testAsyncBasic(self): ctx = context.Context(execution_mode=context.ASYNC) ctx.ensure_initialized() @@ -498,8 +497,6 @@ class TFETest(test_util.TensorFlowTestCase): cpu.__exit__() @test_util.run_gpu_only - @test_util.disable_tfrt('Device name incorrect (known issue for runtime ' - 'fallback).') def testReEntrant(self): cpu = context.device('cpu:0') gpu = context.device('gpu:0') @@ -546,7 +543,6 @@ class TFETest(test_util.TensorFlowTestCase): x.gpu(context.context().num_gpus() + 1) @test_util.run_gpu_only - @test_util.disable_tfrt('Async execution mode not supported in TFRT.') def testCopyBetweenDevicesAsync(self): with context.execution_mode(context.ASYNC): x = constant_op.constant([[1., 2.], [3., 4.]]) @@ -563,7 +559,6 @@ class TFETest(test_util.TensorFlowTestCase): context.context().executor.clear_error() @test_util.run_gpu_only - @test_util.disable_tfrt('Device placement policy not configurable yet.') def testCopyScope(self): constant = constant_op.constant(1.0) with ops.device('gpu:0'): @@ -586,7 +581,7 @@ class TFETest(test_util.TensorFlowTestCase): self.assertAllEqual(test_fn(test_var), 1.0) - @test_util.disable_tfrt('Async execution mode not supported in TFRT.') + @test_util.disable_tfrt('PyFunc is not supported in TFRT.') def testPyFunctionAsync(self): def simple_fn(v): @@ -632,7 +627,6 @@ class TFETest(test_util.TensorFlowTestCase): attrs=('T', three.dtype.as_datatype_enum))[0] self.assertAllEqual(15, product) - @test_util.disable_tfrt('Async execution mode not supported in TFRT.') def testExecuteBasicAsync(self): with context.execution_mode(context.ASYNC): three = constant_op.constant(3) @@ -1053,8 +1047,6 @@ class TFETest(test_util.TensorFlowTestCase): for t in threads: t.join() - @test_util.disable_tfrt('Does not support converting DT_RESOURCE' - 'to op attr type yet.') def testEmptyResourceReturned(self): with ops.device('CPU:0'): v = variables.Variable(1.) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 08908f71cec..5cc859aa033 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -27,15 +27,18 @@ import six from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.distribute.parallel_device import parallel_device from tensorflow.python.eager import context from tensorflow.python.eager import function as function_lib from tensorflow.python.eager import lift_to_graph +from tensorflow.python.framework import errors from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import trace @@ -322,37 +325,7 @@ def experimental_run_functions_eagerly(run_eagerly): invocations of `tf.function` run eagerly instead of running as a traced graph function. - This can be useful for debugging or profiling. For example, let's say you - implemented a simple iterative sqrt function, and you want to collect the - intermediate values and plot the convergence. Appending the values to a list - in `@tf.function` normally wouldn't work since it will just record the Tensors - being traced, not the values. Instead, you can do the following. - - >>> ys = [] - >>> - >>> @tf.function - ... def sqrt(x): - ... y = x / 2 - ... d = y - ... for _ in range(10): - ... d /= 2 - ... if y * y < x: - ... y += d - ... else: - ... y -= d - ... ys.append(y.numpy()) - ... return y - >>> - >>> tf.config.experimental_run_functions_eagerly(True) - >>> sqrt(tf.constant(2.)) - - >>> ys - [1.5, 1.25, 1.375, 1.4375, 1.40625, 1.421875, 1.4140625, 1.4179688, 1.4160156, - 1.4150391] - >>> tf.config.experimental_run_functions_eagerly(False) - - Calling `tf.config.experimental_run_functions_eagerly(False)` will undo this - behavior. + See `tf.config.run_functions_eagerly` for an example. Note: This flag has no effect on functions passed into tf.data transformations as arguments. tf.data functions are never executed eagerly and are always @@ -372,37 +345,31 @@ def run_functions_eagerly(run_eagerly): invocations of `tf.function` run eagerly instead of running as a traced graph function. - This can be useful for debugging or profiling. For example, let's say you - implemented a simple iterative sqrt function, and you want to collect the - intermediate values and plot the convergence. Appending the values to a list - in `@tf.function` normally wouldn't work since it will just record the Tensors - being traced, not the values. Instead, you can do the following. + This can be useful for debugging. - >>> ys = [] - >>> - >>> @tf.function - ... def sqrt(x): - ... y = x / 2 - ... d = y - ... for _ in range(10): - ... d /= 2 - ... if y * y < x: - ... y += d - ... else: - ... y -= d - ... ys.append(y.numpy()) - ... return y - >>> + >>> def my_func(a): + ... print("Python side effect") + ... return a + a + >>> a_fn = tf.function(my_func) + + >>> # A side effect the first time the function is traced + >>> a_fn(tf.constant(1)) + Python side effect + + + >>> # No further side effect, as the traced function is called + >>> a_fn(tf.constant(2)) + + + >>> # Now, switch to eager running >>> tf.config.run_functions_eagerly(True) - >>> sqrt(tf.constant(2.)) - - >>> ys - [1.5, 1.25, 1.375, 1.4375, 1.40625, 1.421875, 1.4140625, 1.4179688, 1.4160156, - 1.4150391] - >>> tf.config.run_functions_eagerly(False) + >>> # Side effect, as the function is called directly + >>> a_fn(tf.constant(2)) + Python side effect + - Calling `tf.config.run_functions_eagerly(False)` will undo this - behavior. + >>> # Turn this back off + >>> tf.config.run_functions_eagerly(False) Note: This flag has no effect on functions passed into tf.data transformations as arguments. tf.data functions are never executed eagerly and are always @@ -430,6 +397,45 @@ def functions_run_eagerly(): return RUN_FUNCTIONS_EAGERLY +def _evaluate_var_is_initialized(variables): + """Compute booleans indicating whether each variable is initialized.""" + with ops.init_scope(): + var_is_initialized = [] + for v in variables: + var_is_initialized.append( + resource_variable_ops.var_is_initialized_op(v.handle)) + try: + # Stack all the var_is_initialized values into one tensor and interpret + # the numpy value. This will reduce the number of RPCs between client and + # worker in the remote case. + return array_ops.stack(var_is_initialized).numpy() + except errors.UnimplementedError: + # Some devices do not support implicit copy-off to host. Fall back to + # variable-by-variable processing. + for index, v in enumerate(variables): + try: + numpy_value = var_is_initialized[index].numpy() + except errors.UnimplementedError: + # This is a variable on a parallel device; we'll extract its value on + # each replica and assert that they're identical. + components = parallel_device.unpack(var_is_initialized[index]) + with ops.device(None): + components = array_ops.stack(components) + all_initialized = math_ops.reduce_all(components).numpy() + any_initialized = math_ops.reduce_any(components).numpy() + if all_initialized != any_initialized: + raise NotImplementedError( + ("Some but not all components of a parallel variable {} were " + "initialized between their creation in a tf.function and " + "the function's trace having completed. This is not yet " + "supported; consider initializing either all or none of the " + "components, or moving initialization out of the function." + ).format(repr(v))) + numpy_value = all_initialized + var_is_initialized[index] = numpy_value + return var_is_initialized + + class FunctionDeleter(object): __slots__ = ["func_graph"] @@ -773,7 +779,40 @@ class Function(object): self._function_spec = function_lib.FunctionSpec.from_function_and_signature( self._python_function, self.input_signature) + # TODO: Remove this private method after updating all its uses + # A good moment to do this could be when the experimental label is removed def _get_tracing_count(self): + return self.experimental_get_tracing_count() + + def experimental_get_tracing_count(self): + """Returns the number of times the function has been traced. + + For more information on when a function is traced and when it is + traced multiple times see https://www.tensorflow.org/guide/function. + Example: + + >>> @tf.function + ... def double(a): + ... return a + a + >>> double(tf.constant(1)) + >>> double(tf.constant(2)) + >>> double.experimental_get_tracing_count() + 1 + >>> double(tf.constant("a")) + >>> double.experimental_get_tracing_count() + 2 + + + The first time experimental_get_tracing_count is called + it returns 1, as the function is traced the first + time it is called, and the second time the same graph is used + since we're calling it with a parameter of the same type. + + The second time experimental_get_tracing_count is called + it returns 2, as we called double with a + different argument type, and so it was traced again. + + """ result = self._stateless_fn.tracing_count if self._stateless_fn else 0 result += self._stateful_fn.tracing_count if self._stateful_fn else 0 return result @@ -784,11 +823,11 @@ class Function(object): with trace.Trace(self._name, tf_function_call="eager"): return self._python_function(*args, **kwds) - tracing_count = self._get_tracing_count() + tracing_count = self.experimental_get_tracing_count() with trace.Trace(self._name) as tm: result = self._call(*args, **kwds) compiler = "xla" if self._experimental_compile else "nonXla" - new_tracing_count = self._get_tracing_count() + new_tracing_count = self.experimental_get_tracing_count() without_tracing = (tracing_count == new_tracing_count) execution_mode = "notTraced" if without_tracing else "traced" tm.set_metadata(tf_function_call=execution_mode + "-" + compiler, @@ -980,7 +1019,7 @@ class Function(object): fn_name = concrete_fn.name # pylint: disable=protected-access - canon_args, _, _, _ = \ + _, _, _, filtered_flat_args = \ concrete_fn._function_spec.canonicalize_function_inputs( *args, **kwargs) @@ -991,10 +1030,14 @@ class Function(object): stage: Stage at which to return the IR. Allowed values are 'hlo' and 'optimized_hlo'. """ + # TODO(cheshire): This is a hack to get the current "preferred" device, + # there is no current API to get it otherwise. + device = random_ops.random_normal([]).device return context.context().get_compiler_ir( + device_name=device, stage=stage, function_name=fn_name, - args=list(canon_args) + concrete_fn.captured_inputs) + args=list(filtered_flat_args) + concrete_fn.captured_inputs) return compiler_ir_generator @@ -1024,21 +1067,15 @@ class Function(object): if not initializers: return + var_is_initialized = _evaluate_var_is_initialized( + [v for v, _ in initializers]) + # Note: using defun here avoids an infinite recursion. # Most of the code in this function runs eagerly with init_scope, where # autograph is not necessary. @function_lib.defun(autograph=False) def initialize_variables(): op_map = object_identity.ObjectIdentityDictionary() - # Stack all the var_is_initialized values into one tensor and interpret - # the numpy value. This will reduce the number of RPCs between client and - # worker in the remote case. - with ops.init_scope(): - var_is_initialized = [] - for v, _ in initializers: - var_is_initialized.append( - resource_variable_ops.var_is_initialized_op(v.handle)) - var_is_initialized = array_ops.stack(var_is_initialized).numpy() inits = [] for (v, init), is_initialized in zip(initializers, var_is_initialized): diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index b4b232a5371..1e56b577d96 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -70,7 +70,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0) - @test_util.disable_tfrt('Variable argument is not supported') def testFailIfVariablesAreCreatedMoreThanOnce(self): @def_function.function @@ -80,7 +79,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaises(ValueError): fn(1.0) - @test_util.disable_tfrt('Variable argument is not supported') def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self): state = [] @@ -100,7 +98,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(f(range(5)), 1.0) - @test_util.disable_tfrt('Variable argument is not supported') def testCorrectVariableCreation(self): state = [] @@ -114,7 +111,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) - @test_util.disable_tfrt('Variable argument is not supported') def testFunctionInitializer(self): state = [] @@ -127,7 +123,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) - @test_util.disable_tfrt('Variable argument is not supported') def testFunctionMultipleVariableInitializer(self): state = [] @@ -141,7 +136,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0]) - @test_util.disable_tfrt('Variable argument is not supported') def testFunctionInitializationFunction(self): state = [] @@ -159,7 +153,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): init_fn() self.assertEqual(state[0].numpy(), 2.0) - @test_util.disable_tfrt('Variable argument is not supported') + @test_util.disable_tfrt('Error in native condition op.') def testVariableInitializerNotConstant(self): state = [] @@ -189,7 +183,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(sess.run(state[0]), 2.0) self.assertAllEqual(self.evaluate(result), 6.0) - @test_util.disable_tfrt('Variable argument is not supported') def testLegacyGraphModeVariablesNonTrivialInitializer(self): with ops.Graph().as_default(), self.test_session() as sess: state = [] @@ -209,7 +202,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(sess.run(state[0]), 6.0) self.assertAllEqual(self.evaluate(result), 18.0) - @test_util.disable_tfrt('Variable argument is not supported') def testLegacyGraphModeInputDependentInitializerFails(self): with ops.Graph().as_default(): state = [] @@ -224,7 +216,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): r'transitively.* mul .* x'): fn(constant_op.constant(3.0)) - @test_util.disable_tfrt('Variable argument is not supported') def testMethod(self): class MyModel(object): @@ -253,7 +244,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): def_function.function(functools.partial(lambda x, y: x + y, 1.))( constant_op.constant(2.))) - @test_util.disable_tfrt('Partial is not supported') def test_functools_partial_new_default(self): def f(x=3, y=7): return x + y @@ -262,7 +252,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertEqual(func().numpy(), 9) self.assertEqual(func(y=8).numpy(), 11) - @test_util.disable_tfrt('Partial is not supported') def test_functools_partial_keywords(self): def f(x, y): return x + y @@ -271,7 +260,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1]))) self.assertAllEqual(func(), [0.0]) - @test_util.disable_tfrt('Partial is not supported') def test_functools_partial_single_positional(self): def f(x, y): return x + y @@ -280,7 +268,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): functools.partial(f, constant_op.constant(1))) self.assertAllEqual(func(5), 6) - @test_util.disable_tfrt('Partial is not supported') def test_complicated_partial_with_defaults(self): def identity(*args): @@ -328,7 +315,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): (tensor_spec.TensorSpec( None, dtypes.float32, name='x'),)) - @test_util.disable_tfrt('Variable argument is not supported') @test_util.run_in_graph_and_eager_modes def test_variable_naming(self): class HasVars(module.Module): @@ -399,7 +385,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): 'defined in another function or code block'): f(array_ops.zeros(shape=(8, 42, 3))) - @test_util.disable_tfrt('Control flow is not supported') + @test_util.disable_tfrt('b/169375363: error code support') def testRuntimeErrorNotSticky(self): @def_function.function @@ -504,7 +490,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): constant_op.constant(3.), constant_op.constant(4.))) - @test_util.disable_tfrt('Variable argument is not supported') def testVariableCreatorScope(self): created_variables = [] captured_variables = [] @@ -524,7 +509,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): f() self.assertEqual(created_variables, captured_variables) - @test_util.disable_tfrt('Variable argument is not supported') def testVarAlreadyInitializedNoClobbering(self): v_holder = [] @@ -542,7 +526,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): add_var.get_concrete_function(constant_op.constant(2.)) self.assertAllClose([13., 14.], add_var(constant_op.constant(2.))) - @test_util.disable_tfrt('Variable argument is not supported') def testSameVariableTwice(self): v = variables.Variable(1.0) @@ -552,7 +535,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(add(v, v), 2.0) - @test_util.disable_tfrt('Variable argument is not supported') def testVariableUpdate(self): v1 = variables.Variable(1.0) v2 = variables.Variable(2.0) @@ -599,12 +581,19 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertIs(func_a, func_b) - with save_context.save_context(save_options.SaveOptions()): + with save_context.save_context( + save_options.SaveOptions(experimental_variable_policy=save_options + .VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)): func_c = func.get_concrete_function(constant_op.constant(2.)) - self.assertIs(func_a, func_c) + with save_context.save_context( + save_options.SaveOptions( + experimental_variable_policy=save_options.VariablePolicy.NONE)): + func_d = func.get_concrete_function(constant_op.constant(2.)) + + self.assertIs(func_a, func_c) + self.assertIsNot(func_a, func_d) - @test_util.disable_tfrt('Nested function is not supported') def testInitializationInNestedCall(self): v_holder = [] @@ -627,7 +616,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): v_holder[1].assign(11.) self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.))) - @test_util.disable_tfrt('Variable argument is not supported') @test_util.run_gpu_only def testDeviceAnnotationRespected(self): a = [] @@ -647,7 +635,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): create_variable() self.assertRegex(a[0].device, 'CPU') - @test_util.disable_tfrt('Variable argument is not supported') @test_util.run_gpu_only def testDeviceAnnotationForInitializerRespected(self): a = [] @@ -685,8 +672,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): (True, False), # compile (True, False), # override_function )) - @test_util.disable_tfrt('b/168618526: design proper method to copy tensors' - 'for function.') + def testClone(self, input_signature, autograph, autograph_options, implements, relax_shapes, compile_, override_function): original_py_function = lambda x: x @@ -724,7 +710,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(cloned(x)), self.evaluate(cloned_py_function(x))) - @test_util.disable_tfrt('Variable argument is not supported') def testLiftPlaceholderInitializedVariable(self): with ops.Graph().as_default(): var_list = [] @@ -769,8 +754,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): (None, 'foo.bar'), # implements (None, True, False), # relax_shapes )) - @test_util.disable_tfrt('b/168618526: design proper method to copy tensors' - 'for function.') + def test_pickle(self, input_signature, autograph, autograph_options, implements, relax_shapes): """@function objects can be pickled and unpickled.""" @@ -880,7 +864,6 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertLen(logs.output, 1) self.assertIn('Tracing is expensive', logs.output[0]) - @test_util.disable_tfrt('Nested function is not supported') def test_frequent_retracing_warning_nested(self): if sys.version_info[0] < 3: self.skipTest('self.assertLogs() call is not available in Python 2.') @@ -939,6 +922,70 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertLen(logs.output, 1) self.assertIn('Tracing is expensive', logs.output[0]) + def test_experimental_get_tracing_count_function(self): + + @def_function.function + def double(a): + return a + a + + double(constant_op.constant(1)) + double(constant_op.constant(2)) + self.assertAllEqual(double.experimental_get_tracing_count(), 1) + double(constant_op.constant('a')) + self.assertAllEqual(double.experimental_get_tracing_count(), 2) + + def test_experimental_get_tracing_count_method(self): + + class TestClass(): + + @def_function.function + def testDouble(self, a): + return a + a + + obj1 = TestClass() + obj1.testDouble(constant_op.constant(1)) + obj1.testDouble(constant_op.constant(2)) + obj1.testDouble(constant_op.constant(1.1)) + self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2) + obj2 = TestClass() + obj2.testDouble(constant_op.constant(1)) + obj2.testDouble(constant_op.constant(1.1)) + obj2.testDouble(constant_op.constant('a')) + self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3) + self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2) + + def test_experimental_get_tracing_count_function(self): + + @def_function.function + def double(a): + return a + a + + double(constant_op.constant(1)) + double(constant_op.constant(2)) + self.assertAllEqual(double.experimental_get_tracing_count(), 1) + double(constant_op.constant('a')) + self.assertAllEqual(double.experimental_get_tracing_count(), 2) + + def test_experimental_get_tracing_count_method(self): + + class TestClass(): + + @def_function.function + def testDouble(self, a): + return a + a + + obj1 = TestClass() + obj1.testDouble(constant_op.constant(1)) + obj1.testDouble(constant_op.constant(2)) + obj1.testDouble(constant_op.constant(1.1)) + self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2) + obj2 = TestClass() + obj2.testDouble(constant_op.constant(1)) + obj2.testDouble(constant_op.constant(1.1)) + obj2.testDouble(constant_op.constant('a')) + self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3) + self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/def_function_test_cpu_only.py b/tensorflow/python/eager/def_function_test_cpu_only.py index 7bb6ade8f6c..7fb5e8175fe 100644 --- a/tensorflow/python/eager/def_function_test_cpu_only.py +++ b/tensorflow/python/eager/def_function_test_cpu_only.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.framework import tfrt_utils from tensorflow.python.platform import test @@ -37,8 +38,9 @@ class DefFunctionCpuOnlyTest(test.TestCase, parameterized.TestCase): if test.is_built_with_rocm() or test_util.is_xla_enabled(): return - with self.assertRaisesRegexp(errors.UnimplementedError, - 'check target linkage'): + with self.assertRaisesRegex((errors.UnknownError if tfrt_utils.enabled() + else errors.UnimplementedError), + 'check target linkage'): @def_function.function(experimental_compile=True) def fn(x): diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 80ca015a24d..ed1085d8b54 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -616,6 +616,46 @@ class DefFunctionTest(xla_test.XLATestCase): 'label', f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) + def testGetCompilerIrNoDevicePlacement(self): + if 'gpu' not in self.device.lower(): + self.skipTest('Testing get_compiler_ir on GPUs without placement') + + @def_function.function(experimental_compile=True) + def f(a, b): + return a + b + + a = constant_op.constant([1.1, 1.1]) + b = constant_op.constant([2.2, 2.2]) + + self.assertIn( + 'label', + f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) + + def testGetCompilerIrNonTensors(self): + with ops.device('device:{}:0'.format(self.device)): + + @def_function.function(experimental_compile=True) + def f(l): + return l[0] + l[1] + + l = [constant_op.constant(1.1), constant_op.constant(2.2)] + + self.assertIn('tuple', + f.experimental_get_compiler_ir(l)()) + + def testConstantOnWrongDevice(self): + with ops.device('device:{}:0'.format(self.device)): + + s = random_ops.random_uniform([2], 1, 10, dtypes.int32) + l = random_ops.random_normal([s[0] * s[1]]) + + @def_function.function(experimental_compile=True) + def f(l): + return array_ops.reshape(l, s) + + self.assertIn('tuple', + f.experimental_get_compiler_ir(l)()) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 02f167b4688..60dd3f17024 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -589,6 +589,8 @@ class _EagerDefinedFunction(object): config=config, executor_type=executor_type) + for i, func_graph_output in enumerate(self._func_graph_outputs): + custom_gradient.copy_handle_data(func_graph_output, outputs[i]) if executing_eagerly: return outputs else: @@ -597,8 +599,6 @@ class _EagerDefinedFunction(object): # once that's done. for i, shape in enumerate(self._output_shapes): outputs[i].set_shape(shape) - for i, func_graph_output in enumerate(self._func_graph_outputs): - custom_gradient.copy_handle_data(func_graph_output, outputs[i]) return outputs @@ -900,8 +900,9 @@ class _TapeGradientFunctions(object): for output in trainable_outputs: gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( output) - gradients_wrt_outputs.append( - graph_placeholder(gradient_dtype, gradient_shape)) + gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape) + custom_gradient.copy_handle_data(output, gradient_placeholder) + gradients_wrt_outputs.append(gradient_placeholder) with ops.device(None): gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access trainable_outputs, @@ -2174,6 +2175,7 @@ class ConcreteFunction(object): j = 0 for i, o in enumerate(outputs_list): if o is not None: + custom_gradient.copy_handle_data(self.outputs[j], result[j]) outputs_list[i] = result[j] j += 1 ret = nest.pack_sequence_as(self._func_graph.structured_outputs, @@ -2279,10 +2281,16 @@ class ConcreteFunction(object): lines.append(" Args:") lines.extend(arg_details) lines.append(" Returns:") + + def spec_from_value(value): + # For loaded function, structured_outputs are already specs. + if isinstance(value, type_spec.TypeSpec): + return value + return type_spec.type_spec_from_value(value) + lines.append(" {}".format( pretty_print_spec( - nest.map_structure(type_spec.type_spec_from_value, - self.structured_outputs)))) + nest.map_structure(spec_from_value, self.structured_outputs)))) return "\n".join(lines) @@ -2933,10 +2941,6 @@ class Function(object): self._experimental_compile = experimental_compile self._experimental_follow_type_hints = experimental_follow_type_hints - # A boolean indicating whether the function has been traced with - # distribution strategy. - self._traced_with_distribution_strategy = False - def __call__(self, *args, **kwargs): """Calls a graph function specialized to the inputs.""" with self._lock: @@ -3169,18 +3173,13 @@ class Function(object): except (AttributeError, IndexError): pass - # If the function has been traced with a distribution strategy, it might - # need to be retraced at saving time as DistributedVariable created under - # distribution strategy may want different tracing behavior at training and - # saving, e.g, it wants to resolve to the primary component at saving time, - # but wants resolve to the component residing in the current device at - # training time. We achieve this by adding variable_policy to the function - # cache key. - if save_context.in_save_context( - ) and self._traced_with_distribution_strategy: + if save_context.in_save_context(): variable_policy = ( save_context.get_save_options().experimental_variable_policy) else: + # With EXPAND_DISTRIBUTED_VARIABLES the variables have the same behavior + # in and out of saving. We use EXPAND_DISTRIBUTED_VARIABLES so that if the + # user saves with it, there's no need to retrace the functions. variable_policy = save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES return (parent_graph, device_functions, colocation_stack, @@ -3372,9 +3371,6 @@ class Function(object): graph_function = self._create_graph_function(args, kwargs) self._function_cache.primary[cache_key] = graph_function - if ops.get_default_graph()._distribution_strategy_stack: - self._traced_with_distribution_strategy = True - return graph_function, filtered_flat_args diff --git a/tensorflow/python/eager/function_argument_naming_test.py b/tensorflow/python/eager/function_argument_naming_test.py index b3bc4845b9a..4707f45164c 100644 --- a/tensorflow/python/eager/function_argument_naming_test.py +++ b/tensorflow/python/eager/function_argument_naming_test.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -39,8 +38,6 @@ from tensorflow.python.platform import test class ArgumentNamingTests(test.TestCase, parameterized.TestCase): """Tests for recognizable export signatures from concrete functions.""" - @test_util.disable_tfrt('b/168618526: design proper method to copy tensors' - 'for function.') def testBasic(self, function_decorator): @function_decorator def fn(a, b): @@ -65,8 +62,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase): [3., 2.], fn_op(a=constant_op.constant(1.), b=constant_op.constant(2.))) - @test_util.disable_tfrt('b/168618526: design proper method to copy tensors' - 'for function.') def testVariable(self, function_decorator): @function_decorator def fn(a, b): @@ -85,8 +80,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase): [inp.op.get_attr('_user_specified_name') for inp in fn_op.inputs]) self.assertEqual(2, len(fn_op.graph.structured_outputs)) - @test_util.disable_tfrt('b/168618526: design proper method to copy tensors' - 'for function.') def testDictReturned(self, function_decorator): @function_decorator def fn(x, z=(1., 2.), y=3.): @@ -204,8 +197,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase): [b'y'], [inp.op.get_attr('_user_specified_name') for inp in method_op2.inputs]) - @test_util.disable_tfrt('b/168618526: design proper method to copy tensors' - 'for function.') def testVariadic(self, function_decorator): @function_decorator def variadic_fn(x, *args, **kwargs): @@ -228,8 +219,6 @@ class ArgumentNamingTests(test.TestCase, parameterized.TestCase): [b'x', b'y', b'args_1', b'second_variadic', b'z', b'cust'], [inp.op.get_attr('_user_specified_name') for inp in variadic_op.inputs]) - @test_util.disable_tfrt('b/168618526: design proper method to copy tensors' - 'for function.') def testVariadicInputSignature(self, function_decorator): @function_decorator( input_signature=( diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index f825d6db971..7055ddda437 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -190,6 +190,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(AttributeError, 'no attribute'): add(c) + @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.') def testPackedVariable(self): with ops.device('/cpu:0'): v0_0 = resource_variable_ops.ResourceVariable(1.0) @@ -842,6 +843,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): expected = [4.0] * 100 self.assertSequenceEqual(outputs, expected) + @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt') def testExecutingStatefulDefunConcurrently(self): v = resource_variable_ops.ResourceVariable(1.0) @@ -1320,6 +1322,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): self.assertIsInstance( self.v, resource_variable_ops.ResourceVariable) + @test_util.disable_tfrt('b/169294215') def testRunMetadata(self): @def_function.function @@ -3347,6 +3350,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): test_fn() self.assertEqual(ag_ctx.control_status_ctx().status, prev_status) + @test_util.disable_tfrt('b/170435618') def testCancelBeforeFunctionExecution(self): if not context.executing_eagerly(): self.skipTest('eager only') @@ -3364,8 +3368,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.CancelledError): cancelable_func() - # TODO(b/162544929): Enable this test. - def DISABLE_testCancelBlockedFunctionExecution(self): + @test_util.disable_tfrt('b/170435618') + def testCancelBlockedFunctionExecution(self): if not context.executing_eagerly(): self.skipTest('eager only') @@ -3388,6 +3392,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase): cancelable_func() t.join() + @test_util.disable_tfrt('b/170435618') def testCancelAfterFunctionExecution(self): if not context.executing_eagerly(): self.skipTest('eager only') @@ -4292,6 +4297,28 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(TypeError, 'missing required arguments: y'): foo.add(2) # pylint: disable=no-value-for-parameter + @test_util.run_v2_only + def testControlDependencyAfterInline(self): + v = variables.Variable(0.) + + @def_function.function + def assign(): + return v.assign(1.) + + @def_function.function + def assign_add(): + return v.assign_add(1.) + + @def_function.function + def f(): + check_ops.assert_equal_v2(assign(), 1.) + check_ops.assert_equal_v2(assign_add(), 2.) + + # We don't have a way to inspect the inlined graph in Python, so we run it + # multiple times to have more confidence the dependency is correct. + for _ in range(30): + f() + class MultiDeviceTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index b45b699c2de..0c8bbe76c98 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -77,8 +77,6 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): total = math_ops.add_n([three, four]) self.assertAllEqual(7, total) - @test_util.disable_tfrt('b/163564975: TFRT MatMul Kernel is partially ' - 'implemented.') def testExecuteBoolAttr(self): three = constant_op.constant([[3]]) five = constant_op.constant([[5]]) @@ -504,6 +502,18 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): context.async_clear_error() config.set_synchronous_execution(True) + def testCrossContextTensorCache(self): + old_context = context.context() + old_x = constant_op.constant(9.5) + context._set_context(context.Context()) + + try: + new_x = constant_op.constant(9.5) + self.assertEqual(new_x.numpy(), 9.5) + finally: + context._set_context(old_context) + + self.assertEqual(old_x.numpy(), 9.5) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc index e24717960a7..f23bdc6ec5c 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc @@ -50,7 +50,7 @@ auto OpGradientInfoInit(const T &a) { absl::optional> OpGradientUnusedInputIndices( const tensorflow::string &op_name) { - static std::array a = {{ + static std::array a = {{ {"Acosh"}, {"AllToAll", 1, {0}}, {"ApproximateEqual"}, @@ -227,6 +227,7 @@ absl::optional> OpGradientUnusedInputIndices( {"QuantizeAndDequantize"}, {"QuantizeAndDequantizeV2"}, {"QuantizeAndDequantizeV3"}, + {"QuantizeAndDequantizeV4Grad", 1, {3}}, {"QueueClose"}, {"QueueDequeue"}, {"QueueDequeueMany"}, @@ -420,7 +421,7 @@ absl::optional> OpGradientUnusedInputIndices( absl::optional> OpGradientUnusedOutputIndices( const tensorflow::string &op_name) { - static std::array a = {{ + static std::array a = {{ {"Abs"}, {"AccumulateNV2"}, {"Acos"}, @@ -669,6 +670,8 @@ absl::optional> OpGradientUnusedOutputIndices( {"QuantizeAndDequantize"}, {"QuantizeAndDequantizeV2"}, {"QuantizeAndDequantizeV3"}, + {"QuantizeAndDequantizeV4"}, + {"QuantizeAndDequantizeV4Grad"}, {"QueueClose"}, {"QueueEnqueue"}, {"QueueEnqueueMany"}, diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index e5c74deaf80..bdd17c889e6 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -335,12 +335,12 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value, // TODO(slebedev): also cache singleton NumPy arrays and scalars? if (PyArray_IsPythonNumber(value)) { auto* cache = TFE_TensorHandleCache::Get(); - TFE_TensorHandle* handle = cache->Lookup(value, dtype, device_name); + TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name); if (handle != nullptr) return handle; handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name); if (handle == nullptr) return nullptr; if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) { - cache->Insert(value, dtype, device_name, handle); + cache->Insert(value, dtype, ctx, device_name, handle); } return handle; } else { diff --git a/tensorflow/python/eager/pywrap_tensor_conversion.cc b/tensorflow/python/eager/pywrap_tensor_conversion.cc index 041ddf4ec53..432082433d7 100644 --- a/tensorflow/python/eager/pywrap_tensor_conversion.cc +++ b/tensorflow/python/eager/pywrap_tensor_conversion.cc @@ -38,10 +38,10 @@ TFE_TensorHandleCache* TFE_TensorHandleCache::Get() { } TFE_TensorHandle* TFE_TensorHandleCache::Lookup( - PyObject* value, tensorflow::DataType dtype, + PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx, absl::string_view device_name) const { CHECK_NOTNULL(value); - const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name}); + const auto it = cache.find(Key{PyObjectPtr{value}, dtype, ctx, device_name}); if (it == cache.end()) { scalar_cache_misses->GetCell()->IncrementBy(1); return nullptr; @@ -53,10 +53,11 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup( } void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype, + TFE_Context* ctx, absl::string_view device_name, TFE_TensorHandle* h) { Py_INCREF(value); - cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, + cache.emplace(Key{PyObjectPtr{value}, dtype, ctx, device_name}, tensorflow::wrap(tensorflow::unwrap(h)->Copy())); } diff --git a/tensorflow/python/eager/pywrap_tensor_conversion.h b/tensorflow/python/eager/pywrap_tensor_conversion.h index 8890979c379..5bd851ed5b2 100644 --- a/tensorflow/python/eager/pywrap_tensor_conversion.h +++ b/tensorflow/python/eager/pywrap_tensor_conversion.h @@ -73,16 +73,20 @@ struct TFE_TensorHandleCache { ~TFE_TensorHandleCache() { DecrefUnrefAll(); } TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype, + TFE_Context* ctx, absl::string_view device_name) const; - void Insert(PyObject* value, tensorflow::DataType dtype, + void Insert(PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx, absl::string_view device_name, TFE_TensorHandle* h); void Clear(); private: - // TODO(slebedev): should the key depend on TFE_Context? - using Key = std::tuple; + // TODO(kkb): Instead of `TFE_Context*` key, ideally Python's context object + // should have TFE_TensorHandleCache instance. Migrate once we Python context + // object is backed by C++ data structure. b/169790439 + using Key = std::tuple; void DecrefUnrefAll() { for (const auto& p : cache) { diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 1ab72a5125b..128fb09d114 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1283,7 +1283,8 @@ class PyVSpace : public tensorflow::eager::VSpace outputs(1); if (!name) { name = "Add"; } - if (!tape) { - MaybeRaiseRegisteredFromStatus( - ops::Add(ctx, {a, b}, absl::MakeSpan(outputs), name)); - } else { - MaybeRaiseRegisteredFromStatus(gradients::internal::Add( - ctx, tape, {a, b}, absl::MakeSpan(outputs), *registry)); - } + MaybeRaiseRegisteredFromStatus( + ops::Add(ctx, {a, b}, absl::MakeSpan(outputs), name)); return outputs[0]; }); m.def("mat_mul", [](AbstractContext* ctx, AbstractTensorHandle* a, - AbstractTensorHandle* b, const char* name, Tape* tape, - GradientRegistry* registry) { + AbstractTensorHandle* b, const char* name) { int num_outputs = 1; std::vector outputs(1); if (!name) { name = "MatMul"; } - if (!tape) { - MaybeRaiseRegisteredFromStatus( - ops::MatMul(ctx, {a, b}, absl::MakeSpan(outputs), name, - /*transpose_a=*/false, /*transpose_b=*/false)); - } else { - MaybeRaiseRegisteredFromStatus(gradients::internal::MatMul( - ctx, tape, {a, b}, absl::MakeSpan(outputs), name, - /*transpose_a=*/false, - /*transpose_b=*/false, *registry)); - } + MaybeRaiseRegisteredFromStatus( + ops::MatMul(ctx, {a, b}, absl::MakeSpan(outputs), name, + /*transpose_a=*/false, /*transpose_b=*/false)); return outputs[0]; }); } diff --git a/tensorflow/python/framework/experimental/math_ops.py b/tensorflow/python/framework/experimental/math_ops.py index e0ab5c4ab45..7b3a171da1f 100644 --- a/tensorflow/python/framework/experimental/math_ops.py +++ b/tensorflow/python/framework/experimental/math_ops.py @@ -20,19 +20,13 @@ from __future__ import print_function from tensorflow.python.framework.experimental import _math_ops from tensorflow.python.framework.experimental import context_stack as context -from tensorflow.python.framework.experimental import gradient_registry -from tensorflow.python.framework.experimental import tape_stack def add(a, b, name=None): ctx = context.get_default() - tape = tape_stack.get_default() - grad_registry = gradient_registry.get_global_registry() - return _math_ops.add(ctx, a, b, name, tape, grad_registry) + return _math_ops.add(ctx, a, b, name) def mat_mul(a, b, name=None): ctx = context.get_default() - tape = tape_stack.get_default() - grad_registry = gradient_registry.get_global_registry() - return _math_ops.mat_mul(ctx, a, b, name, tape, grad_registry) + return _math_ops.mat_mul(ctx, a, b, name) diff --git a/tensorflow/python/framework/experimental/nn_ops.cc b/tensorflow/python/framework/experimental/nn_ops.cc index 1524dd35e65..7f857f12a6a 100644 --- a/tensorflow/python/framework/experimental/nn_ops.cc +++ b/tensorflow/python/framework/experimental/nn_ops.cc @@ -23,60 +23,37 @@ limitations under the License. #include "pybind11/pybind11.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/gradients.h" #include "tensorflow/python/lib/core/pybind11_status.h" -// This library provides helpers for running ops and recording them on a -// GradientTape. This is currently needed because the tape does not provide -// an implementation of the abstract execution APIs but that will change. -// TODO(b/168209775): Remove this and its imported symbols once the tape -// execution context is ready. -#include "tensorflow/c/eager/mnist_gradients_testutil.h" - using tensorflow::AbstractContext; using tensorflow::AbstractTensorHandle; -using tensorflow::gradients::GradientRegistry; -using tensorflow::gradients::Tape; namespace tensorflow { PYBIND11_MODULE(_nn_ops, m) { - m.def("relu", [](AbstractContext* ctx, AbstractTensorHandle* a, - const char* name, Tape* tape, GradientRegistry* registry) { - int num_outputs = 1; - std::vector outputs(1); - if (!name) { - name = "Relu"; - } - if (!tape) { - MaybeRaiseRegisteredFromStatus( - ops::Relu(ctx, {a}, absl::MakeSpan(outputs), name)); - } else { - MaybeRaiseRegisteredFromStatus(gradients::internal::Relu( - ctx, tape, {a}, absl::MakeSpan(outputs), name, *registry)); - } - return outputs[0]; - }); - - m.def("sparse_softmax_cross_entropy_with_logits", - [](AbstractContext* ctx, AbstractTensorHandle* features, - AbstractTensorHandle* labels, const char* name, Tape* tape, - GradientRegistry* registry) { - int num_outputs = 2; - std::vector outputs(2); + m.def("relu", + [](AbstractContext* ctx, AbstractTensorHandle* a, const char* name) { + int num_outputs = 1; + std::vector outputs(1); if (!name) { - name = "SparseSoftmaxCrossEntropyWithLogits"; + name = "Relu"; } - if (!tape) { - MaybeRaiseRegisteredFromStatus( - ops::SparseSoftmaxCrossEntropyWithLogits( - ctx, {features, labels}, absl::MakeSpan(outputs), name)); - } else { - MaybeRaiseRegisteredFromStatus( - gradients::internal::SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, {features, labels}, absl::MakeSpan(outputs), - name, *registry)); - } - return outputs[0]; // Only return the loss vals, not the backprop. + MaybeRaiseRegisteredFromStatus( + ops::Relu(ctx, {a}, absl::MakeSpan(outputs), name)); + return outputs[0]; }); + + m.def( + "sparse_softmax_cross_entropy_with_logits", + [](AbstractContext* ctx, AbstractTensorHandle* features, + AbstractTensorHandle* labels, const char* name) { + int num_outputs = 2; + std::vector outputs(2); + if (!name) { + name = "SparseSoftmaxCrossEntropyWithLogits"; + } + MaybeRaiseRegisteredFromStatus(ops::SparseSoftmaxCrossEntropyWithLogits( + ctx, {features, labels}, absl::MakeSpan(outputs), name)); + return outputs[0]; // Only return the loss vals, not the backprop. + }); } } // namespace tensorflow diff --git a/tensorflow/python/framework/experimental/nn_ops.py b/tensorflow/python/framework/experimental/nn_ops.py index 5befa2845cf..47b0bb5916e 100644 --- a/tensorflow/python/framework/experimental/nn_ops.py +++ b/tensorflow/python/framework/experimental/nn_ops.py @@ -20,20 +20,14 @@ from __future__ import print_function from tensorflow.python.framework.experimental import _nn_ops from tensorflow.python.framework.experimental import context_stack as context -from tensorflow.python.framework.experimental import gradient_registry -from tensorflow.python.framework.experimental import tape_stack def relu(a, name=None): ctx = context.get_default() - tape = tape_stack.get_default() - grad_registry = gradient_registry.get_global_registry() - return _nn_ops.relu(ctx, a, name, tape, grad_registry) + return _nn_ops.relu(ctx, a, name) def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): ctx = context.get_default() - tape = tape_stack.get_default() - grad_registry = gradient_registry.get_global_registry() return _nn_ops.sparse_softmax_cross_entropy_with_logits( - ctx, logits, labels, name, tape, grad_registry) + ctx, logits, labels, name) diff --git a/tensorflow/python/framework/experimental/tape.cc b/tensorflow/python/framework/experimental/tape.cc index 8b5445db562..85c943dddf9 100644 --- a/tensorflow/python/framework/experimental/tape.cc +++ b/tensorflow/python/framework/experimental/tape.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/python/lib/core/pybind11_status.h" @@ -27,7 +28,8 @@ namespace tensorflow { namespace gradients { Status RegisterGradients(GradientRegistry* registry) { - TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + // TODO(srbs): Rename ops::Add and AddRegisterer to AddV2. + TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); @@ -74,6 +76,11 @@ PYBIND11_MODULE(_tape, m) { MaybeRaiseRegisteredFromStatus(RegisterGradients(registry)); return registry; })); + py::class_(m, "TapeContext") + .def(py::init( + [](AbstractContext* ctx, Tape* tape, GradientRegistry* registry) { + return new TapeContext(ctx, tape, *registry); + })); } } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/python/framework/experimental/tape.py b/tensorflow/python/framework/experimental/tape.py index 47ce781ee7c..e88e09a7a44 100644 --- a/tensorflow/python/framework/experimental/tape.py +++ b/tensorflow/python/framework/experimental/tape.py @@ -20,7 +20,7 @@ from __future__ import print_function from tensorflow.python.framework.experimental import _tape from tensorflow.python.framework.experimental import context_stack -from tensorflow.python.framework.experimental import tape_stack +from tensorflow.python.framework.experimental import gradient_registry from tensorflow.python.util import nest @@ -29,6 +29,10 @@ class GradientTape(object): def __init__(self, persistent=False): self._c_tape = _tape.Tape(persistent) + ctx = context_stack.get_default() + self._tape_context = _tape.TapeContext( + ctx, self._c_tape, gradient_registry.get_global_registry()) + self._ctx_manager = None def watch(self, t): self._c_tape.Watch(t) @@ -45,10 +49,10 @@ class GradientTape(object): def __enter__(self): """Enters a context inside which operations are recorded on this tape.""" - if tape_stack.get_default(): - raise ValueError("Nested tapes are not supported yet.") - tape_stack.push(self._c_tape) + self._ctx_manager = context_stack.set_default(self._tape_context) + self._ctx_manager.__enter__() return self def __exit__(self, typ, value, traceback): - tape_stack.pop() + self._ctx_manager.__exit__(typ, value, traceback) + self._ctx_manager = None diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 9cff100dcc6..71c009095a0 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -1205,7 +1205,8 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None): # Tensor or not. For non-tensor entries it should be None. shape = next(shapes_iter) if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): - if isinstance(arg, tensor_spec.TensorSpec) and arg.name: + arg_is_spec = isinstance(arg, tensor_spec.TensorSpec) + if arg_is_spec and arg.name: requested_name = arg.name else: requested_name = name @@ -1218,6 +1219,8 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None): # Sometimes parameter names are not valid op names, so fall back to # unnamed placeholders. placeholder = graph_placeholder(arg.dtype, placeholder_shape) + if not arg_is_spec: + custom_gradient.copy_handle_data(arg, placeholder) if name is not None: # Record the requested/user-specified name in case it's different than # the uniquified name, for validation when exporting signatures. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 28a5c41b27d..90cd0f62986 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1651,7 +1651,10 @@ class FunctionInlineControlTest(test.TestCase, parameterized.TestCase): x = Cell(x) return math_ops.reduce_sum(x, [0, 1]) - self.assertEqual(noinline, Cell.definition.attr["_noinline"].b) + # Disabling this check on the ROCm platform, because it fails + # The failure might not be ROCm specific(see commit message for details) + if not test.is_built_with_rocm(): + self.assertEqual(noinline, Cell.definition.attr["_noinline"].b) g = ops.Graph() with g.as_default(): diff --git a/tensorflow/python/framework/op_def_util.cc b/tensorflow/python/framework/op_def_util.cc index 4e1569f190d..794acbc0b28 100644 --- a/tensorflow/python/framework/op_def_util.cc +++ b/tensorflow/python/framework/op_def_util.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" #include "tensorflow/python/util/util.h" using ::tensorflow::swig::GetRegisteredPyObject; diff --git a/tensorflow/python/framework/op_def_util.h b/tensorflow/python/framework/op_def_util.h index 3b35c3ef7ad..dc72c1304f8 100644 --- a/tensorflow/python/framework/op_def_util.h +++ b/tensorflow/python/framework/op_def_util.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/lib/core/safe_pyobject_ptr.h" namespace tensorflow { diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 82a6dc44959..ccc1daf721c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -371,7 +371,7 @@ class Tensor(internal.NativeObject, core_tf_types.Tensor): TypeError: If the op is not an `Operation`. """ if not isinstance(op, Operation): - raise TypeError("op needs to be an Operation: %s" % op) + raise TypeError("op needs to be an Operation: %s" % (op,)) self._op = op self._value_index = value_index self._dtype = dtypes.as_dtype(dtype) @@ -1946,19 +1946,19 @@ class Operation(object): assert op_def is None c_op = node_def else: - raise TypeError("node_def needs to be a NodeDef: %s" % node_def) + raise TypeError("node_def needs to be a NodeDef: %s" % (node_def,)) if not isinstance(g, Graph): - raise TypeError("g needs to be a Graph: %s" % g) + raise TypeError("g needs to be a Graph: %s" % (g,)) self._graph = g if inputs is None: inputs = [] elif not isinstance(inputs, list): - raise TypeError("inputs needs to be a list of Tensors: %s" % inputs) + raise TypeError("inputs needs to be a list of Tensors: %s" % (inputs,)) for a in inputs: if not isinstance(a, Tensor): - raise TypeError("input needs to be a Tensor: %s" % a) + raise TypeError("input needs to be a Tensor: %s" % (a,)) if input_types is None: input_types = [i.dtype.base_dtype for i in inputs] else: @@ -6099,7 +6099,7 @@ def _get_graph_from_inputs(op_input_list, graph=None): op_input_list = tuple(op_input_list) # Handle generators correctly if graph and not isinstance(graph, Graph): - raise TypeError("Input graph needs to be a Graph: %s" % graph) + raise TypeError("Input graph needs to be a Graph: %s" % (graph,)) # 1. We validate that all of the inputs are from the same graph. This is # either the supplied graph parameter, or the first one selected from one diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 58e3f650c44..e3fe4c07a57 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -91,7 +91,6 @@ class ResourceTest(test_util.TensorFlowTestCase): resources.shared_resources()).eval()), 0) -@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class TensorAndShapeTest(test_util.TensorFlowTestCase): def testShape(self): @@ -327,6 +326,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(z, [False, False, False, True]) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseAndErrors(self): x_int = constant_op.constant(0) @@ -368,6 +368,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(z, [False, True, True, True]) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseOrErrors(self): x_int = constant_op.constant(0) @@ -409,6 +410,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(z, [False, True, True, False]) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseXorErrors(self): x_int = constant_op.constant(0) @@ -448,6 +450,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(y, [True, False]) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testBitwiseNotErrors(self): if context.executing_eagerly(): # :( @@ -459,7 +462,6 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): _ = ~constant_op.constant("a") -@test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_all_in_graph_and_eager_modes class IndexedSlicesTest(test_util.TensorFlowTestCase): @@ -504,7 +506,6 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase): self.assertAllEqual(x.indices, [0, 2]) -@test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_all_in_graph_and_eager_modes class IndexedSlicesSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase): @@ -650,7 +651,6 @@ def _apply_op(g, *args, **kwargs): return op.outputs -@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class OperationTest(test_util.TensorFlowTestCase): @test_util.run_deprecated_v1 @@ -1605,7 +1605,6 @@ class NameTest(test_util.TensorFlowTestCase): g.create_op("FloatOutput", [], [dtypes.float32]).name) -@test_util.disable_tfrt("Device API are not supported yet. b/156188344") class DeviceTest(test_util.TensorFlowTestCase): def testNoDevice(self): @@ -2186,7 +2185,6 @@ class CollectionTest(test_util.TensorFlowTestCase): # Collections are ordered. self.assertEqual([90, 100], ops.get_collection("key")) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def test_defun(self): with context.eager_mode(): @@ -2293,7 +2291,6 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): # e should be dominated by c. self.assertEqual(e.op.control_inputs, []) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_in_graph_and_eager_modes def testEager(self): def future(): @@ -2614,7 +2611,6 @@ class OpScopeTest(test_util.TensorFlowTestCase): self._testGraphElements([a, variable, b]) -@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class InitScopeTest(test_util.TensorFlowTestCase): def testClearsControlDependencies(self): @@ -2916,7 +2912,6 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertFalse(self.evaluate(f())) -@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class GraphTest(test_util.TensorFlowTestCase): def setUp(self): @@ -3405,7 +3400,6 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): self.assertEqual("/device:CPU:0", b.op.device) f() - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def testColocateWithVariableInFunction(self): v = variables.Variable(1.) diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 20508f37eb7..de29cc53c1f 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import operator import six from tensorflow.core.framework import tensor_shape_pb2 @@ -916,10 +918,7 @@ class TensorShape(object): def num_elements(self): """Returns the total number of elements, or none for incomplete shapes.""" if self.is_fully_defined(): - size = 1 - for dim in self._dims: - size *= dim.value - return size + return functools.reduce(operator.mul, self.as_list(), 1) else: return None @@ -942,19 +941,20 @@ class TensorShape(object): other = as_shape(other) if self._dims is None: return other + if other.dims is None: + return self else: try: self.assert_same_rank(other) - new_dims = [] - for i, dim in enumerate(self._dims): - new_dims.append(dim.merge_with(other[i])) + new_dims = [ + dim.merge_with(other_dim) + for dim, other_dim in zip(self._dims, other.dims) + ] return TensorShape(new_dims) except ValueError: raise ValueError("Shapes %s and %s are not compatible" % (self, other)) def __add__(self, other): - if not isinstance(other, TensorShape): - other = TensorShape(other) return self.concatenate(other) def __radd__(self, other): @@ -1157,10 +1157,10 @@ class TensorShape(object): if self._dims is None or other.dims is None or self.rank != other.rank: return unknown_shape() - dims = [(Dimension(None))] * self.rank - for i, (d1, d2) in enumerate(zip(self._dims, other.dims)): - if d1 is not None and d2 is not None and d1 == d2: - dims[i] = d1 + dims = [ + d1 if d1 is not None and d2 is not None and d1 == d2 else None + for d1, d2 in zip(self._dims, other.dims) + ] return TensorShape(dims) def is_fully_defined(self): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index c0dddac0602..36bf78987a1 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -66,6 +66,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import tfrt_utils from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util @@ -119,20 +120,12 @@ except ImportError: pass -# Uses the same mechanism as above to selectively enable TFRT. -def is_tfrt_enabled(): - return False +def _get_object_count_by_type(exclude=()): + return ( + collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - + collections.Counter([type(obj).__name__ for obj in exclude])) -try: - from tensorflow.python.framework.is_tfrt_test_true import is_tfrt_enabled # pylint: disable=g-import-not-at-top, unused-import -except Exception: # pylint: disable=broad-except - pass - - -def _get_object_count_by_type(): - return collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - @tf_export("test.gpu_device_name") def gpu_device_name(): """Returns the name of a GPU device if available or the empty string.""" @@ -290,7 +283,8 @@ def _strip_checkpoint_v2_randomized(graph_def): if attr_tensor_value and len(attr_tensor_value.string_val) == 1: attr_tensor_string_value = attr_tensor_value.string_val[0] if (attr_tensor_string_value and - re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))): + re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN), + attr_tensor_string_value)): delete_keys.append(attr_key) for attr_key in delete_keys: del node.attr[attr_key] @@ -303,7 +297,8 @@ def _strip_hash_table_shared_name(graph_def): for node in graph_def.node: delete_keys = [] if node.op == "HashTableV2" and "shared_name" in node.attr: - if re.match(_TABLE_SHARED_NAME_PATTERN, str(node.attr["shared_name"].s)): + if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN), + node.attr["shared_name"].s): delete_keys.append("shared_name") for attr_key in delete_keys: del node.attr[attr_key] @@ -621,12 +616,12 @@ def enable_output_all_intermediates(fn): The wrapped function """ - def wrapper(*args, **kwargs): + def wrapper(self, *args, **kwargs): output_all_intermediates_old = \ control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True try: - return fn(*args, **kwargs) + return fn(self, *args, **kwargs) finally: control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \ output_all_intermediates_old @@ -665,12 +660,20 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): # versions of python2.7.x. for _ in range(warmup_iters): f(self, *args, **kwargs) + # Since we aren't in the normal test lifecylce, we need to manually run + # cleanups to clear out their object references. + self.doCleanups() # Some objects are newly created by _get_object_count_by_type(). So # create and save as a dummy variable to include it as a baseline. obj_count_by_type = _get_object_count_by_type() gc.collect() - obj_count_by_type = _get_object_count_by_type() + # unittest.doCleanups adds to self._outcome with each unwound call. + # These objects are retained across gc collections so we exclude them + # from the object count calculation. + obj_count_by_type = _get_object_count_by_type( + exclude=gc.get_referents(self._outcome.errors, + self._outcome.skipped)) if ops.has_default_graph(): collection_sizes_before = { @@ -679,6 +682,9 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): } for _ in range(3): f(self, *args, **kwargs) + # Since we aren't in the normal test lifecylce, we need to manually run + # cleanups to clear out their object references. + self.doCleanups() # Note that gc.get_objects misses anything that isn't subject to garbage # collection (C types). Collections are a common source of leaks, so we # test for collection sizes explicitly. @@ -700,7 +706,11 @@ def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): gc.collect() # There should be no new Python objects hanging around. - obj_count_by_type = _get_object_count_by_type() - obj_count_by_type + obj_count_by_type = ( + _get_object_count_by_type( + exclude=gc.get_referents(self._outcome.errors, + self._outcome.skipped)) - + obj_count_by_type) # In some cases (specifically on MacOS), new_count is somehow # smaller than previous_count. # Using plain assert because not all classes using this decorator @@ -1829,7 +1839,7 @@ def disable_tfrt(unused_description): """Execute the test only if tfrt is not enabled.""" if tf_inspect.isclass(cls_or_func): - if is_tfrt_enabled(): + if tfrt_utils.enabled(): return None else: return cls_or_func @@ -1837,7 +1847,7 @@ def disable_tfrt(unused_description): def decorator(func): def decorated(self, *args, **kwargs): - if is_tfrt_enabled(): + if tfrt_utils.enabled(): return else: return func(self, *args, **kwargs) @@ -2021,6 +2031,7 @@ class TensorFlowTestCase(googletest.TestCase): self._test_start_time = None def setUp(self): + super(TensorFlowTestCase, self).setUp() self._ClearCachedSession() random.seed(random_seed.DEFAULT_GRAPH_SEED) np.random.seed(random_seed.DEFAULT_GRAPH_SEED) @@ -2055,6 +2066,7 @@ class TensorFlowTestCase(googletest.TestCase): thread.check_termination() self._ClearCachedSession() + super(TensorFlowTestCase, self).tearDown() def _ClearCachedSession(self): if self._cached_session is not None: @@ -2904,6 +2916,8 @@ class TensorFlowTestCase(googletest.TestCase): lines = [] subscripts = np.transpose(subscripts) prefix = " " * indent + if np.ndim(value) == 0: + return [prefix + "[0] : " + str(value)] for subscript in itertools.islice(subscripts, limit): lines.append(prefix + str(subscript) + " : " + str(value[tuple(subscript)])) diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 745e30b1fa4..ec5d7b6a85b 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops # pylint: disable=unused-import from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops @@ -108,6 +109,27 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): r"^Found unexpected node '{{node seven}}"): test_util.assert_equal_graph_def(def_57, def_empty) + def test_assert_equal_graph_def_hash_table(self): + def get_graph_def(): + with ops.Graph().as_default() as g: + x = constant_op.constant([2, 9], name="x") + keys = constant_op.constant([1, 2], name="keys") + values = constant_op.constant([3, 4], name="values") + default = constant_op.constant(-1, name="default") + table = lookup_ops.StaticHashTable( + lookup_ops.KeyValueTensorInitializer(keys, values), default) + _ = table.lookup(x) + return g.as_graph_def() + def_1 = get_graph_def() + def_2 = get_graph_def() + # The unique shared_name of each table makes the graph unequal. + with self.assertRaisesRegex(AssertionError, "hash_table_"): + test_util.assert_equal_graph_def(def_1, def_2, + hash_table_shared_name=False) + # That can be ignored. (NOTE: modifies GraphDefs in-place.) + test_util.assert_equal_graph_def(def_1, def_2, + hash_table_shared_name=True) + def testIsGoogleCudaEnabled(self): # The test doesn't assert anything. It ensures the py wrapper # function is generated correctly. @@ -596,6 +618,19 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllInRange( x, 10, 15, open_lower_bound=True, open_upper_bound=True) + @test_util.run_in_graph_and_eager_modes + def testAssertAllInRangeScalar(self): + x = constant_op.constant(10.0, name="x") + nan = constant_op.constant(np.nan, name="nan") + self.assertAllInRange(x, 5, 15) + with self.assertRaises(AssertionError): + self.assertAllInRange(nan, 5, 15) + + with self.assertRaises(AssertionError): + self.assertAllInRange(x, 10, 15, open_lower_bound=True) + with self.assertRaises(AssertionError): + self.assertAllInRange(x, 1, 2) + @test_util.run_in_graph_and_eager_modes def testAssertAllInRangeErrorMessageEllipses(self): x_init = np.array([[10.0, 15.0]] * 12) @@ -927,12 +962,13 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): def test_no_new_objects_decorator(self): - class LeakedObjectTest(object): + class LeakedObjectTest(unittest.TestCase): - def __init__(inner_self): # pylint: disable=no-self-argument - inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name - inner_self.accumulation = [] + def __init__(self, *args, **kwargs): + super(LeakedObjectTest, self).__init__(*args, **kwargs) + self.accumulation = [] + @unittest.expectedFailure @test_util.assert_no_new_pyobjects_executing_eagerly def test_has_leak(self): self.accumulation.append([1.]) @@ -941,10 +977,8 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase): def test_has_no_leak(self): self.not_accumulating = [1.] - with self.assertRaises(AssertionError): - LeakedObjectTest().test_has_leak() - - LeakedObjectTest().test_has_no_leak() + self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful()) + self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful()) class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase, diff --git a/tensorflow/python/framework/experimental/tape_stack.py b/tensorflow/python/framework/tfrt_utils.py similarity index 71% rename from tensorflow/python/framework/experimental/tape_stack.py rename to tensorflow/python/framework/tfrt_utils.py index 20583e699bb..0f973b6a5bb 100644 --- a/tensorflow/python/framework/experimental/tape_stack.py +++ b/tensorflow/python/framework/tfrt_utils.py @@ -12,24 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Thread-local context manager stack.""" +"""Utilities for TFRT migration.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework.experimental import thread_local_stack -_default_tape_stack = thread_local_stack.ThreadLocalStack() - - -def get_default(): - return _default_tape_stack.peek() - - -def push(tape): - _default_tape_stack.push(tape) - - -def pop(): - _default_tape_stack.pop() +def enabled(): + """Returns true if TFRT should be enabled.""" + return False diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py index 3336d3f7e8f..b5a9e7c8e48 100644 --- a/tensorflow/python/grappler/constant_folding_test.py +++ b/tensorflow/python/grappler/constant_folding_test.py @@ -96,19 +96,15 @@ class ConstantFoldingTest(test.TestCase): f(x, y).numpy() self.assertLen(graphs, 1) assign_count = 0 - read_count = 0 for node in graphs[0].node: if node.op == 'AssignAddVariableOp': self.assertEqual(node.input[0], 'y') assign_count += 1 - if node.op == 'ReadVariableOp': - read_count += 1 # Make sure that the only variable update that remains after - # grappler optimization is that of y, and that we prune all - # but the 2 necessary variable reads. + # grappler optimization is that of y. self.assertEqual(assign_count, 1) - self.assertEqual(read_count, 2) + self.assertLen(graphs[0].node, 11) if __name__ == '__main__': diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index a69ed72db87..263b05047da 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -193,7 +193,9 @@ def _get_cluster(): def _is_transpose(node): return node.endswith('TransposeNHWCToNCHW-LayoutOptimizer') or node.endswith( - 'TransposeNCHWToNHWC-LayoutOptimizer') + 'TransposeNCHWToNHWC-LayoutOptimizer') or node.endswith( + 'TransposeNDHWCToNCDHW-LayoutOptimizer') or node.endswith( + 'TransposeNCDHWToNDHWC-LayoutOptimizer') def _is_permute(node): @@ -1230,6 +1232,139 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only + def testBinaryOpsFor5DTensors(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0) + w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0) + mean = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + variance = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0) + conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME') + y = nn.batch_normalization( + conv3d, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=0.001) + output = array_ops.identity(y) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # The binary ops mul_1 and add_1 in batch norm need to transpose one of + # the two inputs to NCDHW. The other input has already been tranposed via + # Conv3D. + expected_num_transposes = 4 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ndhwc_to_ncdhw('batchnorm/mul_1-1', nodes) + self._assert_trans_ndhwc_to_ncdhw('batchnorm/add_1-1', nodes) + self._assert_trans_ncdhw_to_ndhwc('batchnorm/add_1-0-0', nodes) + + @test_util.deprecated_graph_mode_only + def testBatchNorm3D(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0) + filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0) + strides_val = [1, 1, 1, 1, 1] + scale = constant_op.constant(0.1, shape=[3]) + offset = constant_op.constant(0.3, shape=[3]) + conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME') + y, _, _ = nn.fused_batch_norm(conv3d, scale, offset, data_format='NDHWC') + output = array_ops.identity(y) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormV3-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + @test_util.deprecated_graph_mode_only + def testBatchNormGrad3D(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x_3d = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0) + filters = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0) + strides_val = [1, 1, 1, 1, 1] + scale = constant_op.constant(0.1, shape=[3]) + offset = constant_op.constant(0.3, shape=[3]) + mean = constant_op.constant(0.1, shape=[3]) + variance = constant_op.constant(0.3, shape=[3]) + conv3d = gen_nn_ops.conv3d(x_3d, filters, strides_val, 'SAME') + y, running_mean, running_var, r0, r1, r2 = gen_nn_ops.fused_batch_norm_v3( + conv3d, + scale, + offset, + mean, + variance, + epsilon=1.001e-5, + exponential_avg_factor=1.0, + data_format='NDHWC', + is_training=True, + name='batch_norm') + dx, dscale, doffset, _, _ = gen_nn_ops.fused_batch_norm_grad_v3( + y, + x_3d, + scale, + r0, + r1, + r2, + epsilon=1.001e-5, + data_format='NDHWC', + is_training=True) + output = array_ops.identity(dx) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + expected_num_transposes = 3 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes) + self._assert_trans_ndhwc_to_ncdhw('FusedBatchNormGradV3-1', nodes) + self._assert_trans_ncdhw_to_ndhwc('FusedBatchNormGradV3-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + @test_util.deprecated_graph_mode_only def testConv3D(self): if test.is_gpu_available(cuda_only=True): diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index a194dca9a69..7c8492a5f48 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -104,7 +104,9 @@ py_library( name = "backend_config", srcs = ["backend_config.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:util"], + deps = [ + "//tensorflow/python:util", + ], ) # TODO(scottzhu): Cleanup this target and point all the user to keras/engine. @@ -459,8 +461,8 @@ tf_py_test( size = "small", srcs = ["metrics_functional_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ + ":combinations", ":keras", "//tensorflow/python:client_testlib", "//third_party/py/numpy", diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD index aba936c2a8f..8140a1ed806 100644 --- a/tensorflow/python/keras/applications/BUILD +++ b/tensorflow/python/keras/applications/BUILD @@ -12,8 +12,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -47,6 +45,7 @@ py_library( "//tensorflow/python:tf_export", "//tensorflow/python/keras:activations", "//tensorflow/python/keras:backend", + "//tensorflow/python/keras:models", "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers", "//tensorflow/python/keras/utils:data_utils", @@ -86,10 +85,10 @@ tf_py_test( "no_pip", "notsan", # b/168814536 ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -103,10 +102,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -120,10 +119,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -137,10 +136,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -154,10 +153,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -171,10 +170,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -188,10 +187,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -205,10 +204,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -222,10 +221,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -239,10 +238,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -256,10 +255,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -275,10 +274,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -294,10 +293,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -311,10 +310,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -328,10 +327,10 @@ tf_py_test( "no_oss", "no_pip", ], - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras/preprocessing", "@absl_py//absl/testing:parameterized", ], ) @@ -341,10 +340,10 @@ tf_py_test( size = "medium", srcs = ["imagenet_utils_test.py"], shard_count = 2, - tfrt_enabled = True, deps = [ ":applications", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/keras/applications/mobilenet_v3.py b/tensorflow/python/keras/applications/mobilenet_v3.py index ae8f98c181b..b774623b6af 100644 --- a/tensorflow/python/keras/applications/mobilenet_v3.py +++ b/tensorflow/python/keras/applications/mobilenet_v3.py @@ -61,14 +61,15 @@ BASE_DOCSTRING = """Instantiates the {name} architecture. The following table describes the performance of MobileNets: ------------------------------------------------------------------------ MACs stands for Multiply Adds - - |Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1|CPU(ms)| - | [mobilenet_v3_large_1.0_224] | 217 | 5.4 | 75.6 | 51.2 | - | [mobilenet_v3_large_0.75_224] | 155 | 4.0 | 73.3 | 39.8 | - | [mobilenet_v3_large_minimalistic_1.0_224] | 209 | 3.9 | 72.3 | 44.1 | - | [mobilenet_v3_small_1.0_224] | 66 | 2.9 | 68.1 | 15.8 | - | [mobilenet_v3_small_0.75_224] | 44 | 2.4 | 65.4 | 12.8 | - | [mobilenet_v3_small_minimalistic_1.0_224] | 65 | 2.0 | 61.9 | 12.2 | + + |Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1 CPU(ms)| + |---|---|---|---|---| + | mobilenet_v3_large_1.0_224 | 217 | 5.4 | 75.6 | 51.2 | + | mobilenet_v3_large_0.75_224 | 155 | 4.0 | 73.3 | 39.8 | + | mobilenet_v3_large_minimalistic_1.0_224 | 209 | 3.9 | 72.3 | 44.1 | + | mobilenet_v3_small_1.0_224 | 66 | 2.9 | 68.1 | 15.8 | + | mobilenet_v3_small_0.75_224 | 44 | 2.4 | 65.4 | 12.8 | + | mobilenet_v3_small_minimalistic_1.0_224 | 65 | 2.0 | 61.9 | 12.2 | The weights for all 6 models are obtained and translated from the Tensorflow checkpoints from TensorFlow checkpoints found [here] diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 7bab18084dd..2493d32fe6a 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -107,6 +107,17 @@ _CURRENT_SCRATCH_GRAPH = threading.local() _SESSION = threading.local() +# A global dictionary mapping graph objects to an index of counters used +# for various layer/optimizer names in each graph. +# Allows to give unique autogenerated names to layers, in a graph-specific way. +PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary() + + +# A global set tracking what object names have been seen so far. +# Optionally used as an avoid-list when generaing names +OBSERVED_NAMES = set() + + # _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES. # We keep a separate reference to it to make sure it does not get removed from # _GRAPH_LEARNING_PHASES. @@ -210,12 +221,6 @@ def cast_to_floatx(x): return np.asarray(x, dtype=floatx()) -# A global dictionary mapping graph objects to an index of counters used -# for various layer/optimizer names in each graph. -# Allows to give unique autogenerated names to layers, in a graph-specific way. -PER_GRAPH_OBJECT_NAME_UIDS = weakref.WeakKeyDictionary() - - @keras_export('keras.backend.get_uid') def get_uid(prefix=''): """Associates a string prefix with an integer counter in a TensorFlow graph. @@ -248,6 +253,7 @@ def reset_uids(): """ PER_GRAPH_OBJECT_NAME_UIDS.clear() + OBSERVED_NAMES.clear() @keras_export('keras.backend.clear_session') @@ -999,11 +1005,17 @@ def track_variable(v): _GRAPH_VARIABLES[graph].add(v) +def observe_object_name(name): + """Observe a name and make sure it won't be used by `unique_object_name`.""" + OBSERVED_NAMES.add(name) + + def unique_object_name(name, name_uid_map=None, avoid_names=None, namespace='', - zero_based=False): + zero_based=False, + avoid_observed_names=False): """Makes a object name (or arbitrary string) unique within a TensorFlow graph. Arguments: @@ -1011,12 +1023,15 @@ def unique_object_name(name, name_uid_map: An optional defaultdict(int) to use when creating unique names. If None (default), uses a per-Graph dictionary. avoid_names: An optional set or dict with names which should not be used. If - None (default) does not avoid any names. + None (default), don't avoid any names unless `avoid_observed_names` is + True. namespace: Gets a name which is unique within the (graph, namespace). Layers which are not Networks use a blank namespace and so get graph-global names. zero_based: If True, name sequences start with no suffix (e.g. "dense", "dense_1"). If False, naming is one-based ("dense_1", "dense_2"). + avoid_observed_names: If True, avoid any names that have been observed by + `backend.observe_object_name`. Returns: Unique string name. @@ -1031,7 +1046,10 @@ def unique_object_name(name, if name_uid_map is None: name_uid_map = get_default_graph_uid_map() if avoid_names is None: - avoid_names = set() + if avoid_observed_names: + avoid_names = OBSERVED_NAMES + else: + avoid_names = set() proposed_name = None while proposed_name is None or proposed_name in avoid_names: name_key = (namespace, name) @@ -1207,11 +1225,10 @@ def placeholder(shape=None, if ndim: shape = (None,) * ndim if keras_tensor.keras_tensors_enabled(): - spec = tensor_spec.TensorSpec( - shape=shape, dtype=dtype, name=name) if sparse: spec = sparse_tensor.SparseTensorSpec( shape=shape, dtype=dtype) + x = keras_tensor.SparseKerasTensor(spec, name=name) elif ragged: ragged_rank = 0 for i in range(1, len(shape)): @@ -1224,7 +1241,11 @@ def placeholder(shape=None, spec = ragged_tensor.RaggedTensorSpec( shape=shape, dtype=dtype, ragged_rank=ragged_rank) - x = keras_tensor.KerasTensor(spec, name=name) + x = keras_tensor.RaggedKerasTensor(spec, name=name) + else: + spec = tensor_spec.TensorSpec( + shape=shape, dtype=dtype, name=name) + x = keras_tensor.KerasTensor(spec, name=name) else: with get_graph().as_default(): if sparse: @@ -1817,6 +1838,8 @@ def moving_average_update(x, value, momentum): def dot(x, y): """Multiplies 2 tensors (and/or variables) and returns a tensor. + This operation corresponds to `numpy.dot(a, b, out=None)`. + Arguments: x: Tensor or variable. y: Tensor or variable. @@ -1826,6 +1849,7 @@ def dot(x, y): Examples: + If inputs `x` and `y` are 2-D arrays, then it is equivalent to `tf.matmul`. >>> x = tf.keras.backend.placeholder(shape=(2, 3)) >>> y = tf.keras.backend.placeholder(shape=(3, 4)) >>> xy = tf.keras.backend.dot(x, y) @@ -1838,6 +1862,8 @@ def dot(x, y): >>> xy + If `x` is an N-D array and `y` is an M-D array (where M>=2), it is a sum + product over the last axis of `x` and the second-to-last axis of `y`. >>> x = tf.keras.backend.random_uniform_variable(shape=(2, 3), low=0, high=1) >>> y = tf.keras.backend.ones((4, 3, 5)) >>> xy = tf.keras.backend.dot(x, y) @@ -2420,6 +2446,9 @@ def abs(x): def sqrt(x): """Element-wise square root. + This function clips tensor values to a specified min(0) and max(inf) + before taking sqrt. + Arguments: x: Tensor or variable. diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD index a2209513b62..d25afb24f9a 100644 --- a/tensorflow/python/keras/benchmarks/BUILD +++ b/tensorflow/python/keras/benchmarks/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -42,7 +40,10 @@ py_library( # This lib is mainly for running benchmarks on mlcompass infra. py_library( name = "profiler_lib", - visibility = ["//tensorflow:internal"], + visibility = [ + "//learning/brain/contrib/keras_benchmark:__subpackages__", + "//tensorflow/python/keras:__subpackages__", + ], ) COMMON_TAGS = [ @@ -76,6 +77,7 @@ cuda_py_test( deps = [ ":profiler_lib", "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/utils:tf_inspect", ], ) diff --git a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD index 4debe1d0eac..5501aedcd4e 100644 --- a/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD +++ b/tensorflow/python/keras/benchmarks/saved_model_benchmarks/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 5759cf772e4..de97c80fd62 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -264,7 +264,7 @@ class CallbackList(object): if self._progbar is None and add_progbar: self._progbar = ProgbarLogger(count_mode='steps') - self.callbacks.append(self._progbar) + self.callbacks.insert(0, self._progbar) if self._history is None and add_history: self._history = History() @@ -1143,15 +1143,22 @@ class ModelCheckpoint(Callback): the end of every epoch, or after a fixed number of training batches. - Whether only weights are saved, or the whole model is saved. + Note: If you get `WARNING:tensorflow:Can save best model only with + available, skipping` see the description of the `monitor` argument for + details on how to get this right. + Example: ```python + model.compile(loss=..., optimizer=..., + metrics=['accuracy']) + EPOCHS = 10 checkpoint_filepath = '/tmp/checkpoint' model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, save_weights_only=True, - monitor='val_acc', + monitor='val_accuracy', mode='max', save_best_only=True) @@ -1170,17 +1177,30 @@ class ModelCheckpoint(Callback): `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the model checkpoints will be saved with the epoch number and the validation loss in the filename. - monitor: quantity to monitor. + monitor: The metric name to monitor. Typically the metrics are set by the + `Model.compile` method. Note: + + * Prefix the name with `"val_`" to monitor validation metrics. + * Use `"loss"` or "`val_loss`" to monitor the model's total loss. + * If you specify metrics as strings, like `"accuracy"`, pass the same + string (with or without the `"val_"` prefix). + * If you pass `metrics.Metric` objects, `monitor` should be set to + `metric.name` + * If you're not sure about the metric names you can check the contents + of the `history.history` dictionary returned by + `history = model.fit()` + * Multi-output models set additional prefixes on the metric names. + verbose: verbosity mode, 0 or 1. - save_best_only: if `save_best_only=True`, it only saves when the model - is considered the "best" and the latest best model according to the - quantity monitored will not be overwritten. If `filepath` doesn't - contain formatting options like `{epoch}` then `filepath` will be + save_best_only: if `save_best_only=True`, it only saves when the model + is considered the "best" and the latest best model according to the + quantity monitored will not be overwritten. If `filepath` doesn't + contain formatting options like `{epoch}` then `filepath` will be overwritten by each new better model. - mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the - decision to overwrite the current save file is made based on either - the maximization or the minimization of the monitored quantity. - For `val_acc`, this should be `max`, for `val_loss` this should be + mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the + decision to overwrite the current save file is made based on either + the maximization or the minimization of the monitored quantity. + For `val_acc`, this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. save_weights_only: if True, then only the model's weights will be saved @@ -1924,30 +1944,6 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): You can find more information about TensorBoard [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). - Example (Basic): - - ```python - tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") - model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) - # run the tensorboard command to view the visualizations. - ``` - - Example (Profile): - - ```python - # profile a single batch, e.g. the 5th batch. - tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', - profile_batch=5) - model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) - # Now run the tensorboard command to view the visualizations (profile plugin). - - # profile a range of batches, e.g. from 10 to 20. - tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs', - profile_batch='10,20') - model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) - # Now run the tensorboard command to view the visualizations (profile plugin). - ``` - Arguments: log_dir: the path of the directory where to save the log files to be parsed by TensorBoard. @@ -1979,8 +1975,72 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): about metadata files format. In case if the same metadata file is used for all embedding layers, string can be passed. - Raises: - ValueError: If histogram_freq is set and no validation data is provided. + Examples: + + Basic usage: + + ```python + tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + # Then run the tensorboard command to view the visualizations. + ``` + + Custom batch-level summaries in a subclassed Model: + + ```python + class MyModel(tf.keras.Model): + + def build(self, _): + self.dense = tf.keras.layers.Dense(10) + + def call(self, x): + outputs = self.dense(x) + tf.summary.histogram('outputs', outputs) + return outputs + + model = MyModel() + model.compile('sgd', 'mse') + + # Make sure to set `update_freq=N` to log a batch-level summary every N batches. + # In addition to any `tf.summary` contained in `Model.call`, metrics added in + # `Model.compile` will be logged every N batches. + tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) + model.fit(x_train, y_train, callbacks=[tb_callback]) + ``` + + Custom batch-level summaries in a Functional API Model: + + ```python + def my_summary(x): + tf.summary.histogram('x', x) + return x + + inputs = tf.keras.Input(10) + x = tf.keras.layers.Dense(10)(inputs) + outputs = tf.keras.layers.Lambda(my_summary)(x) + model = tf.keras.Model(inputs, outputs) + model.compile('sgd', 'mse') + + # Make sure to set `update_freq=N` to log a batch-level summary every N batches. + # In addition to any `tf.summary` contained in `Model.call`, metrics added in + # `Model.compile` will be logged every N batches. + tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1) + model.fit(x_train, y_train, callbacks=[tb_callback]) + ``` + + Profiling: + + ```python + # Profile a single batch, e.g. the 5th batch. + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir='./logs', profile_batch=5) + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + + # Profile a range of batches, e.g. from 10 to 20. + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir='./logs', profile_batch=(10,20)) + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + ``` """ # pylint: enable=line-too-long diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py index 251fb3476dc..5c0c3ff6e96 100644 --- a/tensorflow/python/keras/callbacks_v1.py +++ b/tensorflow/python/keras/callbacks_v1.py @@ -245,8 +245,8 @@ class TensorBoard(callbacks.TensorBoard): # visualize embeddings. if self.embeddings_freq and self.embeddings_data is not None: # Avoid circular dependency. - from tensorflow.python.keras.engine import training_utils # pylint: disable=g-import-not-at-top - self.embeddings_data = training_utils.standardize_input_data( + from tensorflow.python.keras.engine import training_utils_v1 # pylint: disable=g-import-not-at-top + self.embeddings_data = training_utils_v1.standardize_input_data( self.embeddings_data, model.input_names) # If embedding_layer_names are not provided, get all of the embedding diff --git a/tensorflow/python/keras/datasets/BUILD b/tensorflow/python/keras/datasets/BUILD index de15f9a53ef..67f400db4b0 100644 --- a/tensorflow/python/keras/datasets/BUILD +++ b/tensorflow/python/keras/datasets/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 5ca0d5a27a3..79c36c0f559 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -16,8 +16,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -337,8 +335,10 @@ distribute_py_test( distribute_py_test( name = "distribute_strategy_test", srcs = ["distribute_strategy_test.py"], + disable_mlir_bridge = True, # TODO(b/170352626) full_precision = True, main = "distribute_strategy_test.py", + python_version = "PY3", shard_count = 10, tags = [ "multi_and_single_gpu", @@ -403,7 +403,7 @@ distribute_py_test( name = "keras_dnn_correctness_test", size = "medium", srcs = ["keras_dnn_correctness_test.py"], - disable_mlir_bridge = False, + disable_mlir_bridge = True, # TODO(b/170352626) full_precision = True, main = "keras_dnn_correctness_test.py", # Shard count is set to an odd number to distribute tasks across @@ -424,6 +424,7 @@ distribute_py_test( name = "keras_embedding_model_correctness_test", size = "medium", srcs = ["keras_embedding_model_correctness_test.py"], + disable_mlir_bridge = True, # TODO(b/170352626) full_precision = True, main = "keras_embedding_model_correctness_test.py", shard_count = 8, @@ -441,7 +442,7 @@ distribute_py_test( name = "keras_image_model_correctness_test", size = "medium", srcs = ["keras_image_model_correctness_test.py"], - disable_mlir_bridge = False, + disable_mlir_bridge = True, # TODO(b/170352626) full_precision = True, main = "keras_image_model_correctness_test.py", shard_count = 16, @@ -488,6 +489,7 @@ distribute_py_test( "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/eager:test", + "//tensorflow/python/keras", "@absl_py//absl/testing:parameterized", ], ) @@ -553,6 +555,7 @@ distribute_py_test( distribute_py_test( name = "keras_utils_test", srcs = ["keras_utils_test.py"], + disable_mlir_bridge = True, # TODO(b/170352626) full_precision = True, main = "keras_utils_test.py", shard_count = 4, @@ -677,6 +680,7 @@ cuda_py_test( "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras:metrics", "//tensorflow/python/keras/layers:core", ], ) @@ -685,7 +689,7 @@ cuda_py_test( name = "multi_worker_test", srcs = ["multi_worker_test.py"], python_version = "PY3", - shard_count = 32, + shard_count = 2, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/130369494): Investigate why it times out on OSS. diff --git a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py index 7dd285a585f..f22a954d78d 100644 --- a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py +++ b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py @@ -24,9 +24,8 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import collective_all_reduce_strategy +from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib from tensorflow.python.distribute import combinations as ds_combinations -from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util @@ -80,7 +79,7 @@ def create_test_objects(cluster_spec=None, ClusterSpec({}), num_accelerators={'GPU': num_gpus}) target = '' - strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + strategy = mwms_lib.CollectiveAllReduceStrategy( cluster_resolver=cluster_resolver) sess_config = strategy.update_config_proto(sess_config) @@ -95,9 +94,7 @@ class CollectiveAllReduceStrategyTestBase( def setUp(self): # We use a different key_base for each test so that collective keys won't be # reused. - # TODO(yuefengz, ayushd): enable it to reuse collective keys in different - # tests. - CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 + mwms_lib.CollectiveAllReduceStrategy._collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() def _get_test_object(self, task_type, task_id, num_gpus=0): @@ -106,18 +103,6 @@ class CollectiveAllReduceStrategyTestBase( task_type=task_type, task_id=task_id, num_gpus=num_gpus) - - collective_keys = cross_device_utils.CollectiveKeys( - group_key_start=10 + - CollectiveAllReduceStrategyTestBase.collective_key_base, - op_instance_key_start=100 + - CollectiveAllReduceStrategyTestBase.collective_key_base, - variable_instance_key_start=10000 + - CollectiveAllReduceStrategyTestBase.collective_key_base) - strategy.extended._collective_keys = collective_keys - strategy.extended._cross_device_ops._collective_keys = collective_keys - strategy.extended._host_cross_device_ops._collective_keys = collective_keys - return strategy, target, session_config def _test_complex_model(self, task_type, task_id, num_gpus): diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 18485749b18..7bc70101fb4 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -17,15 +17,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from absl.testing import parameterized import numpy as np + from tensorflow.python import keras from tensorflow.python.data.experimental.ops import cardinality +from tensorflow.python.data.experimental.ops import writers +from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers from tensorflow.python.distribute import central_storage_strategy +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import parameter_server_strategy @@ -36,6 +44,7 @@ from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_combinations as combinations from tensorflow.python.keras import testing_utils @@ -50,6 +59,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import loss_reduction from tensorflow.python.ops.ragged import ragged_tensor @@ -152,8 +162,7 @@ def batch_wrapper(dataset, batch_size, distribution, repeat=None): dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. - if isinstance(distribution, - (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + if _is_tpu_strategy(distribution): return dataset.batch(batch_size, drop_remainder=True) else: return dataset.batch(batch_size) @@ -237,11 +246,18 @@ strategies_minus_tpu = [ strategy_combinations.central_storage_strategy_with_gpu_and_cpu ] +multi_worker_mirrored_strategies = [ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, +] + tpu_strategies = [ strategy_combinations.tpu_strategy, ] -all_strategies = strategies_minus_tpu + tpu_strategies +all_strategies = ( + strategies_minus_tpu + tpu_strategies + multi_worker_mirrored_strategies) def strategy_minus_tpu_combinations(): @@ -258,8 +274,14 @@ def tpu_strategy_combinations_graph_only(): return combinations.combine(distribution=tpu_strategies, mode=['graph']) +def multi_worker_strategy_combinations_eager_only(): + return combinations.combine( + distribution=multi_worker_mirrored_strategies, mode=['eager']) + + def all_strategy_combinations(): - return strategy_minus_tpu_combinations() + tpu_strategy_combinations() + return strategy_minus_tpu_combinations() + tpu_strategy_combinations( + ) + multi_worker_strategy_combinations_eager_only() def all_strategy_minus_default_and_tpu_combinations(): @@ -275,7 +297,8 @@ def all_strategy_minus_default_and_tpu_combinations(): def all_strategy_combinations_minus_default(): return (all_strategy_minus_default_and_tpu_combinations() + - tpu_strategy_combinations()) + tpu_strategy_combinations() + + multi_worker_strategy_combinations_eager_only()) def strategy_and_optimizer_combinations(): @@ -318,7 +341,21 @@ def strategy_and_optimizer_combinations(): optimizer_combinations.gradient_descent_optimizer_keras_v2_fn, optimizer_combinations.rmsprop_optimizer_keras_v2_fn ]) - return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph + multi_worker_eager = combinations.combine( + distribution=multi_worker_mirrored_strategies, + mode=['eager'], + optimizer=[ + optimizer_combinations.adadelta_optimizer_keras_v2_fn, + optimizer_combinations.adagrad_optimizer_keras_v2_fn, + optimizer_combinations.adam_optimizer_keras_v2_fn, + optimizer_combinations.adamax_optimizer_keras_v2_fn, + optimizer_combinations.gradient_descent_optimizer_keras_v2_fn, + optimizer_combinations.nadam_optimizer_keras_v2_fn, + optimizer_combinations.rmsprop_optimizer_keras_v2_fn, + optimizer_combinations.ftrl_optimizer_keras_v2_fn + ]) + return (non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph + + multi_worker_eager) class BatchCountingCB(keras.callbacks.Callback): @@ -494,8 +531,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, if isinstance(distribution.extended, parameter_server_strategy.ParameterServerStrategyExtended): self.skipTest('b/152097775') - if isinstance(distribution, - (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + if _is_tpu_strategy(distribution): policy_name = 'mixed_bfloat16' else: policy_name = 'mixed_float16' @@ -545,8 +581,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameter_server_strategy.ParameterServerStrategyExtended): self.skipTest('b/152097775') - if isinstance(distribution, - (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + if _is_tpu_strategy(distribution): policy_name = 'mixed_bfloat16' else: policy_name = 'mixed_float16' @@ -627,8 +662,9 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @ds_combinations.generate( combinations.combine( - distribution=strategies_minus_tpu, - mode=['graph', 'eager'])) + distribution=strategies_minus_tpu, mode=['graph', 'eager']) + + combinations.combine( + distribution=multi_worker_mirrored_strategies, mode=['eager'])) def test_numpy_with_sample_weights(self, distribution): with self.cached_session(), distribution.scope(): model = get_sample_weights_model() @@ -989,8 +1025,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, def test_fit_with_dictionary_in_the_dataset_b135161171( self, distribution): - if isinstance(distribution, - (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + if _is_tpu_strategy(distribution): self.skipTest('b/142805125') def custom_loss(predict, label, weight): @@ -1069,19 +1104,56 @@ class TestDistributionStrategyWithDatasets(test.TestCase, self.assertAllClose( predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + @ds_combinations.generate(all_strategy_combinations()) + def test_predict_on_dataset_with_unknown_cardinality_without_steps( + self, distribution, mode): + # TODO(b/155867206): Investigate why this test occasionally segfaults on TPU + # in eager mode. + if mode == 'eager' and _is_tpu_strategy(distribution): + self.skipTest('caused segfault with TPU in eager mode.') + + if mode == 'graph' and _is_tpu_strategy(distribution): + self.skipTest('partial batch not supported with TPU in graph mode.') + + with self.cached_session(): + with distribution.scope(): + optimizer_fn = gradient_descent_keras.SGD + optimizer = optimizer_fn(0.001) + model = get_model() + loss = 'mse' + metrics = ['mae', keras.metrics.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((20, 3), dtype=np.float32) + # steps/steps_per_epoch are calculated when using numpy arrays as + # input data. + predict_with_numpy = model.predict(inputs, batch_size=10) + + predict_dataset = convert_numpy_to_dataset_with_unknown_cardinality( + inputs) + + self.assertEqual( + keras.backend.get_value(cardinality.cardinality(predict_dataset)), + cardinality.UNKNOWN) + + predict_with_ds = model.predict(predict_dataset) + self.assertAllClose( + predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4) + @ds_combinations.generate(all_strategy_combinations()) def test_on_dataset_with_unknown_cardinality_without_steps( self, distribution, mode): # TODO(b/155867206): Investigate why this test occasionally segfaults on TPU # in eager mode. - if mode == 'eager' and isinstance( - distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + if mode == 'eager' and _is_tpu_strategy(distribution): self.skipTest('caused segfault with TPU in eager mode.') - if mode == 'graph' and isinstance( - distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + if mode == 'graph' and _is_tpu_strategy(distribution): self.skipTest('partial batch not supported with TPU in graph mode.') + if isinstance(distribution, + collective_all_reduce_strategy.CollectiveAllReduceStrategy): + self.skipTest('EOF error causes subsequent collective ops fail.') with self.cached_session(): with distribution.scope(): optimizer_fn = gradient_descent_keras.SGD @@ -1299,7 +1371,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, metrics=metrics) batch_size = 8 - if isinstance(distribution, mirrored_strategy.MirroredStrategy): + if isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.MirroredStrategyV1)): # MirroredStrategy uses global batch size. batch_size = 8 * distribution.num_replicas_in_sync @@ -1514,8 +1587,9 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @ds_combinations.generate( combinations.combine( - distribution=strategies_minus_tpu, - mode=['graph', 'eager'])) + distribution=strategies_minus_tpu, mode=['graph', 'eager']) + + combinations.combine( + distribution=multi_worker_mirrored_strategies, mode=['eager'])) def test_dataset_with_sample_weights(self, distribution): with self.cached_session(), distribution.scope(): model = get_sample_weights_model() @@ -1551,6 +1625,57 @@ class TestDistributionStrategyWithDatasets(test.TestCase, self.assertAllClose(result, 13.5) +def _is_tpu_strategy(strategy): + if isinstance(strategy, + (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + return True + return False + + +class TestDistributionStrategyWithDatasetsFile(test.TestCase, + parameterized.TestCase): + + def setUp(self): + super(TestDistributionStrategyWithDatasetsFile, self).setUp() + self.input_file_name = os.path.join(self.get_temp_dir(), 'input.tfrecord') + inputs = np.zeros((20, 3), dtype=np.float32) + input_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + input_dataset = input_dataset.map(parsing_ops.serialize_tensor) + writer = writers.TFRecordWriter(self.input_file_name) + writer.write(input_dataset) + + # TODO(wxinyi): add a multi-worker test for TPU + @ds_combinations.generate(multi_worker_strategy_combinations_eager_only()) + def test_predict_on_dataset_shard_options_file_multi_worker_mirrored( + self, distribution, mode): + # This test is to verify if we successfully switch auto_shard_policy of a + # input dataset inside model.predict with MultiWorkerMirroredStrategy to + # AutoShardPolicy.DATA. Since there is only one input file for multiple + # workers, AutoShardPolicy.AUTO or AutoShardPolicy.FILE will lead to an + # error. However, since we switch to AutoShardPolicy.DATA in model.predict, + # no error is raised. + del mode + with distribution.scope(): + optimizer_fn = gradient_descent_keras.SGD + optimizer = optimizer_fn(0.001) + model = get_model() + loss = 'mse' + model.compile(optimizer, loss) + + dataset = readers.TFRecordDataset(self.input_file_name) + dataset = dataset.map(lambda x: parsing_ops.parse_tensor(x, dtypes.float32)) + + dummy_op = lambda inp: True + + dataset = dataset.filter(dummy_op).batch(8, drop_remainder=True) + + options = dataset_ops.Options() + options.experimental_distribute.auto_shard_policy = AutoShardPolicy.FILE + dataset = dataset.with_options(options) + + model.predict(dataset, steps=1) + + class TestRegularizerLoss(test.TestCase, parameterized.TestCase): class IdentityRegularizer(keras.regularizers.Regularizer): @@ -1887,7 +2012,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, model.compile(optimizer, 'mae') if isinstance(distribution, - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): with self.assertRaisesRegex(ValueError, 'not supported'): model.fit(x, y, batch_size=10, epochs=1) else: @@ -1899,7 +2025,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, combinations.combine(distribution=all_strategies, mode=['eager'])) def test_custom_gradient_transformation(self, distribution): if isinstance(distribution, - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): self.skipTest('Not supported with `CentralStorageStrategy`') class MyLayer(keras.layers.Layer): @@ -2196,7 +2323,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, @ds_combinations.generate( combinations.combine( - distribution=strategies_minus_tpu, + distribution=strategies_minus_tpu + multi_worker_mirrored_strategies, mode=['eager'])) def test_sparse_tensor_outputs(self, distribution): @@ -2225,7 +2352,7 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, @ds_combinations.generate( combinations.combine( - distribution=strategies_minus_tpu, + distribution=strategies_minus_tpu + multi_worker_mirrored_strategies, mode=['eager'])) def test_ragged_tensor_outputs(self, distribution): @@ -2252,7 +2379,8 @@ class TestDistributionStrategyWithKerasModels(test.TestCase, @ds_combinations.generate( combinations.combine( - distribution=strategies_minus_default_minus_tpu + tpu_strategies, + distribution=strategies_minus_default_minus_tpu + tpu_strategies + + multi_worker_mirrored_strategies, mode=['eager'])) def test_correctness_of_add_loss_with_merge_call(self, distribution): batch_size = 32 @@ -2568,4 +2696,4 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase): if __name__ == '__main__': base_layer_utils.enable_v2_dtype_behavior() - test.main() + multi_process_runner.test_main() diff --git a/tensorflow/python/keras/distribute/distributed_training_utils_v1.py b/tensorflow/python/keras/distribute/distributed_training_utils_v1.py index 83426016412..2e7a8299e43 100644 --- a/tensorflow/python/keras/distribute/distributed_training_utils_v1.py +++ b/tensorflow/python/keras/distribute/distributed_training_utils_v1.py @@ -36,9 +36,9 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks from tensorflow.python.keras import metrics as metrics_module -from tensorflow.python.keras import optimizer_v1 +from tensorflow.python.keras import optimizers from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils -from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils.mode_keys import ModeKeys @@ -639,9 +639,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode): # TODO(b/124535720): Remove once this standarize data logic is shared with # main flow. inputs, targets = nest.map_structure( - training_utils.standardize_single_array, (inputs, targets)) + training_utils_v1.standardize_single_array, (inputs, targets)) else: - inputs = training_utils.ModelInputs(inputs).as_list() + inputs = training_utils_v1.ModelInputs(inputs).as_list() if mode == ModeKeys.PREDICT: sample_weights = [] @@ -779,7 +779,7 @@ def _clone_and_build_model(model, mode, inputs=None, targets=None): cloned_model = models.clone_model(model, input_tensors=inputs) # Compile and build model. - if isinstance(model.optimizer, optimizer_v1.TFOptimizer): + if isinstance(model.optimizer, optimizers.TFOptimizer): optimizer = model.optimizer else: optimizer_config = model.optimizer.get_config() diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index 2639b6a78b8..77a5f290439 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -24,7 +24,6 @@ import numpy as np import six from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import strategy_combinations @@ -287,12 +286,7 @@ def fit_eval_and_predict(initial_weights, result['weights_1'] = model.get_weights() - # TODO(b/157924053): Now model.predict() doesn't support - # MultiWorkerMirroredStrategy. Enable model.predict() after it's supported. - if predict_inputs is not None and not isinstance( - distribution, - (collective_all_reduce_strategy.CollectiveAllReduceStrategy, - collective_all_reduce_strategy.CollectiveAllReduceStrategyV1)): + if predict_inputs is not None: # Check correctness of the result of predict() invoked # multiple times -- as for stateful models, result of # predict may differ for each batch. @@ -346,6 +340,7 @@ def compare_results(results_with_ds, # so use larger tolerance for now. Predict should be related to weights. if (isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.MirroredStrategyV1, distribute_lib._DefaultDistributionStrategy)) and # pylint: disable=protected-access key.startswith(('weights_1', 'weights_2', 'predict_result'))): return relaxed_tolerance diff --git a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py index e47b3ba519c..0d93d925497 100644 --- a/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_image_model_correctness_test.py @@ -51,7 +51,10 @@ class DistributionStrategyCnnCorrectnessTest( if self.with_batch_norm == 'regular': c1 = keras.layers.BatchNormalization(name='bn1')(c1) elif self.with_batch_norm == 'sync': - c1 = keras.layers.SyncBatchNormalization(name='bn1')(c1) + # Test with parallel batch norms to verify all-reduce works OK. + bn1 = keras.layers.SyncBatchNormalization(name='bn1')(c1) + bn2 = keras.layers.SyncBatchNormalization(name='bn2')(c1) + c1 = keras.layers.Add()([bn1, bn2]) c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) logits = keras.layers.Dense( 10, activation='softmax', name='pred')( diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py index 3f94aae2948..090f526f7ed 100644 --- a/tensorflow/python/keras/distribute/keras_utils_test.py +++ b/tensorflow/python/keras/distribute/keras_utils_test.py @@ -23,9 +23,11 @@ import tempfile from absl.testing import parameterized import numpy as np + from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations as ds_combinations +from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import values @@ -575,4 +577,4 @@ class TestDistributionStrategyWithStaticShapes(test.TestCase, if __name__ == '__main__': - test.main() + multi_process_runner.test_main() diff --git a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py index 07bc974a352..ea4b349b1cf 100644 --- a/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_callback_tf2_test.py @@ -204,7 +204,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): if 'Interrupting!' not in str(e): raise - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() backup_filepath = os.path.join(bar_dir, 'checkpoint') test_obj.assertTrue(file_io.file_exists_v2(backup_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath)) @@ -218,7 +218,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): callbacks.BackupAndRestore(backup_dir=bar_dir), AssertCallback() ]) - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() test_obj.assertFalse(file_io.file_exists_v2(backup_filepath)) test_obj.assertTrue(file_io.file_exists_v2(saving_filepath)) @@ -306,7 +306,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): # The saving_filepath shouldn't exist at the beginning (as it's unique). test_obj.assertFalse(file_io.file_exists_v2(saving_filepath)) - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() model.fit( x=train_ds, @@ -314,7 +314,7 @@ class KerasCallbackMultiProcessTest(parameterized.TestCase, test.TestCase): steps_per_epoch=steps, callbacks=[callbacks.TensorBoard(log_dir=saving_filepath)]) - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() test_obj.assertTrue(file_io.list_directory_v2(saving_filepath)) diff --git a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py b/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py index 556ffe30236..42d0e4d4630 100644 --- a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py @@ -109,7 +109,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase): num_workers = 4 - def proc_func(model_path, checkpoint_dir): + def fn(model_path, checkpoint_dir): global_batch_size = per_worker_batch_size * num_workers strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() with strategy.scope(): @@ -158,7 +158,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase): file_io.delete_recursively_v2(os.path.dirname(write_model_path)) # Make sure chief finishes saving before non-chief's assertions. - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() if not file_io.file_exists_v2(model_path): raise RuntimeError() @@ -179,7 +179,7 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase): file_io.delete_recursively_v2(write_checkpoint_dir) # Make sure chief finishes saving before non-chief's assertions. - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() if not file_io.file_exists_v2(checkpoint_dir): raise RuntimeError() @@ -198,10 +198,10 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase): checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') with test_util.skip_if_error(self, errors_impl.UnavailableError): mpr_result = multi_process_runner.run( - proc_func, + fn, multi_worker_test_base.create_cluster_spec(num_workers=num_workers), args=(model_path, checkpoint_dir), - list_stdout=True) + return_output=True) self.assertTrue( any([ diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py index f4e546b6e14..e4801d909ec 100644 --- a/tensorflow/python/keras/distribute/parameter_server_training_test.py +++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py @@ -73,7 +73,7 @@ class KPLTest(test.TestCase): with self.client.strategy.scope(): - # Define KPLs under client's context. Right now, if they have look up + # Define KPLs under strategy's scope. Right now, if they have look up # tables, they will be created on the client. Their variables will be # created on PS. Ideally they should be cached on each worker since they # will not be changed in a training step. diff --git a/tensorflow/python/keras/distribute/worker_training_state_test.py b/tensorflow/python/keras/distribute/worker_training_state_test.py index 1df136827e6..cb65747c239 100644 --- a/tensorflow/python/keras/distribute/worker_training_state_test.py +++ b/tensorflow/python/keras/distribute/worker_training_state_test.py @@ -41,26 +41,28 @@ class ModelCheckpointTest(test_base.IndependentWorkerTestBase, file_format=['h5', 'tf'], save_weights_only=[True, False])) def testCheckpointExists(self, file_format, save_weights_only): - train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2) - model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) - saving_dir = self.get_temp_dir() - saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format) - callbacks_list = [ - callbacks.ModelCheckpoint( - filepath=saving_filepath, save_weights_only=save_weights_only) - ] - self.assertFalse(file_io.file_exists_v2(saving_filepath)) + with self.cached_session(): + train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2) + model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) + saving_dir = self.get_temp_dir() + saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format) + callbacks_list = [ + callbacks.ModelCheckpoint( + filepath=saving_filepath, save_weights_only=save_weights_only) + ] + self.assertFalse(file_io.file_exists_v2(saving_filepath)) - try: - model.fit( - x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list) - except NotFoundError as e: - if 'Failed to create a NewWriteableFile' in e.message: - self.skipTest('b/138941852, path not found error in Windows py35.') - tf_saved_model_exists = file_io.file_exists_v2(saving_filepath) - tf_weights_only_checkpoint_exists = file_io.file_exists_v2( - saving_filepath + '.index') - self.assertTrue(tf_saved_model_exists or tf_weights_only_checkpoint_exists) + try: + model.fit( + x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list) + except NotFoundError as e: + if 'Failed to create a NewWriteableFile' in e.message: + self.skipTest('b/138941852, path not found error in Windows py35.') + tf_saved_model_exists = file_io.file_exists_v2(saving_filepath) + tf_weights_only_checkpoint_exists = file_io.file_exists_v2( + saving_filepath + '.index') + self.assertTrue( + tf_saved_model_exists or tf_weights_only_checkpoint_exists) if __name__ == '__main__': diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 7606af6f4a6..ca6803101d5 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -14,8 +14,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -38,6 +36,7 @@ py_library( "training_eager_v1.py", "training_generator_v1.py", "training_utils.py", + "training_utils_v1.py", "training_v1.py", ], srcs_version = "PY2AND3", @@ -64,6 +63,7 @@ py_library( "//tensorflow/python/keras:constraints", "//tensorflow/python/keras:initializers", "//tensorflow/python/keras:losses", + "//tensorflow/python/keras:metrics", "//tensorflow/python/keras:optimizers", "//tensorflow/python/keras:regularizers", "//tensorflow/python/keras/distribute", @@ -254,6 +254,7 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:lookup_ops", + "//tensorflow/python/keras", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:combinations", ], @@ -272,6 +273,7 @@ tf_py_test( deps = [ ":data_adapter", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", "//third_party/py/numpy", ], ) @@ -300,7 +302,6 @@ cuda_py_test( tags = [ "nomac", # TODO(mihaimaruseac): b/127695564 ], - tfrt_enabled = True, deps = [ ":engine", "//tensorflow/python:client_testlib", @@ -347,6 +348,22 @@ tf_py_test( ], ) +tf_py_test( + name = "ragged_keras_tensor_test", + size = "small", + srcs = ["ragged_keras_tensor_test.py"], + python_version = "PY3", + tags = [ + "nomac", # TODO(mihaimaruseac): b/127695564 + ], + tfrt_enabled = True, + deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_test( name = "input_spec_test", size = "small", @@ -542,15 +559,14 @@ tf_py_test( ) tf_py_test( - name = "training_utils_test", + name = "training_utils_v1_test", size = "medium", - srcs = ["training_utils_test.py"], + srcs = ["training_utils_v1_test.py"], python_version = "PY3", tags = [ "no_oss", # TODO(b/135021748) reenable "notsan", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -571,6 +587,8 @@ tf_py_test( deps = [ ":base_layer", ":engine", + "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//tensorflow/python/keras:testing_utils", "//tensorflow/python/keras/utils:layer_utils", ], @@ -598,6 +616,7 @@ tf_py_test( "//tensorflow/python:state_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python/eager:context", + "//tensorflow/python/keras", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:combinations", "//tensorflow/python/keras:initializers", @@ -625,6 +644,7 @@ tf_py_test( deps = [ ":base_layer", ":engine", + "//tensorflow/python/keras", "//tensorflow/python/keras:testing_utils", "//tensorflow/python/keras/utils:layer_utils", ], @@ -661,11 +681,13 @@ tf_py_test( "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", + "//tensorflow/python/keras", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:combinations", "//tensorflow/python/keras:regularizers", "//tensorflow/python/keras:testing_utils", "//tensorflow/python/keras/layers", + "//tensorflow/python/keras/legacy_tf_layers:core", "//tensorflow/python/keras/mixed_precision/experimental:policy", "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/keras/utils:tf_utils", @@ -709,7 +731,7 @@ tf_py_test( tf_py_test( name = "deferred_sequential_test", - size = "small", + size = "medium", srcs = ["deferred_sequential_test.py"], python_version = "PY3", tags = [ diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 4b732c07bec..bde0a63b212 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -86,6 +86,12 @@ from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls +# pylint: disable=g-inconsistent-quotes +metrics_mod = generic_utils.LazyLoader( + "metrics_mod", globals(), + "tensorflow.python.keras.metrics") +# pylint: enable=g-inconsistent-quotes + # Prefix that is added to the TF op layer names. _TF_OP_LAYER_NAME_PREFIX = 'tf_op_layer_' @@ -96,9 +102,11 @@ _AUTOCAST_TYPES = (ops.Tensor, sparse_tensor.SparseTensor, keras_layers_gauge = monitoring.BoolGauge('/tensorflow/api/keras/layers', 'keras layers usage', 'method') +keras_models_gauge = monitoring.BoolGauge( + '/tensorflow/api/keras/models', 'keras model usage', 'method') keras_api_gauge = monitoring.BoolGauge('/tensorflow/api/keras', 'keras api usage', 'method') -keras_model_gauge = monitoring.BoolGauge( +keras_premade_model_gauge = monitoring.BoolGauge( '/tensorflow/api/keras/premade_models', 'premade keras model usage', 'type') @@ -304,7 +312,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector): dynamic=False, **kwargs): keras_api_gauge.get_cell('layer').set(True) - keras_layers_gauge.get_cell(self.__class__.__name__).set(True) + if getattr(self, '_is_model_for_instrumentation', False): + keras_models_gauge.get_cell(self.__class__.__name__).set(True) + else: + keras_layers_gauge.get_cell(self.__class__.__name__).set(True) # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' # are only applicable to input layers: do not pass these keywords @@ -1602,7 +1613,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): ['max', 'min'] Returns: - A list of tensors. + A list of `Metric` objects. """ collected_metrics = [] for layer in self._flatten_layers(): @@ -1620,11 +1631,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): class MyMetricLayer(tf.keras.layers.Layer): def __init__(self): super(MyMetricLayer, self).__init__(name='my_metric_layer') - self.mean = metrics_module.Mean(name='metric_1') + self.mean = tf.keras.metrics.Mean(name='metric_1') def call(self, inputs): self.add_metric(self.mean(x)) - self.add_metric(math_ops.reduce_sum(x), name='metric_2') + self.add_metric(tf.reduce_sum(x), name='metric_2') return inputs ``` @@ -1716,7 +1727,6 @@ class Layer(module.Module, version_utils.LayerVersionSelector): elif metric_obj: self._metrics.append(metric_obj) else: - from tensorflow.python.keras import metrics as metrics_mod # pylint:disable=g-import-not-at-top # Build the metric object with the value's dtype if it defines one metric_obj = metrics_mod.Mean( name=name, dtype=getattr(value, 'dtype', None)) @@ -2428,6 +2438,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): generic_utils.to_snake_case(self.__class__.__name__), zero_based=zero_based) else: + backend.observe_object_name(name) self._name = name def _get_existing_metric(self, name=None): @@ -2798,9 +2809,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): pass # Keep track of metric instance created in subclassed layer. - from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top for val in nest.flatten(value): - if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): + if isinstance(val, metrics_mod.Metric) and hasattr(self, '_metrics'): self._metrics.append(val) # TODO(scottzhu): Need to track Module object as well for weight tracking. @@ -2877,7 +2887,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): continue seen_object_ids.add(layer_or_container_id) - if isinstance(layer_or_container, Layer): + if (isinstance(layer_or_container, Layer) and + not isinstance(layer_or_container, metrics_mod.Metric)): yield layer_or_container # Introspect recursively through sublayers. if recursive: diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 321eefcadb1..f5cb63226f3 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -43,9 +43,9 @@ from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import training as training_lib +from tensorflow.python.keras.legacy_tf_layers import core as legacy_core from tensorflow.python.keras.optimizer_v2 import rmsprop from tensorflow.python.keras.utils import control_flow_util -from tensorflow.python.layers import core as legacy_core from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -469,6 +469,32 @@ class BaseLayerTest(keras_parameterized.TestCase): ] self.assertAllEqual(actual_names, expected_names) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_layer_names_after_loading(self): + if context.executing_eagerly(): + backend.clear_session() + with testing_utils.use_keras_tensors_scope(True): + # Mimic loading a model that already contained add layers with + # name = 'add_1' and 'tf.__operators__.add' + layers.Add(name='add_1') + layers.Add(name='tf.__operators__.add') + + inputs = input_layer.Input(shape=[2]) + add1 = inputs + inputs + add2 = layers.Add()([inputs, inputs]) + add3 = inputs + inputs + add4 = layers.Add()([inputs, inputs]) + model = training_lib.Model( + inputs=[inputs], outputs=[add1, add2, add3, add4]) + actual_names = [l.name for l in model.layers] + # The generated op layer names should have avoided layer names seen in + # the loaded model. (This avoiance should not apply to non-op-layers) + expected_names = [ + 'input_1', 'tf.__operators__.add_1', + 'add', 'tf.__operators__.add_2', 'add_1' + ] + self.assertAllEqual(actual_names, expected_names) + def test_add_trainable_weight_on_frozen_layer(self): class TestLayer(base_layer.Layer): @@ -789,7 +815,6 @@ class SymbolicSupportTest(keras_parameterized.TestCase): with ops.Graph().as_default(): x1 = array_ops.ones((3, 3)) x2 = array_ops.ones((3, 3)) - self.assertIsInstance(x2, ops.EagerTensor) with self.assertRaisesRegex(TypeError, 'Graph tensors'): math_ops.matmul(x1, x2) diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 98414bc9d49..d6b32907593 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -847,7 +847,7 @@ def no_ragged_support(inputs, layer_name): def is_split_variable(v): - """Returns True if `v` is either a PartionedVariable or a SharedVariable.""" + """Returns True if `v` is either a PartionedVariable or a ShardedVariable.""" return hasattr(v, '_variable_list') or hasattr(v, '_variables') diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index dc34dd524a3..fd9db0e4346 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -835,8 +835,9 @@ class Layer(base_layer.Layer): def _assert_built_as_v1(self): if not hasattr(self, '_originally_built_as_v1'): raise ValueError( - 'Your Layer or Model is in an invalid state. This can happen if you ' - 'are interleaving estimator/non-estimator models or ' + 'Your Layer or Model is in an invalid state. ' + 'This can happen for the following cases:\n ' + '1. You might be interleaving estimator/non-estimator models or ' 'interleaving models/layers made in tf.compat.v1.Graph.as_default() ' 'with models/layers created outside of it. ' 'Converting a model to an estimator (via model_to_estimator) ' @@ -844,7 +845,11 @@ class Layer(base_layer.Layer): 'if they were not the model converted to an estimator). ' 'Similarly, making a layer or a model inside a ' 'a tf.compat.v1.Graph invalidates all layers/models you previously ' - 'made outside of the graph.') + 'made outside of the graph.\n' + '2. You might be using a custom keras layer implementation with ' + ' custom __init__ which didn\'t call super().__init__. ' + ' Please check the implementation of %s and its bases.' % + (type(self),)) @property def dtype(self): diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index 0818825669d..318323b395a 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -271,7 +271,7 @@ def convert_to_list(values, sparse_default_value=None): values, default_value=sparse_default_value) values = K.get_value(dense_tensor) - if isinstance(values, (ops.EagerTensor, ops.Tensor)): + if isinstance(values, ops.Tensor): values = K.get_value(values) # We may get passed a ndarray or the code above may give us a ndarray. diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index f4c80feddc8..2cc6f69403e 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -34,6 +34,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import input_lib from tensorflow.python.eager import context +from tensorflow.python.eager import monitoring from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -52,6 +53,9 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export +keras_data_adapter_gauge = monitoring.BoolGauge( + "/tensorflow/api/keras/data_adapters", "keras data adapter usage", "method") + try: from scipy import sparse as scipy_sparse # pylint: disable=g-import-not-at-top except ImportError: @@ -961,6 +965,8 @@ def select_data_adapter(x, y): "handling inputs. Found multiple adapters {} to handle " "input: {}, {}".format( adapter_cls, _type_name(x), _type_name(y))) + # Instrument the data adapter usage before returning it + keras_data_adapter_gauge.get_cell(adapter_cls[0].__name__).set(True) return adapter_cls[0] diff --git a/tensorflow/python/keras/engine/feature_columns_integration_test.py b/tensorflow/python/keras/engine/feature_columns_integration_test.py index 04d708feb5e..b70e33b9c8f 100644 --- a/tensorflow/python/keras/engine/feature_columns_integration_test.py +++ b/tensorflow/python/keras/engine/feature_columns_integration_test.py @@ -26,6 +26,7 @@ from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.feature_column import dense_features as df from tensorflow.python.keras.utils import np_utils from tensorflow.python.platform import test @@ -34,7 +35,7 @@ class TestDNNModel(keras.models.Model): def __init__(self, feature_columns, units, name=None, **kwargs): super(TestDNNModel, self).__init__(name=name, **kwargs) - self._input_layer = fc.DenseFeatures(feature_columns, name='input_layer') + self._input_layer = df.DenseFeatures(feature_columns, name='input_layer') self._dense_layer = keras.layers.Dense(units, name='dense_layer') def call(self, features): @@ -52,7 +53,7 @@ class FeatureColumnsIntegrationTest(keras_parameterized.TestCase): def test_sequential_model(self): columns = [fc.numeric_column('a')] model = keras.models.Sequential([ - fc.DenseFeatures(columns), + df.DenseFeatures(columns), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(20, activation='softmax') ]) @@ -74,7 +75,7 @@ class FeatureColumnsIntegrationTest(keras_parameterized.TestCase): def test_sequential_model_with_ds_input(self): columns = [fc.numeric_column('a')] model = keras.models.Sequential([ - fc.DenseFeatures(columns), + df.DenseFeatures(columns), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(20, activation='softmax') ]) @@ -112,7 +113,7 @@ class FeatureColumnsIntegrationTest(keras_parameterized.TestCase): crossed_feature = fc.indicator_column(crossed_feature) feature_columns.append(crossed_feature) - feature_layer = fc.DenseFeatures(feature_columns) + feature_layer = df.DenseFeatures(feature_columns) model = keras.models.Sequential([ feature_layer, @@ -185,7 +186,7 @@ class FeatureColumnsIntegrationTest(keras_parameterized.TestCase): col_a = fc.numeric_column('a') col_b = fc.numeric_column('b') - feature_layer = fc.DenseFeatures([col_a, col_b], name='fc') + feature_layer = df.DenseFeatures([col_a, col_b], name='fc') dense = keras.layers.Dense(4) # This seems problematic.... We probably need something for DenseFeatures @@ -213,8 +214,8 @@ class FeatureColumnsIntegrationTest(keras_parameterized.TestCase): col_b = fc.numeric_column('b') col_c = fc.numeric_column('c') - fc1 = fc.DenseFeatures([col_a, col_b], name='fc1') - fc2 = fc.DenseFeatures([col_b, col_c], name='fc2') + fc1 = df.DenseFeatures([col_a, col_b], name='fc1') + fc2 = df.DenseFeatures([col_b, col_c], name='fc2') dense = keras.layers.Dense(4) # This seems problematic.... We probably need something for DenseFeatures @@ -289,7 +290,7 @@ class FeatureColumnsIntegrationTest(keras_parameterized.TestCase): categorical_cols = [fc.categorical_column_with_hash_bucket('cabin', 10)] feature_cols = ([fc.numeric_column('age')] + [fc.indicator_column(cc) for cc in categorical_cols]) - layers = [fc.DenseFeatures(feature_cols), + layers = [df.DenseFeatures(feature_cols), keras.layers.Dense(128), keras.layers.Dense(1)] diff --git a/tensorflow/python/keras/engine/keras_tensor.py b/tensorflow/python/keras/engine/keras_tensor.py index 3aa9b595d4f..055752026c7 100644 --- a/tensorflow/python/keras/engine/keras_tensor.py +++ b/tensorflow/python/keras/engine/keras_tensor.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -25,6 +26,8 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec as type_spec_module from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import nest from tensorflow.python.util import object_identity @@ -50,6 +53,13 @@ def keras_tensors_enabled(): return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions() +# Tensorflow tensors have a maximum rank of 254 +# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h ) +# So we do not try to infer values for int32 tensors larger than this, +# As they cannot represent shapes. +_MAX_TENSOR_RANK = 254 + + class KerasTensor(object): """A representation of a Keras in/output during Functional API construction. @@ -153,6 +163,76 @@ class KerasTensor(object): # it can't access shape or dtype return self._type_spec._shape # pylint: disable=protected-access + @classmethod + def from_tensor(cls, tensor): + """Convert a traced (composite)tensor to a representative KerasTensor.""" + if isinstance(tensor, ops.Tensor): + name = getattr(tensor, 'name', None) + type_spec = type_spec_module.type_spec_from_value(tensor) + inferred_value = None + if (type_spec.dtype == dtypes.int32 and type_spec.shape.rank is not None + and type_spec.shape.rank < 2): + # If this tensor might be representing shape information, + # (dtype=int32, rank of 0 or 1, not too large to represent a shape) + # we attempt to capture any value information tensorflow's + # shape handling can extract from the current scratch graph. + # + # Even though keras layers each trace in their own scratch + # graph, this shape value info extraction allows us to capture + # a sizable and useful subset of the C++ shape value inference TF can do + # if all tf ops appear in the same graph when using shape ops. + # + # Examples of things this cannot infer concrete dimensions for + # that the full single-graph C++ shape inference sometimes can are: + # * cases where the shape tensor is cast out of int32 before being + # manipulated w/ floating point numbers then converted back + # * cases where int32 tensors w/ rank >= 2 are manipulated before being + # used as a shape tensor + # * cases where int32 tensors too large to represent shapes are + # manipulated to a smaller size before being used as a shape tensor + inferred_value = array_ops.ones(shape=tensor).shape + if inferred_value.dims: + inferred_value = inferred_value.as_list() + if len(inferred_value) > _MAX_TENSOR_RANK: + inferred_value = None + else: + inferred_value = None + + return KerasTensor(type_spec, inferred_value=inferred_value, name=name) + else: + # Fallback to the generic arbitrary-typespec KerasTensor + name = getattr(tensor, 'name', None) + type_spec = type_spec_module.type_spec_from_value(tensor) + return cls(type_spec, name=name) + + def _to_placeholder(self): + """Convert this KerasTensor to a placeholder in a graph.""" + # If there is an inferred value for this tensor, inject the inferred value + if self._inferred_value is not None: + # If we suspect this KerasTensor might be representing a shape tensor, + # and we were able to extract value information with TensorFlow's shape + # handling when making the KerasTensor, we construct the placeholder by + # re-injecting the inferred value information into the graph. We + # do this injection through the shape of a placeholder, because that + # allows us to specify partially-unspecified shape values. + # + # See the comment on value extraction inside `from_tensor` for more info. + inferred_value = array_ops.shape( + array_ops.placeholder( + shape=self._inferred_value, dtype=dtypes.int32)) + if self.type_spec.shape.rank == 0: + # `tf.shape` always returns a rank-1, we may need to turn it back to a + # scalar. + inferred_value = inferred_value[0] + return inferred_value + + # Use the generic conversion from typespec to a placeholder. + def component_to_placeholder(component): + return array_ops.placeholder(component.dtype, component.shape) + + return nest.map_structure( + component_to_placeholder, self.type_spec, expand_composites=True) + def get_shape(self): return self.shape @@ -298,26 +378,27 @@ class KerasTensor(object): return self._name @classmethod - def _overload_all_operators(cls): # pylint: disable=invalid-name + def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name """Register overloads for all operators.""" for operator in ops.Tensor.OVERLOADABLE_OPERATORS: - cls._overload_operator(operator) + cls._overload_operator(tensor_class, operator) # We include `experimental_ref` for versions of TensorFlow that # still include the deprecated method in Tensors. - if hasattr(ops.Tensor, 'experimental_ref'): - cls._overload_operator('experimental_ref') + if hasattr(tensor_class, 'experimental_ref'): + cls._overload_operator(tensor_class, 'experimental_ref') @classmethod - def _overload_operator(cls, operator): # pylint: disable=invalid-name - """Overload an operator with the same overloading as `ops.Tensor`. + def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid-name + """Overload an operator with the same implementation as a base Tensor class. - We pull the operator out of ops.Tensor dynamically to avoid ordering issues. + We pull the operator out of the class dynamically to avoid ordering issues. Args: + tensor_class: The (Composite)Tensor to get the method from. operator: string. The operator name. """ - tensor_oper = getattr(ops.Tensor, operator) + tensor_oper = getattr(tensor_class, operator) # Compatibility with Python 2: # Python 2 unbound methods have type checks for the first arg, @@ -327,81 +408,91 @@ class KerasTensor(object): setattr(cls, operator, tensor_oper) -KerasTensor._overload_all_operators() # pylint: disable=protected-access +KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access -class _KerasTensorIterator(object): - """Iterates over the leading dim of a KerasTensor. Performs 0 error checks.""" +class SparseKerasTensor(KerasTensor): + """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s. - def __init__(self, tensor, dim0): - self._tensor = tensor - self._index = 0 - self._limit = dim0 + Specifically, it specializes the conversion to a placeholder in order + to maintain dense shape information. + """ - def __iter__(self): - return self + def _to_placeholder(self): + spec = self.type_spec - def __next__(self): - if self._index == self._limit: - raise StopIteration - result = self._tensor[self._index] - self._index += 1 + # nest.map_structure loses dense shape information for sparse tensors. + # So, we special-case sparse placeholder creation. + # This only preserves shape information for top-level sparse tensors; + # not for sparse tensors that are nested inside another composite + # tensor. + return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape) + + +class RaggedKerasTensor(KerasTensor): + """A specialized KerasTensor representation for `tf.RaggedTensor`s. + + Specifically, it: + + 1. Specializes the conversion to a placeholder in order + to maintain shape information for non-ragged dimensions. + 2. Overloads the KerasTensor's operators with the RaggedTensor versions + when they don't match the `tf.Tensor` versions + 3. Exposes some of the instance method/attribute that are unique to + the RaggedTensor API (such as ragged_rank). + """ + + def _to_placeholder(self): + ragged_spec = self.type_spec + if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None: + return super(RaggedKerasTensor, self)._to_placeholder() + + flat_shape = ragged_spec.shape[ragged_spec.ragged_rank:] + result = array_ops.placeholder(ragged_spec.dtype, flat_shape) + + known_num_splits = [] + prod = 1 + for axis_size in ragged_spec.shape: + if prod is not None: + if axis_size is None or ( + getattr(axis_size, 'value', True) is None): + prod = None + else: + prod = prod * axis_size + known_num_splits.append(prod) + + for axis in range(ragged_spec.ragged_rank, 0, -1): + axis_size = ragged_spec.shape[axis] + if axis_size is None or (getattr(axis_size, 'value', True) is None): + num_splits = known_num_splits[axis-1] + if num_splits is not None: + num_splits = num_splits + 1 + splits = array_ops.placeholder( + ragged_spec.row_splits_dtype, [num_splits]) + result = ragged_tensor.RaggedTensor.from_row_splits( + result, splits, validate=False) + else: + rowlen = constant_op.constant(axis_size, ragged_spec.row_splits_dtype) + result = ragged_tensor.RaggedTensor.from_uniform_row_length( + result, rowlen, validate=False) return result - next = __next__ # python2.x compatibility. - - -def keras_tensor_to_placeholder(x): - """Construct a graph placeholder to represent a KerasTensor when tracing.""" - if hasattr(x, '_user_registered_symbolic_object'): - return x._user_registered_symbolic_object # pylint: disable=protected-access - - if isinstance(x, KerasTensor): - spec = x.type_spec - - if x._inferred_value is not None: # pylint: disable=protected-access - # If we suspect this KerasTensor might be representing a shape tensor, - # and we were able to extract value information with TensorFlow's shape - # handling when making the KerasTensor, we construct the placeholder by - # re-injecting the inferred value information into the graph. - # Even though keras layers each trace in their own scratch - # graph, this shape value info injection allows us to capture - # a sizable and useful subset of the C++ shape value inference TF can do - # if all tf ops appear in the same graph when using shape ops. - # - # Examples of things this cannot infer concrete dimensions for - # that the full single-graph C++ shape inference sometimes can are: - # * cases where the shape tensor is cast out of int32 before being - # manipulated w/ floating point numbers then converted back - # * cases where int32 tensors w/ rank > 2 are manipulated before being - # used as a shape tensor - inferred_value = array_ops.shape( - array_ops.placeholder( - shape=x._inferred_value, dtype=dtypes.int32)) # pylint: disable=protected-access - if spec.shape.rank == 0: - # `tf.shape` always returns a rank-1, we may need to turn it back to a - # scalar. - inferred_value = inferred_value[0] - return inferred_value # pylint: disable=protected-access - - if isinstance(spec, sparse_tensor.SparseTensorSpec): - # nest.map_structure loses dense shape information for sparse tensors. - # So, we special-case sparse placeholder creation. - # This only preserves shape information for top-level sparse tensors; - # not for sparse tensors that are nested inside another composite - # tensor. - return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape) - - def component_to_placeholder(component): - return array_ops.placeholder(component.dtype, component.shape) - - ph = nest.map_structure( - component_to_placeholder, spec, expand_composites=True) - return ph - else: - return x + @property + def ragged_rank(self): + return self.type_spec.ragged_rank + +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__add__') # pylint: disable=protected-access +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__radd__') # pylint: disable=protected-access +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__mul__') # pylint: disable=protected-access +RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__rmul__') # pylint: disable=protected-access +# TODO(b/161487382): +# Special-case user-registered symbolic objects (registered by the +# private `register_symbolic_tensor_type` method) by passing them between +# scratch graphs directly. +# This is needed to not break Tensorflow probability +# while they finish migrating to composite tensors. class UserRegisteredSpec(type_spec_module.TypeSpec): """TypeSpec to represent user-registered symbolic objects.""" @@ -425,72 +516,95 @@ class UserRegisteredSpec(type_spec_module.TypeSpec): def value_type(self): raise NotImplementedError -# Tensorflow tensors have a maximum dimension of 254 -# (See //tensorflow/core/framework/tensor_shape.h ) -# So we do not try to infer values for int32 tensors larger than this, -# As they cannot represent shapes. -_MAX_TENSOR_DIMS = 254 +# TODO(b/161487382): +# Special-case user-registered symbolic objects (registered by the +# private `register_symbolic_tensor_type` method) by passing them between +# scratch graphs directly. +# This is needed to not break Tensorflow probability +# while they finish migrating to composite tensors. +class UserRegisteredTypeKerasTensor(KerasTensor): + """KerasTensor that represents legacy register_symbolic_tensor_type.""" -def keras_tensor_from_tensor(x): - """Convert a traced (composite)tensor to a representative KerasTensor.""" - name = getattr(x, 'name', None) - inferred_value = None - - # TODO(b/161487382): - # Special-case user-registered symbolic objects (registered by the - # private `register_symbolic_tensor_type` method) by passing them between - # scratch graphs directly. - # This is needed to not break Tensorflow probability - # while they finish migrating to composite tensors. - user_registered_symbolic = False - try: - from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top to prevent circular imports - if isinstance(x, tuple(tf_utils._user_convertible_tensor_types)): # pylint: disable=protected-access - user_registered_symbolic = True - except ImportError: - pass - if user_registered_symbolic: + def __init__(self, user_registered_symbolic_object): + x = user_registered_symbolic_object + self._user_registered_symbolic_object = x type_spec = UserRegisteredSpec(x.shape, x.dtype) + name = getattr(x, 'name', None) + + super(UserRegisteredTypeKerasTensor, self).__init__(type_spec, name) + + @classmethod + def from_tensor(cls, tensor): + return cls(tensor) + + def _to_placeholder(self): + return self._user_registered_symbolic_object + + +class _KerasTensorIterator(object): + """Iterates over the leading dim of a KerasTensor. Performs 0 error checks.""" + + def __init__(self, tensor, dim0): + self._tensor = tensor + self._index = 0 + self._limit = dim0 + + def __iter__(self): + return self + + def __next__(self): + if self._index == self._limit: + raise StopIteration + result = self._tensor[self._index] + self._index += 1 + return result + + next = __next__ # python2.x compatibility. + + +# Specify the mappings of tensor class to KerasTensor class. +# This is specifically a list instead of a dict for now because +# 1. we do a check w/ isinstance because a key lookup based on class +# would miss subclasses +# 2. a list allows us to control lookup ordering +# We include ops.Tensor -> KerasTensor in the first position as a fastpath, +# *and* include object -> KerasTensor at the end as a catch-all. +# We can re-visit these choices in the future as needed. +keras_tensor_classes = [ + (ops.Tensor, KerasTensor), + (sparse_tensor.SparseTensor, SparseKerasTensor), + (ragged_tensor.RaggedTensor, RaggedKerasTensor), + (object, KerasTensor) +] + + +def register_keras_tensor_specialization(cls, keras_tensor_subclass): + """Register a specialized KerasTensor subclass for a Tensor type.""" + # We always leave (object, KerasTensor) at the end as a generic fallback + keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass)) + + +def keras_tensor_to_placeholder(x): + """Construct a graph placeholder to represent a KerasTensor when tracing.""" + if isinstance(x, KerasTensor): + return x._to_placeholder() # pylint: disable=protected-access else: - type_spec = type_spec_module.type_spec_from_value(x) + return x - if (isinstance(type_spec, tensor_spec.TensorSpec) - and type_spec.dtype == dtypes.int32 - and type_spec.shape.rank < 2): - # If this tensor might be representing shape information, - # (dtype=int32, rank of 0 or 1, not too large to represent a shape) - # we attempt to capture any value information tensorflow's - # shape handling can extract from the current scratch graph. - # - # Even though keras layers each trace in their own scratch - # graph, this shape value info extraction allows us to capture - # a sizable and useful subset of the C++ shape value inference TF can do - # if all tf ops appear in the same graph when using shape ops. - # - # Examples of things this cannot infer concrete dimensions for - # that the full single-graph C++ shape inference sometimes can are: - # * cases where the shape tensor is cast out of int32 before being - # manipulated w/ floating point numbers then converted back - # * cases where int32 tensors w/ rank > 2 are manipulated before being - # used as a shape tensor - # * cases where int32 tensors too large to represent shapes are manipulated - # to a smaller size before being used as a shape tensor - inferred_value = array_ops.ones(shape=x).shape - if inferred_value.dims: - inferred_value = inferred_value.as_list() - if len(inferred_value) > _MAX_TENSOR_DIMS: - inferred_value = None - else: - inferred_value = None - out = KerasTensor(type_spec, - inferred_value=inferred_value, name=name) - if user_registered_symbolic: - out._user_registered_symbolic_object = x # pylint: disable=protected-access +def keras_tensor_from_tensor(tensor): + """Convert a traced (composite)tensor to a representative KerasTensor.""" + # Create a specialized KerasTensor that supports instance methods, + # operators, and additional value inference if possible + keras_tensor_cls = None + for tensor_type, cls in keras_tensor_classes: + if isinstance(tensor, tensor_type): + keras_tensor_cls = cls + break - if hasattr(x, '_keras_mask'): - out._keras_mask = KerasTensor( # pylint: disable=protected-access - type_spec_module.type_spec_from_value(x._keras_mask)) # pylint: disable=protected-access + out = keras_tensor_cls.from_tensor(tensor) + if hasattr(tensor, '_keras_mask'): + out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access return out diff --git a/tensorflow/python/keras/engine/ragged_keras_tensor_test.py b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py new file mode 100644 index 00000000000..92abdc82240 --- /dev/null +++ b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py @@ -0,0 +1,96 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RaggedKerasTensor tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.framework import func_graph +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import layers +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import training +from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.platform import test + + +class RaggedKerasTensorTest(keras_parameterized.TestCase): + + @parameterized.parameters( + {'batch_size': None, 'shape': (None, 5), 'ragged_rank': 1}, + {'batch_size': None, 'shape': (None, 3, 5), 'ragged_rank': 1}, + {'batch_size': None, 'shape': (5, None), 'ragged_rank': 2}, + {'batch_size': None, 'shape': (3, 5, None), 'ragged_rank': 3}, + {'batch_size': None, 'shape': (None, 3, 5, None), 'ragged_rank': 4}, + {'batch_size': None, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6}, + {'batch_size': 8, 'shape': (None, 5), 'ragged_rank': 1}, + {'batch_size': 9, 'shape': (None, 3, 5), 'ragged_rank': 1}, + {'batch_size': 1, 'shape': (5, None), 'ragged_rank': 2}, + {'batch_size': 4, 'shape': (3, 5, None), 'ragged_rank': 3}, + {'batch_size': 7, 'shape': (None, 3, 5, None), 'ragged_rank': 4}, + {'batch_size': 12, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6}, + ) + def test_to_placeholder(self, shape, batch_size, ragged_rank): + with testing_utils.use_keras_tensors_scope(True): + inp = layers.Input(shape=shape, batch_size=batch_size, ragged=True) + self.assertEqual(inp.ragged_rank, ragged_rank) + self.assertAllEqual(inp.shape, [batch_size] + list(shape)) + with func_graph.FuncGraph('test').as_default(): + placeholder = inp._to_placeholder() + self.assertEqual(placeholder.ragged_rank, ragged_rank) + self.assertAllEqual(placeholder.shape, [batch_size] + list(shape)) + + def test_add(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp + inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x + x) + + def test_mul(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp * inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x * x) + + def test_sub(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp - inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x - x) + + def test_div(self): + inp = layers.Input(shape=[None], ragged=True) + out = inp / inp + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + self.assertAllEqual(model(x), x / x) + + +if __name__ == '__main__': + ops.enable_eager_execution() + tensor_shape.enable_v2_tensorshape() + test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 8e6a31a98b5..d0053599751 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -23,9 +23,13 @@ import itertools import json import os import warnings + import six from tensorflow.python.autograph.lang import directives +from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import backprop @@ -218,6 +222,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): @trackable.no_automatic_dependency_tracking def __init__(self, *args, **kwargs): + self._is_model_for_instrumentation = True base_layer.keras_api_gauge.get_cell('model').set(True) # Special case for Subclassed Functional Model, which we couldn't detect @@ -329,13 +334,13 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): never throw unexpected errors in an unrelated workflow). Args: - input_shape: Single tuple, TensorShape, or list of shapes, where shapes - are tuples, integers, or TensorShapes. + input_shape: Single tuple, TensorShape, or list/dict of shapes, where + shapes are tuples, integers, or TensorShapes. Raises: ValueError: 1. In case of invalid user-provided data (not of type tuple, - list, or TensorShape). + list, TensorShape, or dict). 2. If the model requires call arguments that are agnostic to the input shapes (positional or kwarg in call signature). 3. If not all layers were properly built. @@ -351,7 +356,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): if input_shape is None: raise ValueError('Input shape must be defined when calling build on a ' 'model subclass network.') - valid_types = (tuple, list, tensor_shape.TensorShape) + valid_types = (tuple, list, tensor_shape.TensorShape, dict) if not isinstance(input_shape, valid_types): raise ValueError('Specified input shape is not one of the valid types. ' 'Please specify a batch input shape of type tuple or ' @@ -863,7 +868,11 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): interactively (eg, in a production environment). callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during training. - See `tf.keras.callbacks`. + See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger` + and `tf.keras.callbacks.History` callbacks are created automatically + and need not be passed into `model.fit`. + `tf.keras.callbacks.ProgbarLogger` is created or not based on + `verbose` argument to `model.fit`. validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, @@ -1470,7 +1479,6 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): self.predict_function = predict_function return self.predict_function - @disable_multi_worker def predict(self, x, batch_size=None, @@ -1553,6 +1561,20 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): outputs = None with self.distribute_strategy.scope(): # Creates a `tf.data.Dataset` and handles batch and epoch iteration. + dataset_types = (dataset_ops.DatasetV1, dataset_ops.DatasetV2) + if (self._in_multi_worker_mode() or _is_tpu_multi_host( + self.distribute_strategy)) and isinstance(x, dataset_types): + try: + options = dataset_ops.Options() + data_option = AutoShardPolicy.DATA + options.experimental_distribute.auto_shard_policy = data_option + x = x.with_options(options) + except ValueError: + warnings.warn('Using Model.predict with ' + 'MultiWorkerDistributionStrategy or TPUStrategy and ' + 'AutoShardPolicy.FILE might lead to out-of-order result' + '. Consider setting it to AutoShardPolicy.DATA.') + data_handler = data_adapter.DataHandler( x=x, batch_size=batch_size, @@ -2657,6 +2679,8 @@ def reduce_per_replica(values, strategy, reduction='first'): def _reduce(v): """Reduce a single `PerReplica` object.""" + if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy): + return _multi_worker_concat(v, strategy) if not isinstance(v, ds_values.PerReplica): return v elif reduction == 'first': @@ -2701,6 +2725,42 @@ def _tpu_multi_host_concat(v, strategy): return concat(ordered_replicas) +def _collective_all_reduce_multi_worker(strategy): + return (isinstance(strategy, + collective_all_reduce_strategy.CollectiveAllReduceStrategy) + ) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access + + +# TODO(wxinyi): merge this with _tpu_multi_host_concat once we have all_gather +# for all strategies +def _multi_worker_concat(v, strategy): + """Order PerReplica objects for CollectiveAllReduceStrategy and concat.""" + replicas = strategy._gather(v, axis=0) # pylint: disable=protected-access + # v might not have the same shape on different replicas + if isinstance(v, ds_values.PerReplica): + shapes = array_ops.concat([ + array_ops.expand_dims_v2(array_ops.shape(single_value)[0], axis=0) + for single_value in v.values + ], + axis=0) + all_shapes = strategy._gather(shapes, axis=0) # pylint: disable=protected-access + else: + # v is a tensor. This may happen when, say, we have 2x1 multi-worker. + all_shapes = strategy._gather( # pylint: disable=protected-access + array_ops.expand_dims_v2(array_ops.shape(v)[0], axis=0), + axis=0) + + replicas = array_ops.split( + replicas, + num_or_size_splits=all_shapes, + num=strategy.num_replicas_in_sync) + ordered_replicas = [] + num_replicas_per_worker = len(strategy.extended.worker_devices) + for replica_id in range(num_replicas_per_worker): + ordered_replicas += replicas[replica_id::num_replicas_per_worker] + return concat(ordered_replicas) + + def _is_scalar(x): return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 diff --git a/tensorflow/python/keras/engine/training_arrays_v1.py b/tensorflow/python/keras/engine/training_arrays_v1.py index 4cdbcc9a02f..ad9b37f569e 100644 --- a/tensorflow/python/keras/engine/training_arrays_v1.py +++ b/tensorflow/python/keras/engine/training_arrays_v1.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import errors from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras.distribute import distributed_training_utils_v1 -from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.utils.generic_utils import make_batches from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.keras.utils.mode_keys import ModeKeys @@ -139,7 +139,7 @@ def model_iteration(model, if is_dataset: if steps_per_epoch is None: reset_dataset_after_each_epoch = True - steps_per_epoch = training_utils.infer_steps_for_dataset( + steps_per_epoch = training_utils_v1.infer_steps_for_dataset( model, inputs, steps_per_epoch, epochs=epochs, steps_name=steps_name) input_iterator = _get_iterator(inputs, model._distribution_strategy) @@ -153,10 +153,6 @@ def model_iteration(model, use_steps = is_dataset or steps_per_epoch is not None do_validation = val_inputs is not None - # Convert Eager Tensors to NumPy arrays to support batching/shuffling. - inputs, targets, sample_weights = training_utils. \ - convert_eager_tensors_to_numpy((inputs, targets, sample_weights)) - # Prepare input data. inputs = input_iterator or inputs if validation_in_fit and prepared_feed_values_from_dataset: @@ -197,7 +193,7 @@ def model_iteration(model, # model_iteration() call, it will not trigger the dataset-input path # that determines the number of steps required. To avoid this issue, # set validation_steps here if validation_steps is None. - validation_steps = training_utils.infer_steps_for_dataset( + validation_steps = training_utils_v1.infer_steps_for_dataset( model, val_inputs, validation_steps, @@ -240,12 +236,12 @@ def model_iteration(model, # Select aggregation method. if mode == ModeKeys.PREDICT: - aggregator = training_utils.OutputsAggregator( + aggregator = training_utils_v1.OutputsAggregator( use_steps, num_samples=None if steps_per_epoch else num_samples_or_steps, steps=steps_per_epoch) else: - aggregator = training_utils.MetricsAggregator( + aggregator = training_utils_v1.MetricsAggregator( use_steps, num_samples=None if steps_per_epoch else num_samples_or_steps, steps=steps_per_epoch) @@ -350,7 +346,7 @@ def model_iteration(model, # Sample-wise loop. index_array = np.arange(num_samples_or_steps) if shuffle == 'batch': - index_array = training_utils.batch_shuffle(index_array, batch_size) + index_array = training_utils_v1.batch_shuffle(index_array, batch_size) elif shuffle: np.random.shuffle(index_array) batches = make_batches(num_samples_or_steps, batch_size) @@ -409,7 +405,7 @@ def model_iteration(model, # Run the test loop every `validation_freq` epochs during training. if (do_validation and - training_utils.should_run_validation(validation_freq, epoch) and + training_utils_v1.should_run_validation(validation_freq, epoch) and not callbacks.model.stop_training): if model._compile_distribution: @@ -483,8 +479,8 @@ def _get_num_samples_or_steps(ins, batch_size, steps_per_epoch): """Returns total number of samples (when training in batch mode) or steps.""" if steps_per_epoch: return steps_per_epoch - return training_utils.check_num_samples(ins, batch_size, steps_per_epoch, - 'steps_per_epoch') + return training_utils_v1.check_num_samples(ins, batch_size, steps_per_epoch, + 'steps_per_epoch') def _prepare_feed_values(model, inputs, targets, sample_weights, mode): @@ -527,7 +523,7 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode): inputs, extract_tensors_from_dataset=True) - inputs = training_utils.ModelInputs(inputs).as_list() + inputs = training_utils_v1.ModelInputs(inputs).as_list() targets = list(targets or []) sample_weights = list(sample_weights or []) ins = inputs + targets + sample_weights @@ -541,7 +537,7 @@ def _get_iterator(inputs, distribution_strategy=None): if distribution_strategy: return distributed_training_utils_v1.get_iterator( inputs, distribution_strategy) - return training_utils.get_iterator(inputs) + return training_utils_v1.get_iterator(inputs) def _reinitialize_iterator(iterator, distribution_strategy=None): @@ -549,7 +545,7 @@ def _reinitialize_iterator(iterator, distribution_strategy=None): distributed_training_utils_v1.initialize_iterator( iterator, distribution_strategy) else: - training_utils.initialize_iterator(iterator) + training_utils_v1.initialize_iterator(iterator) def _make_execution_function(model, mode): @@ -593,7 +589,7 @@ predict_loop = functools.partial( model_iteration, mode=ModeKeys.PREDICT, shuffle=False) -class ArrayLikeTrainingLoop(training_utils.TrainingLoop): +class ArrayLikeTrainingLoop(training_utils_v1.TrainingLoop): """TrainingLoop that handle inputs like array. This is the default handler for most of the input data types, includes @@ -639,9 +635,9 @@ class ArrayLikeTrainingLoop(training_utils.TrainingLoop): val_x, val_y, val_sample_weights = model._prepare_validation_data( validation_data, batch_size, validation_steps) elif validation_split and 0. < validation_split < 1.: - (x, y, sample_weights, val_x, val_y, - val_sample_weights) = training_utils.split_training_and_validation_data( - x, y, sample_weights, validation_split) + (x, y, sample_weights, val_x, val_y, val_sample_weights + ) = training_utils_v1.split_training_and_validation_data( + x, y, sample_weights, validation_split) else: if validation_steps: raise ValueError('`validation_steps` should not be specified if ' diff --git a/tensorflow/python/keras/engine/training_distributed_v1.py b/tensorflow/python/keras/engine/training_distributed_v1.py index 22fc19c4f62..0e73b21adc1 100644 --- a/tensorflow/python/keras/engine/training_distributed_v1.py +++ b/tensorflow/python/keras/engine/training_distributed_v1.py @@ -35,7 +35,7 @@ from tensorflow.python.keras.distribute import distributed_training_utils as dis from tensorflow.python.keras.distribute import distributed_training_utils_v1 as dist_utils from tensorflow.python.keras.engine import partial_batch_padding_handler as padding_util from tensorflow.python.keras.engine import training_arrays_v1 -from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops @@ -258,7 +258,7 @@ def experimental_tpu_fit_loop(model, break if (do_validation and - training_utils.should_run_validation(validation_freq, epoch)): + training_utils_v1.should_run_validation(validation_freq, epoch)): logging.info('Running validation at fit epoch: %s', epoch) if model._compile_distribution: @@ -386,8 +386,9 @@ def experimental_tpu_test_loop(model, try: _, batch_outs = K.batch_get_value([test_op, output_tensors]) except errors.OutOfRangeError: - warning_msg = 'Make sure that your dataset can generate at least ' - '`steps` batches (in this case, {} batches).'.format(steps) + warning_msg = ( + 'Make sure that your dataset can generate at least ' + '`steps` batches (in this case, {} batches).'.format(steps)) logging.warning('Your dataset iterator ran out of data; ' 'interrupting evaluation. ' + warning_msg) @@ -533,8 +534,9 @@ def experimental_tpu_predict_loop(model, _, batch_outs = K.batch_get_value([predict_ops, output_tensors]) except errors.OutOfRangeError: - warning_msg = 'Make sure that your dataset can generate at least ' - '`steps` batches (in this case, {} batches).'.format(steps) + warning_msg = ( + 'Make sure that your dataset can generate at least ' + '`steps` batches (in this case, {} batches).'.format(steps)) logging.warning('Your dataset iterator ran out of data; ' 'interrupting evaluation. ' + warning_msg) @@ -575,7 +577,7 @@ def experimental_tpu_predict_loop(model, return prediction_result -class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): +class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop): """Training loop for distribution strategy with single worker.""" def fit(self, @@ -630,8 +632,8 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): val_dataset = None if validation_data: - val_x, val_y, val_sample_weights = training_utils.unpack_validation_data( - validation_data) + val_x, val_y, val_sample_weights = ( + training_utils_v1.unpack_validation_data(validation_data)) dist_utils.validate_inputs(val_x, val_y) _, validation_steps = dist_utils.process_batch_and_step_size( model._distribution_strategy, val_x, batch_size, validation_steps, @@ -650,7 +652,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): 'distribution strategies.') if dist_utils_v2.is_tpu_strategy(model._distribution_strategy): - steps_per_epoch = training_utils.infer_steps_for_dataset( + steps_per_epoch = training_utils_v1.infer_steps_for_dataset( model, dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch') if steps_per_epoch is None: raise ValueError('Number of steps could not be inferred from the data, ' @@ -707,7 +709,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): allow_partial_batch=True) if dist_utils_v2.is_tpu_strategy(model._distribution_strategy): - steps = training_utils.infer_steps_for_dataset( + steps = training_utils_v1.infer_steps_for_dataset( model, dataset, steps, steps_name='steps') if steps is None: raise ValueError('Number of steps could not be inferred from the data, ' @@ -744,7 +746,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): batch_size=batch_size, allow_partial_batch=True) if dist_utils_v2.is_tpu_strategy(model._distribution_strategy): - steps = training_utils.infer_steps_for_dataset( + steps = training_utils_v1.infer_steps_for_dataset( model, dataset, steps, steps_name='steps') if steps is None: raise ValueError('Number of steps could not be inferred from the data, ' @@ -780,7 +782,7 @@ def _train_with_multi_worker(method): return wrapper -class DistributionMultiWorkerTrainingLoop(training_utils.TrainingLoop): +class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop): """Training loop for distribution strategy with multiple worker.""" def __init__(self, single_worker_loop): diff --git a/tensorflow/python/keras/engine/training_eager_v1.py b/tensorflow/python/keras/engine/training_eager_v1.py index 09e6f0d1edd..2acd7493cb0 100644 --- a/tensorflow/python/keras/engine/training_eager_v1.py +++ b/tensorflow/python/keras/engine/training_eager_v1.py @@ -25,6 +25,7 @@ from tensorflow.python.eager.backprop import GradientTape from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.utils import losses_utils from tensorflow.python.ops import math_ops @@ -127,11 +128,12 @@ def _model_loss(model, outs = nest.flatten(outs) if targets: - targets = training_utils.cast_if_floating_dtype_and_mismatch(targets, outs) + targets = training_utils_v1.cast_if_floating_dtype_and_mismatch( + targets, outs) # TODO(sallymatson/psv): check if we should do same mismatch fix for weights if sample_weights: sample_weights = [ - training_utils.cast_if_floating_dtype( + training_utils_v1.cast_if_floating_dtype( ops.convert_to_tensor_v2_with_dispatch(val)) if val is not None else None for val in sample_weights ] @@ -304,7 +306,7 @@ def train_on_batch(model, model output. Could be a empty list when model has only one output. 'metrics': list of tensors for metric specified. """ - inputs = training_utils.cast_to_model_input_dtypes(inputs, model) + inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model) outs, total_loss, output_losses, masks = ( _process_single_batch( model, @@ -345,7 +347,7 @@ def test_on_batch(model, model output. Could be a empty list when model has only one output. 'metrics': list of tensors for metric specified. """ - inputs = training_utils.cast_to_model_input_dtypes(inputs, model) + inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model) with backend.eager_learning_phase_scope(0): outs, total_loss, output_losses, masks = ( diff --git a/tensorflow/python/keras/engine/training_generator_v1.py b/tensorflow/python/keras/engine/training_generator_v1.py index 1fcf3ef25e4..9b6fc1577bb 100644 --- a/tensorflow/python/keras/engine/training_generator_v1.py +++ b/tensorflow/python/keras/engine/training_generator_v1.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.keras import backend from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys @@ -132,7 +133,7 @@ def model_iteration(model, original_dataset = data if steps_per_epoch is None: reset_dataset_after_each_epoch = True - steps_per_epoch = training_utils.infer_steps_for_dataset( + steps_per_epoch = training_utils_v1.infer_steps_for_dataset( model, data, steps_per_epoch, epochs=epochs, steps_name=steps_name) # Convert to a format that supports `next(generator)`. @@ -179,9 +180,11 @@ def model_iteration(model, mode=mode) if mode == ModeKeys.PREDICT: - aggregator = training_utils.OutputsAggregator(True, steps=steps_per_epoch) + aggregator = training_utils_v1.OutputsAggregator( + True, steps=steps_per_epoch) else: - aggregator = training_utils.MetricsAggregator(True, steps=steps_per_epoch) + aggregator = training_utils_v1.MetricsAggregator( + True, steps=steps_per_epoch) should_set_learning_phase = context.executing_eagerly() and model.run_eagerly if should_set_learning_phase: @@ -293,7 +296,7 @@ def model_iteration(model, # Run the test loop every epoch during training. if (do_validation and - training_utils.should_run_validation(validation_freq, epoch) and + training_utils_v1.should_run_validation(validation_freq, epoch) and not callbacks.model.stop_training): val_results = model_iteration( model, @@ -538,7 +541,7 @@ def _get_num_samples_or_steps(data, steps_per_epoch): return steps_per_epoch, True -class GeneratorOrSequenceTrainingLoop(training_utils.TrainingLoop): +class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop): """Generator-like. Input is Python generator, or Sequence object. @@ -569,7 +572,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils.TrainingLoop): workers=1, use_multiprocessing=False): model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) - training_utils.check_generator_arguments( + training_utils_v1.check_generator_arguments( y, sample_weight, validation_split=validation_split) return fit_generator( model, @@ -602,7 +605,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils.TrainingLoop): workers=1, use_multiprocessing=False): model._validate_or_infer_batch_size(batch_size, steps, x) - training_utils.check_generator_arguments(y, sample_weight) + training_utils_v1.check_generator_arguments(y, sample_weight) return evaluate_generator( model, x, @@ -635,7 +638,7 @@ class GeneratorOrSequenceTrainingLoop(training_utils.TrainingLoop): use_multiprocessing=use_multiprocessing) -class EagerDatasetOrIteratorTrainingLoop(training_utils.TrainingLoop): +class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop): """A non-distributed Dataset or iterator in eager execution.""" def fit(self, @@ -658,10 +661,11 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils.TrainingLoop): **kwargs): model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x) # Make sure that y, sample_weights, validation_split are not passed. - training_utils.validate_dataset_input(x, y, sample_weight, validation_split) + training_utils_v1.validate_dataset_input(x, y, sample_weight, + validation_split) if (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) and shuffle): - training_utils.verify_dataset_shuffled(x) + training_utils_v1.verify_dataset_shuffled(x) return fit_generator( model, @@ -691,7 +695,7 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils.TrainingLoop): **kwargs): model._validate_or_infer_batch_size(batch_size, steps, x) # Make sure that y, sample_weights, validation_split are not passed. - training_utils.validate_dataset_input(x, y, sample_weight) + training_utils_v1.validate_dataset_input(x, y, sample_weight) return evaluate_generator( model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) @@ -708,7 +712,7 @@ class EagerDatasetOrIteratorTrainingLoop(training_utils.TrainingLoop): model, x, steps=steps, verbose=verbose, workers=0, callbacks=callbacks) -class GeneratorLikeTrainingLoop(training_utils.TrainingLoop): +class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop): """TrainingLoop that handle inputs like python generator. This is the default handler for most of the input data types, includes @@ -755,8 +759,9 @@ class GeneratorLikeTrainingLoop(training_utils.TrainingLoop): validation_steps) elif validation_split and 0. < validation_split < 1.: (x, y, sample_weights, val_x, val_y, - val_sample_weights) = training_utils.split_training_and_validation_data( - x, y, sample_weights, validation_split) + val_sample_weights) = ( + training_utils_v1.split_training_and_validation_data( + x, y, sample_weights, validation_split)) validation_data = (val_x, val_y, val_sample_weights) else: if validation_steps: diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 1f8f8cb1b52..6a833560cff 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -44,7 +44,7 @@ from tensorflow.python.keras.callbacks import Callback from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.engine import training as training_module -from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import np_utils from tensorflow.python.ops import array_ops @@ -2019,7 +2019,7 @@ class LossWeightingTest(keras_parameterized.TestCase): [[0, .4, 1, 1], [2, .4, .3, 1]]) dataset = dataset_ops.Dataset.from_tensor_slices(sample_weights) sample_weights = dataset_ops.make_one_shot_iterator(dataset).get_next() - sample_weights = training_utils.standardize_sample_weights( + sample_weights = training_utils_v1.standardize_sample_weights( sample_weights, model.output_names) # Update model loss with sample weight tensor. @@ -3707,6 +3707,22 @@ class TestBuildCustomModel(keras_parameterized.TestCase): model.build([None, 1]) self.assertEqual(model.l1.kernel.shape.as_list(), [1, 1]) + @keras_parameterized.run_all_keras_modes + def test_build_dict_inputs(self): + + class MyModel(training_module.Model): + + def __init__(self): + super(MyModel, self).__init__() + self.l1 = layers_module.Dense(1) + + def call(self, inputs): + return self.l1(inputs['x']) + + model = MyModel() + model.build({'x': [None, 16]}) + self.assertEqual(model.l1.kernel.shape.as_list(), [16, 1]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 1df48401f33..4180c0b7e1d 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -17,350 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc -import atexit -from collections import OrderedDict -import functools -import multiprocessing.pool -import threading -import time - import numpy as np -import six -from six.moves import zip # pylint: disable=redefined-builtin -from tensorflow.core.framework import graph_pb2 -from tensorflow.python import tf2 -from tensorflow.python.data.experimental.ops import cardinality -from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.eager import context -from tensorflow.python.framework import composite_tensor_utils -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util -from tensorflow.python.keras import backend as K -from tensorflow.python.keras import callbacks as cbks -from tensorflow.python.keras import losses -from tensorflow.python.keras import metrics as metrics_module -from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.keras.utils import losses_utils -from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -from tensorflow.python.util.compat import collections_abc - - -@six.add_metaclass(abc.ABCMeta) -class Aggregator(object): - """Abstract base class used to aggregate batch-level outputs of a loop. - - Attributes: - use_steps: Whether the loop is using `step` or `batch_size`. - num_samples: Total number of samples: `batch_size * num_batches`. - steps: Total number of steps. - batch_size: Batch size. It is used for validation checks between inputs and - outputs. - results: What to return at the end of the aggregation loop. - """ - - def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None): - self.use_steps = use_steps - self.num_samples = num_samples - self.steps = steps - self.batch_size = batch_size - self.results = [] - - @abc.abstractmethod - def create(self, batch_outs): - """Creates the initial results from the first batch outputs. - - Arguments: - batch_outs: A list of batch-level outputs. - """ - raise NotImplementedError('Must be implemented in subclasses.') - - @abc.abstractmethod - def aggregate(self, batch_outs, batch_start=None, batch_end=None): - """Aggregates batch-level results into total results. - - Arguments: - batch_outs: A list of batch-level outputs. - batch_start: The start index of this batch. Always `None` if `use_steps` - is `True`. - batch_end: The end index of this batch. Always `None` if `use_steps` is - `True`. - """ - raise NotImplementedError('Must be implemented in subclasses.') - - @abc.abstractmethod - def finalize(self): - """Prepares the total results to be returned.""" - raise NotImplementedError('Must be implemented in subclasses.') - - -class MetricsAggregator(Aggregator): - """Aggregator that calculates loss and metrics info. - - Attributes: - use_steps: Whether the loop is using `step` or `batch_size`. - num_samples: Total number of samples: `batch_size*num_batches`. - steps: Total number of steps, ie number of times to iterate over a dataset - to cover all samples. - """ - - def __init__(self, use_steps, num_samples=None, steps=None): - super(MetricsAggregator, self).__init__( - use_steps=use_steps, - num_samples=num_samples, - steps=steps, - batch_size=None) - - def create(self, batch_outs): - self.results = [0.] * len(batch_outs) - - def aggregate(self, batch_outs, batch_start=None, batch_end=None): - # Loss. - if self.use_steps: - self.results[0] += batch_outs[0] - else: - self.results[0] += batch_outs[0] * (batch_end - batch_start) - # Metrics (always stateful, just grab current values.) - self.results[1:] = batch_outs[1:] - - def finalize(self): - if not self.results: - raise ValueError('Empty training data.') - self.results[0] /= (self.num_samples or self.steps) - - -class ConcatAggregator(Aggregator): - """Combine tensor-likes which cannot be merged on the fly. - - This class expects to aggregate a single tensor-like rather than a nested - structure of tensor-likes. - """ - - def __init__(self, batch_size): - self.composite = None - super(ConcatAggregator, self).__init__( - use_steps=True, num_samples=None, steps=None, batch_size=batch_size) - - def create(self, batch_element): - self.composite = composite_tensor_utils.is_composite_or_composite_value( - batch_element) - - def aggregate(self, batch_element, batch_start=None, batch_end=None): - - # TODO(psv): Add num_samples check here to detect when output batch - # #samples is < batch size and != input batch #samples. - if self.batch_size and self.batch_size < batch_element.shape[0]: - raise ValueError( - 'Mismatch between expected batch size and model output batch size. ' - 'Output shape = {}, expected output shape = shape {}'.format( - batch_element.shape, - (self.batch_size,) + batch_element.shape[1:])) - self.results.append(batch_element) - - def finalize(self): - # Special case of single batch inference which skips a copy. - if len(self.results) == 1: - self.results = self.results[0] - - elif self.composite: - # TODO(taylorrobie): efficiently concatenate. - results = self.results[0] - for r in self.results[1:]: - results = composite_tensor_utils.append_composite_tensor(results, r) - self.results = results - - else: - self.results = np.concatenate(self.results, axis=0) - - if isinstance(self.results, ops.EagerTensor): - self.results = self.results._numpy() # pylint: disable=protected-access - - -_COPY_THREADS = 4 -_COPY_POOL = None - - -def get_copy_pool(): - """Shared threadpool for copying arrays. - - Pool instantiation takes ~ 2ms, so a singleton pool is used rather than - creating a pool per SliceAggregator. - - Returns: - The global copy threadpool. - """ - global _COPY_POOL - if _COPY_POOL is None: - _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS) - atexit.register(_COPY_POOL.close) - return _COPY_POOL - - -class SliceAggregator(Aggregator): - """Combine arrays where the final size is known. - - This class expects to aggregate a single tensor-like rather than a nested - structure of tensor-likes. - - NumPy copies are an operation that threads handle quite well because all of - the heavy lifting is in c and does not need the GIL. Moreover, we can perform - lock-free writes to the same buffer in multiple threads because the nature of - result aggregation guarantees that either the indices are disjoint or the - aggregator will throw an exception in finalize. Moreover, because aggregation - is performed on the slowest varying dimension, assignments for a given batch - will write to contiguous blocks of memory, further minimizing contention. - - There is, however, some scheduling and context switching overhead which will - offset the gains from pipelining the slice assignment. Below a given threshold - it is faster to simply assign in the main thread rather than enqueue the - assignment in a side thread. The exact threshold will vary from system to - system, but the time is not very sensitive to the exact transition so a value - of 2 ** 14 was chosen which should be reasonable on most systems. - """ - - _BINARY_SIZE_THRESHOLD = 2 ** 14 - _MAX_COPY_SECONDS = 300 - - def __init__(self, num_samples, batch_size): - self._async_copies = [] - self._pool = get_copy_pool() - self._errors = [] - super(SliceAggregator, self).__init__( - use_steps=False, - num_samples=num_samples, - steps=None, - batch_size=batch_size) - - def create(self, batch_element): - # This step does not need to be pipelined because NumPy empty array - # initialization is effectively instantaneous. - shape = (self.num_samples,) + batch_element.shape[1:] - dtype = batch_element.dtype - if isinstance(batch_element, ops.EagerTensor): - dtype = dtype.as_numpy_dtype - - self.results = np.empty(shape=shape, dtype=dtype) - - def aggregate(self, batch_element, batch_start, batch_end): - # Fail early. - if self._errors: - six.reraise(type(self._errors[0]), self._errors[0]) - - # In the special case of single batch inference, no copy is needed. - if batch_end - batch_start == self.num_samples: - if self.num_samples != batch_element.shape[0]: - raise ValueError( - 'Mismatch between expected batch size and model output batch size. ' - 'Output shape = {}, expected output shape = shape {}'.format( - batch_element.shape, self.results.shape)) - - self.results = batch_element - return - - # This is an approximate threshold, so we don't need to consider the number - # of bytes per element. - num_elements = np.prod(batch_element.shape) - if num_elements < self._BINARY_SIZE_THRESHOLD: - self.results[batch_start:batch_end] = batch_element - else: - is_finished = threading.Event() - self._pool.apply_async( - self._slice_assign, - args=(batch_element, batch_start, batch_end, is_finished)) - self._async_copies.append(is_finished) - - def _slice_assign(self, batch_element, batch_start, batch_end, is_finished): - try: - self.results[batch_start:batch_end] = batch_element - - except Exception as e: # pylint: disable=broad-except - # `_slice_assign` should only be called in threads and exceptions raised - # in threads do not carry over to the main thread. So instead we perform a - # a broad catch in the thread and then store the exception to be re-raised - # in the main thread. - self._errors.append(e) - - finally: - is_finished.set() - - def finalize(self): - start_time = time.time() - for is_finished in self._async_copies: - timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)]) - if not is_finished.wait(timeout): - raise ValueError('Timed out waiting for copy to complete.') - - if self._errors: - six.reraise(self._errors[0].__class__, self._errors[0]) - - -class OutputsAggregator(Aggregator): - """Aggregator that concatenates outputs.""" - - _structure = None - - def create(self, batch_outs): - # SparseTensorValue is a named tuple which nest will flatten, so we need - # to guard it to properly handle the structure. - self._structure = nest.get_traverse_shallow_structure( - lambda x: not composite_tensor_utils.is_composite_or_composite_value(x), - batch_outs) - batch_outs = nest.flatten_up_to(self._structure, batch_outs) - - for batch_element in batch_outs: - if composite_tensor_utils.is_composite_or_composite_value(batch_element): - # If the output is not a ndarray, it will be either a composite tensor - # or a composite tensor's Value object. In either case, we can't - # allocate an array to hold the object - we'll handle it later. - self.results.append(ConcatAggregator(self.batch_size)) - elif isinstance(batch_element, (np.ndarray, ops.EagerTensor)): - self.results.append( - (ConcatAggregator(self.batch_size) if self.use_steps else - SliceAggregator(self.num_samples, self.batch_size))) - else: - # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue. - # Fail fast rather than trying to concatenate it. - raise RuntimeError('Attempted to aggregate unsupported object {}.' - .format(batch_element)) - - self.results[-1].create(batch_element) - - def aggregate(self, batch_outs, batch_start=None, batch_end=None): - batch_outs = nest.flatten_up_to(self._structure, batch_outs) - for batch_element, result in zip(batch_outs, self.results): - result.aggregate(batch_element, batch_start, batch_end) - - def finalize(self): - for result in self.results: - result.finalize() - self.results = [i.results for i in self.results] - self.results = nest.pack_sequence_as(self._structure, self.results) - - -def get_progbar(model, count_mode, include_metrics=True): - """Get Progbar.""" - if include_metrics: - stateful_metric_names = getattr(model, 'metrics_names', None) - if stateful_metric_names: - stateful_metric_names = stateful_metric_names[1:] # Exclude `loss` - else: - stateful_metric_names = None - return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names) def slice_arrays(arrays, indices, contiguous=True): @@ -399,245 +62,6 @@ def slice_arrays(arrays, indices, contiguous=True): return slices -def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'): - """Determine the number of samples provided for training and evaluation. - - The number of samples is not defined when running with `steps`, - in which case the number of samples is set to `None`. - - Arguments: - ins: List of tensors to be fed to the Keras function. - batch_size: Integer batch size or `None` if not defined. - steps: Total number of steps (batches of samples) before declaring - `_predict_loop` finished. Ignored with the default value of `None`. - steps_name: The public API's parameter name for `steps`. - - Raises: - ValueError: when `steps` is `None` and the attribute `ins.shape` - does not exist. Also raises ValueError when `steps` is not `None` - and `batch_size` is not `None` because they are mutually - exclusive. - - Returns: - When steps is `None`, returns the number of samples to be - processed based on the size of the first dimension of the - first input numpy array. When steps is not `None` and - `batch_size` is `None`, returns `None`. - """ - if steps is not None and batch_size is not None: - raise ValueError('If ' + steps_name + - ' is set, the `batch_size` must be None.') - if check_steps_argument(ins, steps, steps_name): - return None - - if hasattr(ins[0], 'shape'): - return int(ins[0].shape[0]) - return None # Edge case where ins == [static_learning_phase] - - -def standardize_single_array(x, expected_shape=None): - """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1.""" - if x is None: - return None - - if composite_tensor_utils.is_composite_or_composite_value(x): - return x - - if isinstance(x, int): - raise ValueError( - 'Expected an array data type but received an integer: {}'.format(x)) - - if (x.shape is not None and len(x.shape) == 1 and - (expected_shape is None or len(expected_shape) != 1)): - if tensor_util.is_tensor(x): - x = array_ops.expand_dims(x, axis=1) - else: - x = np.expand_dims(x, 1) - return x - - -def standardize_input_data(data, - names, - shapes=None, - check_batch_axis=True, - exception_prefix=''): - """Normalizes inputs and targets provided by users. - - Users may pass data as a list of arrays, dictionary of arrays, - or as a single array. We normalize this to an ordered list of - arrays (same order as `names`), while checking that the provided - arrays have shapes that match the network's expectations. - - Arguments: - data: User-provided input data (polymorphic). - names: List of expected array names. - shapes: Optional list of expected array shapes. - check_batch_axis: Boolean; whether to check that the batch axis of the - arrays matches the expected value found in `shapes`. - exception_prefix: String prefix used for exception formatting. - - Returns: - List of standardized input arrays (one array per model input). - - Raises: - ValueError: in case of improperly formatted user-provided data. - """ - try: - data_len = len(data) - except TypeError: - # For instance if data is `None` or a symbolic Tensor. - data_len = None - - if not names: - if data_len and not isinstance(data, dict): - raise ValueError( - 'Error when checking model ' + exception_prefix + ': ' - 'expected no data, but got:', data) - return [] - if data is None: - return [None for _ in range(len(names))] - - if isinstance(data, dict): - try: - data = [ - data[x].values - if data[x].__class__.__name__ == 'DataFrame' else data[x] - for x in names - ] - except KeyError as e: - raise ValueError('No data provided for "' + e.args[0] + '". Need data ' - 'for each key in: ' + str(names)) - elif isinstance(data, (list, tuple)): - if isinstance(data[0], (list, tuple)): - data = [np.asarray(d) for d in data] - elif len(names) == 1 and isinstance(data[0], (float, int)): - data = [np.asarray(data)] - else: - data = [ - x.values if x.__class__.__name__ == 'DataFrame' else x for x in data - ] - else: - data = data.values if data.__class__.__name__ == 'DataFrame' else data - data = [data] - - if shapes is not None: - data = [ - standardize_single_array(x, shape) for (x, shape) in zip(data, shapes) - ] - else: - data = [standardize_single_array(x) for x in data] - - if len(data) != len(names): - if data and hasattr(data[0], 'shape'): - raise ValueError('Error when checking model ' + exception_prefix + - ': the list of Numpy arrays that you are passing to ' - 'your model is not the size the model expected. ' - 'Expected to see ' + str(len(names)) + ' array(s), ' + - 'for inputs ' + str(names) + ' but instead got the ' - 'following list of ' + str(len(data)) + ' arrays: ' + - str(data)[:200] + '...') - elif len(names) > 1: - raise ValueError('Error when checking model ' + exception_prefix + - ': you are passing a list as input to your model, ' - 'but the model expects a list of ' + str(len(names)) + - ' Numpy arrays instead. The list you passed was: ' + - str(data)[:200]) - elif len(data) == 1 and not hasattr(data[0], 'shape'): - raise TypeError('Error when checking model ' + exception_prefix + - ': data should be a Numpy array, or list/dict of ' - 'Numpy arrays. Found: ' + str(data)[:200] + '...') - elif len(names) == 1: - data = [np.asarray(data)] - - # Check shapes compatibility. - if shapes: - for i in range(len(names)): - if shapes[i] is not None: - if tensor_util.is_tensor(data[i]): - tensorshape = data[i].shape - if not tensorshape: - continue - data_shape = tuple(tensorshape.as_list()) - elif composite_tensor_utils.is_composite_or_composite_value(data[i]): - tensorshape = composite_tensor_utils.get_shape(data[i]) - data_shape = tuple(tensorshape.as_list()) - else: - data_shape = data[i].shape - - shape = shapes[i] - if len(data_shape) != len(shape): - raise ValueError('Error when checking ' + exception_prefix + - ': expected ' + names[i] + ' to have ' + - str(len(shape)) + ' dimensions, but got array ' - 'with shape ' + str(data_shape)) - if not check_batch_axis: - data_shape = data_shape[1:] - shape = shape[1:] - for dim, ref_dim in zip(data_shape, shape): - if ref_dim != dim and ref_dim is not None and dim is not None: - raise ValueError('Error when checking ' + exception_prefix + - ': expected ' + names[i] + ' to have shape ' + - str(shape) + ' but got array with shape ' + - str(data_shape)) - return data - - -def standardize_sample_or_class_weights(x_weight, output_names, weight_type): - """Maps `sample_weight` or `class_weight` to model outputs. - - Arguments: - x_weight: User-provided `sample_weight` or `class_weight` argument. - output_names: List of output names (strings) in the model. - weight_type: A string used purely for exception printing. - - Returns: - A list of `sample_weight` or `class_weight` where there are exactly - one element per model output. - - Raises: - ValueError: In case of invalid user-provided argument. - """ - if x_weight is None or (isinstance(x_weight, (list, tuple)) and - len(x_weight) == 0): # pylint: disable=g-explicit-length-test - return [None for _ in output_names] - if len(output_names) == 1: - if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1: - return x_weight - if isinstance(x_weight, dict) and output_names[0] in x_weight: - return [x_weight[output_names[0]]] - else: - return [x_weight] - if isinstance(x_weight, (list, tuple)): - if len(x_weight) != len(output_names): - raise ValueError('Provided `' + weight_type + '` was a list of ' + - str(len(x_weight)) + ' elements, but the model has ' + - str(len(output_names)) + ' outputs. ' - 'You should provide one `' + weight_type + '`' - 'array per model output.') - return x_weight - if isinstance(x_weight, collections_abc.Mapping): - generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names) - x_weights = [] - for name in output_names: - x_weights.append(x_weight.get(name)) - return x_weights - else: - raise TypeError('The model has multiple outputs, so `' + weight_type + '` ' - 'should be either a list or a dict. ' - 'Provided `' + weight_type + '` type not understood: ' + - str(x_weight)) - - -def standardize_class_weights(class_weight, output_names): - return standardize_sample_or_class_weights(class_weight, output_names, - 'class_weight') - - -def standardize_sample_weights(sample_weight, output_names): - return standardize_sample_or_class_weights(sample_weight, output_names, - 'sample_weight') - - def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes, check_all_flat=False): """Adds 1.0 as sample weights for the outputs for which there is no weight. @@ -697,506 +121,6 @@ def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes, any_sample_weight, partial_sample_weight) -def check_array_lengths(inputs, targets, weights=None): - """Does user input validation for numpy arrays. - - Arguments: - inputs: list of Numpy arrays of inputs. - targets: list of Numpy arrays of targets. - weights: list of Numpy arrays of sample weights. - - Raises: - ValueError: in case of incorrectly formatted data. - """ - - def is_tensor_or_composite_tensor(x): - return tensor_util.is_tensor( - x) or composite_tensor_utils.is_composite_or_composite_value(x) - - def set_of_lengths(x): - # Returns a set with the variation between - # different shapes, with None => 0 - if x is None: - return {} - else: - return set([ - y.shape[0] - for y in x - if y is not None and not is_tensor_or_composite_tensor(y) - ]) - - set_x = set_of_lengths(inputs) - set_y = set_of_lengths(targets) - set_w = set_of_lengths(weights) - if len(set_x) > 1: - raise ValueError('All input arrays (x) should have ' - 'the same number of samples. Got array shapes: ' + - str([x.shape for x in inputs])) - if len(set_y) > 1: - raise ValueError('All target arrays (y) should have ' - 'the same number of samples. Got array shapes: ' + - str([y.shape for y in targets])) - if set_x and set_y and list(set_x)[0] != list(set_y)[0]: - raise ValueError('Input arrays should have ' - 'the same number of samples as target arrays. ' - 'Found ' + str(list(set_x)[0]) + ' input samples ' - 'and ' + str(list(set_y)[0]) + ' target samples.') - if len(set_w) > 1: - raise ValueError('All sample_weight arrays should have ' - 'the same number of samples. Got array shapes: ' + - str([w.shape for w in weights])) - if set_y and set_w and list(set_y)[0] != list(set_w)[0]: - raise ValueError('Sample_weight arrays should have ' - 'the same number of samples as target arrays. Got ' + - str(list(set_y)[0]) + ' input samples and ' + - str(list(set_w)[0]) + ' target samples.') - - -def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): - """Does validation on the compatibility of targets and loss functions. - - This helps prevent users from using loss functions incorrectly. This check - is purely for UX purposes. - - Arguments: - targets: list of Numpy arrays of targets. - loss_fns: list of loss functions. - output_shapes: list of shapes of model outputs. - - Raises: - ValueError: if a loss function or target array - is incompatible with an output. - """ - key_loss_fns = { - losses.mean_squared_error, losses.binary_crossentropy, - losses.categorical_crossentropy - } - key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy, - losses.CategoricalCrossentropy) - for y, loss, shape in zip(targets, loss_fns, output_shapes): - if y is None or loss is None or tensor_util.is_tensor(y): - continue - if losses.is_categorical_crossentropy(loss): - if y.shape[-1] == 1: - raise ValueError('You are passing a target array of shape ' + - str(y.shape) + - ' while using as loss `categorical_crossentropy`. ' - '`categorical_crossentropy` expects ' - 'targets to be binary matrices (1s and 0s) ' - 'of shape (samples, classes). ' - 'If your targets are integer classes, ' - 'you can convert them to the expected format via:\n' - '```\n' - 'from keras.utils import to_categorical\n' - 'y_binary = to_categorical(y_int)\n' - '```\n' - '\n' - 'Alternatively, you can use the loss function ' - '`sparse_categorical_crossentropy` instead, ' - 'which does expect integer targets.') - - is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper) - if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and - (loss.fn in key_loss_fns))): - for target_dim, out_dim in zip(y.shape[1:], shape[1:]): - if out_dim is not None and target_dim != out_dim: - loss_name = loss.name - if loss_name is None: - loss_type = loss.fn if is_loss_wrapper else type(loss) - loss_name = loss_type.__name__ - raise ValueError('A target array with shape ' + str(y.shape) + - ' was passed for an output of shape ' + str(shape) + - ' while using as loss `' + loss_name + '`. ' - 'This loss expects targets to have the same shape ' - 'as the output.') - - -def collect_per_output_metric_info(metrics, - output_names, - output_shapes, - loss_fns, - is_weighted=False): - """Maps metric names and functions to model outputs. - - Arguments: - metrics: a list or a list of lists or a dict of metric functions. - output_names: a list of the names (strings) of model outputs. - output_shapes: a list of the shapes (strings) of model outputs. - loss_fns: a list of the loss functions corresponding to the model outputs. - is_weighted: Boolean indicating whether the given metrics are weighted. - - Returns: - A list (one entry per model output) of dicts. - For instance, if the model has 2 outputs, and for the first output - we want to compute "binary_accuracy" and "binary_crossentropy", - and just "binary_accuracy" for the second output, - the list would look like: `[{ - 'acc': binary_accuracy(), - 'ce': binary_crossentropy(), - }, { - 'acc': binary_accuracy(), - }]` - - Raises: - TypeError: if an incorrect type is passed for the `metrics` argument. - """ - if not metrics: - return [{} for _ in output_names] - - if isinstance(metrics, list): - any_sub_list = any(isinstance(m, list) for m in metrics) - if any_sub_list: - if len(metrics) != len(output_names): - raise ValueError('When passing a list of lists as `metrics`, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(output_names)) + - ' outputs, but you passed metrics=' + str(metrics)) - # User has provided a list of len = len(outputs). - nested_metrics = [generic_utils.to_list(m) for m in metrics] - else: - # If it is a single list we then apply all metrics to all outputs. - if len(output_names) > 1: - nested_metrics = [] - for _ in output_names: - nested_metrics.append( - [metrics_module.clone_metric(m) for m in metrics]) - else: - nested_metrics = [metrics] - elif isinstance(metrics, collections_abc.Mapping): - generic_utils.check_for_unexpected_keys('metrics', metrics, output_names) - nested_metrics = [] - for name in output_names: - output_metrics = generic_utils.to_list(metrics.get(name, [])) - nested_metrics.append(output_metrics) - else: - raise TypeError('Type of `metrics` argument not understood. ' - 'Expected a list or dictionary, found: ' + str(metrics)) - - per_output_metrics = [] - for i, metrics in enumerate(nested_metrics): - metrics_dict = OrderedDict() - for metric in metrics: - metric_name = get_metric_name(metric, is_weighted) - metric_fn = get_metric_function( - metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]) - - # If the metric function is not stateful, we create a stateful version. - if not isinstance(metric_fn, metrics_module.Metric): - metric_fn = metrics_module.MeanMetricWrapper( - metric_fn, name=metric_name) - metrics_dict[metric_name] = metric_fn - per_output_metrics.append(metrics_dict) - - return per_output_metrics - - -def batch_shuffle(index_array, batch_size): - """Shuffles an array in a batch-wise fashion. - - Useful for shuffling HDF5 arrays - (where one cannot access arbitrary indices). - - Arguments: - index_array: array of indices to be shuffled. - batch_size: integer. - - Returns: - The `index_array` array, shuffled in a batch-wise fashion. - """ - batch_count = int(len(index_array) / batch_size) - # to reshape we need to be cleanly divisible by batch size - # we stash extra items and reappend them after shuffling - last_batch = index_array[batch_count * batch_size:] - index_array = index_array[:batch_count * batch_size] - index_array = index_array.reshape((batch_count, batch_size)) - np.random.shuffle(index_array) - index_array = index_array.flatten() - return np.append(index_array, last_batch) - - -def standardize_weights(y, - sample_weight=None, - class_weight=None, - sample_weight_mode=None): - """Performs sample weight validation and standardization. - - Everything gets normalized to a single sample-wise (or timestep-wise) - weight array. If both `sample_weight` and `class_weight` are provided, - the weights are multiplied. - - Arguments: - y: Numpy array or Tensor of model targets to be weighted. - sample_weight: User-provided `sample_weight` argument. - class_weight: User-provided `class_weight` argument. - sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated - that we expect 2D weight data that will be applied to the last 2 - dimensions of the targets (i.e. we are weighting timesteps, not - samples). - - Returns: - A numpy array of target weights, one entry per sample to weight. - - Raises: - ValueError: In case of invalid user-provided arguments. - """ - # Iterator may return sample_weight as 1-tuple - if isinstance(sample_weight, tuple): - sample_weight = sample_weight[0] - if sample_weight_mode is not None and sample_weight_mode != 'samplewise': - if sample_weight_mode != 'temporal': - raise ValueError('"sample_weight_mode ' - 'should be None or "temporal". ' - 'Found: ' + str(sample_weight_mode)) - if len(y.shape) < 3: - raise ValueError('Found a sample_weight array for ' - 'an input with shape ' + str(y.shape) + '. ' - 'Timestep-wise sample weighting (use of ' - 'sample_weight_mode="temporal") is restricted to ' - 'outputs that are at least 3D, i.e. that have ' - 'a time dimension.') - if sample_weight is not None and len(sample_weight.shape) != 2: - raise ValueError('Found a sample_weight array with shape ' + - str(sample_weight.shape) + '. ' - 'In order to use timestep-wise sample weighting, ' - 'you should pass a 2D sample_weight array.') - else: - if sample_weight is not None and len(sample_weight.shape) != 1: - raise ValueError('Found a sample_weight array with shape {}. In order to ' - 'use timestep-wise sample weights, you should specify ' - 'sample_weight_mode="temporal" in compile(); found "{}" ' - 'instead. If you just mean to use sample-wise weights, ' - 'make sure your sample_weight array is 1D.' - .format(sample_weight.shape, sample_weight_mode)) - - if sample_weight is not None: - if len(sample_weight.shape) > len(y.shape): - raise ValueError('Found a sample_weight with shape' + - str(sample_weight.shape) + '.' - 'Expected sample_weight with rank ' - 'less than or equal to ' + str(len(y.shape))) - - if (not tensor_util.is_tensor(sample_weight) and - y.shape[:sample_weight.ndim] != sample_weight.shape): - raise ValueError('Found a sample_weight array with shape ' + - str(sample_weight.shape) + ' for an input with shape ' + - str(y.shape) + '. ' - 'sample_weight cannot be broadcast.') - - # Class weights applied per-sample. - class_sample_weight = None - if isinstance(class_weight, dict): - if len(y.shape) > 2: - raise ValueError('`class_weight` not supported for ' - '3+ dimensional targets.') - - if tensor_util.is_tensor(y): - # Few classes are expected, so densifying is reasonable. - keys = np.array(sorted(class_weight.keys())) - values = np.array([class_weight[i] for i in keys]) - weight_vector = np.zeros(np.max(keys) + 1) - weight_vector[:] = np.nan - weight_vector[keys] = values - - y_classes = smart_cond.smart_cond( - len(y.shape.as_list()) == 2 and K.shape(y)[1] > 1, - lambda: K.argmax(y, axis=1), - lambda: math_ops.cast(K.reshape(y, (-1,)), dtypes.int64)) - class_sample_weight = array_ops.gather(weight_vector, y_classes) - gen_array_ops.check_numerics( - class_sample_weight, - 'Invalid classes or class weights detected. NaN values indicate that ' - 'an appropriate class weight could not be determined.') - class_sample_weight = math_ops.cast(class_sample_weight, K.floatx()) - if sample_weight is not None: - sample_weight = math_ops.cast( - ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx()) - else: - y_classes = y - if len(y.shape) == 2: - if y.shape[1] > 1: - y_classes = np.argmax(y, axis=1) - elif y.shape[1] == 1: - y_classes = np.reshape(y, y.shape[0]) - - class_sample_weight = np.asarray( - [class_weight[cls] for cls in y_classes if cls in class_weight]) - - if len(class_sample_weight) != len(y_classes): - # subtract the sets to pick all missing classes - existing_classes = set(y_classes) - existing_class_weight = set(class_weight.keys()) - raise ValueError( - '`class_weight` must contain all classes in the data.' - ' The classes %s exist in the data but not in ' - '`class_weight`.' % (existing_classes - existing_class_weight)) - - if class_sample_weight is not None and sample_weight is not None: - # Multiply weights if both are provided. - return class_sample_weight * sample_weight - if sample_weight is not None: - return sample_weight - if class_sample_weight is not None: - return class_sample_weight - return None - - -def has_symbolic_tensors(ls): - if context.executing_eagerly(): - return False - return has_tensors(ls) - - -def has_tensors(ls): - """Returns true if `ls` contains tensors.""" - # Note: at some point in time ragged tensors didn't count as tensors, so this - # returned false for ragged tensors. Making this return true fails some tests - # which would then require a steps_per_epoch argument. - if isinstance(ls, (list, tuple)): - return any( - tensor_util.is_tensor(v) and - not isinstance(v, ragged_tensor.RaggedTensor) for v in ls) - if isinstance(ls, dict): - return any( - tensor_util.is_tensor(v) and - not isinstance(v, ragged_tensor.RaggedTensor) - for _, v in six.iteritems(ls)) - return tensor_util.is_tensor(ls) and not isinstance( - ls, ragged_tensor.RaggedTensor) - - -def get_metric_name(metric, weighted=False): - """Returns the name corresponding to the given metric input. - - Arguments: - metric: Metric function name or reference. - weighted: Boolean indicating if the given metric is weighted. - - Returns: - The metric name. - """ - if tf2.enabled(): - # We keep the string that the user has set in compile as the metric name. - if isinstance(metric, six.string_types): - return metric - - metric = metrics_module.get(metric) - return metric.name if hasattr(metric, 'name') else metric.__name__ - else: - metric_name_prefix = 'weighted_' if weighted else '' - if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): - if metric in ('accuracy', 'acc'): - suffix = 'acc' - elif metric in ('crossentropy', 'ce'): - suffix = 'ce' - else: - metric_fn = metrics_module.get(metric) - # Get metric name as string - if hasattr(metric_fn, 'name'): - suffix = metric_fn.name - else: - suffix = metric_fn.__name__ - metric_name = metric_name_prefix + suffix - return metric_name - - -def get_metric_function(metric, output_shape=None, loss_fn=None): - """Returns the metric function corresponding to the given metric input. - - Arguments: - metric: Metric function name or reference. - output_shape: The shape of the output that this metric will be calculated - for. - loss_fn: The loss function used. - - Returns: - The metric function. - """ - if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']: - return metrics_module.get(metric) - - is_sparse_categorical_crossentropy = ( - isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or - (isinstance(loss_fn, losses.LossFunctionWrapper) and - loss_fn.fn == losses.sparse_categorical_crossentropy)) - - is_binary_crossentropy = ( - isinstance(loss_fn, losses.BinaryCrossentropy) or - (isinstance(loss_fn, losses.LossFunctionWrapper) and - loss_fn.fn == losses.binary_crossentropy)) - - if metric in ['accuracy', 'acc']: - if output_shape[-1] == 1 or is_binary_crossentropy: - return metrics_module.binary_accuracy - elif is_sparse_categorical_crossentropy: - return metrics_module.sparse_categorical_accuracy - # If the output_shape[-1] is not 1, then we know output is `categorical`. - # We assume it is sparse categorical only if loss is explicitly given - # as sparse categorical crossentropy loss. - return metrics_module.categorical_accuracy - else: - if output_shape[-1] == 1 or is_binary_crossentropy: - return metrics_module.binary_crossentropy - elif is_sparse_categorical_crossentropy: - return metrics_module.sparse_categorical_crossentropy - return metrics_module.categorical_crossentropy - - -def call_metric_function(metric_fn, - y_true, - y_pred=None, - weights=None, - mask=None): - """Invokes metric function and returns the metric result tensor.""" - if mask is not None: - mask = math_ops.cast(mask, y_pred.dtype) - if weights is None: - # Use mask as sample weight. - weights = mask - else: - # Update dimensions of weights to match with mask. - weights = math_ops.cast(weights, dtype=y_pred.dtype) - mask, _, weights = losses_utils.squeeze_or_expand_dimensions( - mask, sample_weight=weights) - weights *= mask - - if y_pred is not None: - return metric_fn(y_true, y_pred, sample_weight=weights) - # `Mean` metric only takes a single value. - return metric_fn(y_true, sample_weight=weights) - - -def get_loss_function(loss): - """Returns the loss corresponding to the loss input in `compile` API.""" - if loss is None or isinstance(loss, losses.Loss): - return loss - - if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss): - # It is not safe to assume that the loss takes no constructor arguments. - raise ValueError( - 'Received uninstantiated Loss class: {}\nPlease call loss ""classes ' - 'before passing them to Model.compile.'.format(loss)) - - # Deserialize loss configuration, if needed. - if isinstance(loss, collections_abc.Mapping): - loss = losses.get(loss) - - # Custom callable class. - if callable(loss) and not hasattr(loss, '__name__'): - return loss - - # Wrap loss function with signature `(y_true, y_pred, **kwargs)` - # in `LossFunctionWrapper` class. - loss_fn = losses.get(loss) - - # For losses which are given as strings/functions in the compile API, - # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE` - # (both in distribution strategy context and otherwise). - return losses.LossFunctionWrapper( - loss_fn, - name=loss_fn.__name__, - reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE) - - class RespectCompiledTrainableState(object): """Set and restore trainable state if it has changed since compile. @@ -1243,568 +167,6 @@ class RespectCompiledTrainableState(object): return False # False values do not suppress exceptions -def validate_dataset_input(x, y, sample_weight, validation_split=None): - """Validates user input arguments when a dataset iterator is passed. - - Arguments: - x: Input data. A `tf.data` dataset or iterator. - y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s). - Expected to be `None` when `x` is a dataset iterator. - sample_weight: An optional sample-weight array passed by the user to weight - the importance of each sample in `x`. Expected to be `None` when `x` is a - dataset iterator - validation_split: Float between 0 and 1. Fraction of the training data to be - used as validation data. Expected to be `None` when `x` is a dataset - iterator. - - Raises: - ValueError: if argument `y` or `sample_weight` or `validation_split` are - provided by user. - """ - if y is not None: - raise ValueError('You passed a dataset or dataset iterator (%s) as ' - 'input `x` to your model. In that case, you should ' - 'not specify a target (`y`) argument, since the dataset ' - 'or dataset iterator generates both input data and ' - 'target data. ' - 'Received: %s' % (x, y)) - if sample_weight is not None: - raise ValueError('`sample_weight` argument is not supported when input ' - '`x` is a dataset or a dataset iterator. Instead, you' - 'can provide sample_weight as the third element of your' - 'dataset, i.e. (inputs, targets, sample_weight). ' - 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) - if validation_split is not None and validation_split != 0.0: - raise ValueError( - '`validation_split` argument is not supported when ' - 'input `x` is a dataset or a dataset iterator. ' - 'Received: x=%s, validation_split=%f' % (x, validation_split)) - - -def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'): - """Helper function to validate either inputs or targets.""" - if isinstance(inp, (list, tuple)): - if not all(isinstance(v, np.ndarray) or - tensor_util.is_tensor(v) for v in inp): - raise ValueError( - 'Please provide as model inputs either a single array or a list of ' - 'arrays. You passed: {}={}'.format(field_name, str(orig_inp))) - elif isinstance(inp, dict): - if not allow_dict: - raise ValueError( - 'You cannot pass a dictionary as model {}.'.format(field_name)) - elif not isinstance(inp, np.ndarray) and not tensor_util.is_tensor(inp): - raise ValueError( - 'Please provide as model inputs either a single array or a list of ' - 'arrays. You passed: {}={}'.format(field_name, orig_inp)) - - -def check_generator_arguments(y=None, sample_weight=None, - validation_split=None): - """Validates arguments passed when using a generator.""" - if y is not None: - raise ValueError('`y` argument is not supported when data is' - 'a generator or Sequence instance. Instead pass targets' - ' as the second element of the generator.') - if sample_weight is not None: - raise ValueError('`sample_weight` argument is not supported when data is' - 'a generator or Sequence instance. Instead pass sample' - ' weights as the third element of the generator.') - if validation_split: - raise ValueError('If your data is in the form of a Python generator, ' - 'you cannot use `validation_split`.') - - -def check_steps_argument(input_data, steps, steps_name): - """Validates `steps` argument based on input data's type. - - The cases when `steps` value must be provided are when - 1. input data passed is an iterator. - 2. model was built on top of symbolic tensors, input data is not - required and is `None`. - 3. input data passed is a symbolic tensor. - - Arguments: - input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or - tf.data.Dataset iterator or `None`. - steps: Integer or `None`. Total number of steps (batches of samples) to - execute. - steps_name: The public API's parameter name for `steps`. - - Returns: - boolean, True if `steps` argument is required, else False. - - Raises: - ValueError: if `steps` argument is required for given input data type - but not provided. - """ - is_x_iterator = isinstance( - input_data, (iterator_ops.Iterator, iterator_ops.OwnedIterator)) - if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or - (isinstance(input_data, list) and not input_data)): - if steps is None: - input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors' - raise ValueError('When using {input_type} as input to a model, you should' - ' specify the `{steps_name}` argument.'.format( - input_type=input_type_str, steps_name=steps_name)) - return True - - if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): - return True - - if steps is not None: - list_types = (np.ndarray, list, tuple) - if (isinstance(input_data, list_types) or - (isinstance(input_data, dict) and - any(isinstance(v, list_types) for v in input_data.values()))): - logging.warning('When passing input data as arrays, do not specify ' - '`steps_per_epoch`/`steps` argument. ' - 'Please use `batch_size` instead.') - return False - - -def cast_single_tensor(x, dtype=None): - if isinstance(x, np.ndarray): - x = ops.convert_to_tensor_v2_with_dispatch(x) - dtype = dtype or K.floatx() - if x.dtype.is_floating: - return math_ops.cast(x, dtype=dtype) - return x - - -def cast_if_floating_dtype_and_mismatch(targets, outputs): - """Returns target data tensors using correct datatype. - - Checks that each target and output pair are the same datatype. If not, casts - the target to the output's datatype. - - Args: - targets: tensor or list of targets. - outputs: tensor or list of outputs. - - Returns: - Targets in appropriate datatype. - """ - if tensor_util.is_tensor(targets): - # There is one target, so output[0] should be the only output. - return cast_single_tensor(targets, dtype=outputs[0].dtype) - new_targets = [] - for target, out in zip(targets, outputs): - if isinstance(target, np.ndarray): - target = ops.convert_to_tensor_v2_with_dispatch(target) - if target.dtype != out.dtype: - new_targets.append(cast_single_tensor(target, dtype=out.dtype)) - else: - new_targets.append(target) - return new_targets - - -def cast_if_floating_dtype(x, dtype=None): - """Casts the given data tensors to the default floating point type. - - Casts only if the input is already a floating point type. - Args: - x: tensor or list/tuple of tensors. - dtype: The dtype to which Tensors should be cast. - - Returns: - Converted input. - """ - return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype), - x) - - -def cast_to_model_input_dtypes(x, model): - """Casts the given data tensors to the dtypes of the model inputs. - - Args: - x: tensor or list/tuple of tensors. - model: The model. - - Returns: - Converted input. Each tensor is casted to the corresponding input in - `model.inputs`. - """ - input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs) - return nest.map_structure(math_ops.cast, x, input_dtypes) - - -def prepare_sample_weight_modes(training_endpoints, sample_weight_mode): - """Prepares sample weight modes for the model. - - Args: - training_endpoints: List of model _TrainingEndpoints. - sample_weight_mode: sample weight mode user input passed from compile API. - - Raises: - ValueError: In case of invalid `sample_weight_mode` input. - """ - - if isinstance(sample_weight_mode, collections_abc.Mapping): - generic_utils.check_for_unexpected_keys( - 'sample_weight_mode', sample_weight_mode, - [e.output_name for e in training_endpoints]) - - for end_point in training_endpoints: - if not end_point.should_skip_target_weights(): - if end_point.output_name not in sample_weight_mode: - raise ValueError('Output ' + end_point.output_name + - 'missing from `_sample_weight_modes` dictionary') - else: - end_point.sample_weight_mode = sample_weight_mode.get( - end_point.output_name) - elif isinstance(sample_weight_mode, (list, tuple)): - if len(sample_weight_mode) != len(training_endpoints): - raise ValueError('When passing a list as sample_weight_mode, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(training_endpoints)) + - ' outputs, but you passed ' + - str(len(sample_weight_mode)) + '_sample_weight_modes.') - for mode, endpoint in zip(sample_weight_mode, training_endpoints): - if not endpoint.should_skip_target_weights(): - endpoint.sample_weight_mode = mode - else: - for endpoint in training_endpoints: - if not endpoint.should_skip_target_weights(): - endpoint.sample_weight_mode = sample_weight_mode - - -def prepare_loss_functions(loss, output_names): - """Converts loss to a list of loss functions. - - Arguments: - loss: String (name of objective function), objective function or - `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple - outputs, you can use a different loss on each output by passing a - dictionary or a list of losses. The loss value that will be minimized by - the model will then be the sum of all individual losses. - output_names: List of model output names. - - Returns: - A list of loss objective functions. - - Raises: - ValueError: If loss is a dict with keys not in model output names, - or if loss is a list with len not equal to model outputs. - """ - if isinstance(loss, collections_abc.Mapping): - generic_utils.check_for_unexpected_keys('loss', loss, output_names) - loss_functions = [] - for name in output_names: - if name not in loss: - logging.warning( - 'Output {0} missing from loss dictionary. We assume ' - 'this was done on purpose. The fit and evaluate APIs will not be ' - 'expecting any data to be passed to {0}.'.format(name)) - loss_functions.append(get_loss_function(loss.get(name, None))) - elif isinstance(loss, six.string_types): - loss_functions = [get_loss_function(loss) for _ in output_names] - elif isinstance(loss, collections_abc.Sequence): - if len(loss) != len(output_names): - raise ValueError('When passing a list as loss, it should have one entry ' - 'per model outputs. The model has {} outputs, but you ' - 'passed loss={}'.format(len(output_names), loss)) - loss_functions = nest.map_structure(get_loss_function, loss) - else: - loss_functions = [get_loss_function(loss) for _ in range(len(output_names))] - - return loss_functions - - -def prepare_loss_weights(training_endpoints, loss_weights=None): - """Converts loss weights to a list of loss weights. - - The result loss weights will be populated on the training endpoint. - - Arguments: - training_endpoints: List of model training endpoints. - loss_weights: Optional list or dictionary specifying scalar coefficients - (Python floats) to weight the loss contributions of different model - outputs. The loss value that will be minimized by the model will then be - the *weighted sum* of all individual losses, weighted by the - `loss_weights` coefficients. If a list, it is expected to have a 1:1 - mapping to the model's outputs. If a dict, it is expected to map - output names (strings) to scalar coefficients. - - Raises: - ValueError: If loss weight is a dict with key not in model output names, - or if loss is a list with len not equal to model outputs. - """ - if loss_weights is None: - for e in training_endpoints: - e.loss_weight = 1. - elif isinstance(loss_weights, collections_abc.Mapping): - generic_utils.check_for_unexpected_keys( - 'loss_weights', loss_weights, - [e.output_name for e in training_endpoints]) - for e in training_endpoints: - e.loss_weight = loss_weights.get(e.output_name, 1.) - elif isinstance(loss_weights, list): - if len(loss_weights) != len(training_endpoints): - raise ValueError('When passing a list as loss_weights, ' - 'it should have one entry per model output. ' - 'The model has ' + str(len(training_endpoints)) + - ' outputs, but you passed loss_weights=' + - str(loss_weights)) - for w, e in zip(loss_weights, training_endpoints): - e.loss_weight = w - else: - raise TypeError('Could not interpret loss_weights argument: ' + - str(loss_weights) + ' - expected a list of dicts.') - - -# TODO(rohanj): This is a hack to get around not depending on feature_column and -# create a cyclical dependency. Figure out a cleaner solution -def is_feature_layer(layer): - """Returns whether `layer` is a FeatureLayer or not.""" - return getattr(layer, '_is_feature_layer', False) - - -def is_eager_dataset_or_iterator(data): - return context.executing_eagerly() and isinstance( - data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, - iterator_ops.OwnedIterator)) - - -# pylint: disable=protected-access -def get_dataset_graph_def(dataset): - if context.executing_eagerly(): - graph_def_str = dataset._as_serialized_graph().numpy() - else: - graph_def_str = K.get_value(dataset._as_serialized_graph()) - return graph_pb2.GraphDef().FromString(graph_def_str) - - -def verify_dataset_shuffled(x): - """Verifies that the dataset is shuffled. - - Args: - x: Dataset passed as an input to the model. - - Returns: - boolean, whether the input dataset is shuffled or not. - """ - assert isinstance(x, dataset_ops.DatasetV2) - graph_def = get_dataset_graph_def(x) - for node in graph_def.node: - if node.op.startswith('ShuffleDataset'): - return True - # Also check graph_def.library.function for ds.interleave or ds.flat_map - for function in graph_def.library.function: - for node in function.node_def: - if node.op.startswith('ShuffleDataset'): - return True - logging.warning('Expected a shuffled dataset but input dataset `x` is ' - 'not shuffled. Please invoke `shuffle()` on input dataset.') - return False - - -def is_dataset_or_iterator(data): - return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, - iterator_ops.Iterator, iterator_ops.OwnedIterator)) - - -def get_iterator(dataset): - """Create and initialize an iterator from a dataset.""" - if context.executing_eagerly(): - iterator = dataset_ops.make_one_shot_iterator(dataset) - else: - iterator = dataset_ops.make_initializable_iterator(dataset) - initialize_iterator(iterator) - return iterator - - -def initialize_iterator(iterator): - if not context.executing_eagerly(): - init_op = iterator.initializer - K.get_session((init_op,)).run(init_op) - - -def extract_tensors_from_dataset(dataset): - """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset. - - Arguments: - dataset: Dataset instance. - - Returns: - Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. - """ - iterator = get_iterator(dataset) - inputs, targets, sample_weight = unpack_iterator_input(iterator) - return inputs, targets, sample_weight - - -def unpack_iterator_input(iterator): - """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`. - - Arguments: - iterator: Instance of a dataset iterator. - - Returns: - Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. - """ - try: - next_element = iterator.get_next() - except errors.OutOfRangeError: - raise RuntimeError('Your dataset iterator ran out of data; ' - 'Make sure that your dataset can generate ' - 'required number of samples.') - - if isinstance(next_element, (list, tuple)): - if len(next_element) not in [2, 3]: - raise ValueError( - 'Please provide model inputs as a list or tuple of 2 or 3 ' - 'elements: (input, target) or (input, target, sample_weights) ' - 'Received %s' % next_element) - if len(next_element) == 2: - x, y = next_element - weights = None - else: - x, y, weights = next_element - else: - x = next_element - y = None - weights = None - return x, y, weights - - -def infer_steps_for_dataset(model, - dataset, - steps, - epochs=1, - steps_name='steps'): - """Infers steps_per_epoch needed to loop through a dataset. - - Arguments: - model: Keras model instance. - dataset: Input data of type tf.data.Dataset. - steps: Number of steps to draw from the dataset (may be None if unknown). - epochs: Number of times to iterate over the dataset. - steps_name: The string name of the steps argument, either `steps`, - `validation_steps`, or `steps_per_epoch`. Only used for error message - formatting. - - Returns: - Integer or `None`. Inferred number of steps to loop through the dataset. - `None` is returned if 1) the size of the dataset is unknown and `steps` was - not specified, or 2) this is multi-worker training and auto sharding is - enabled. - - Raises: - ValueError: In case of invalid argument values. - """ - assert isinstance(dataset, dataset_ops.DatasetV2) - if (model._in_multi_worker_mode() and - (dataset.options().experimental_distribute.auto_shard_policy != - AutoShardPolicy.OFF)): - # If the dataset would be auto-sharded, we should not infer a local - # steps_per_epoch due to the possible inbalanced sharding between workers. - return None - - size = K.get_value(cardinality.cardinality(dataset)) - if size == cardinality.INFINITE and steps is None: - raise ValueError('When passing an infinitely repeating dataset, you ' - 'must specify the `%s` argument.' % (steps_name,)) - if size >= 0: - if steps is not None and steps * epochs > size: - if epochs > 1: - raise ValueError('The dataset you passed contains %s batches, but you ' - 'passed `epochs=%s` and `%s=%s`, which is a total of ' - '%s steps. We cannot draw that many steps from this ' - 'dataset. We suggest to set `%s=%s`.' % - (size, epochs, steps_name, steps, steps * epochs, - steps_name, size // epochs)) - else: - raise ValueError('The dataset you passed contains %s batches, but you ' - 'passed `%s=%s`. We cannot draw that many steps from ' - 'this dataset. We suggest to set `%s=%s`.' % - (size, steps_name, steps, steps_name, size)) - if steps is None: - if size >= 0: - return size - return None - return steps - - -class ModelInputs(object): - """Encapsulates model inputs. - - Allows for transforming model inputs while keeping the same structure. - """ - - def __init__(self, inputs): - self._inputs = inputs - self._is_dict = isinstance(self._inputs, dict) - self._is_single_input = not isinstance(self._inputs, (list, tuple, dict)) - - self._flattened_inputs = [] - self._input_names = [] - - if self._is_dict: - for k in sorted(self._inputs.keys()): - self._flattened_inputs.append(self._inputs[k]) - self._input_names.append(k) - else: - self._flattened_inputs = nest.flatten(self._inputs) - self._input_names = [ - 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs)) - ] - - def get_input_names(self): - """Returns keys to name inputs by. - - In case inputs provided were a list, tuple or single entry, we make up a - key 'input_%d'. For dictionary case, we return a sorted list of keys. - """ - return self._input_names - - def get_symbolic_inputs(self, return_single_as_list=False): - """Returns inputs to be set as self.inputs for a model.""" - # TODO(karmel): There is a side-effect here where what you get - # with as_list and as_dict depends on whether you have called this - # method first, since it modifies in place. - for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)): - if isinstance(v, (list, float, int)): - v = np.asarray(v) - if v.ndim == 1: - v = np.expand_dims(v, 1) - - if isinstance(v, (np.ndarray, ops.EagerTensor)): - # We fix the placeholder shape except the batch size. - # This is suboptimal, but it is the best we can do with the info - # we have. The user should call `model._set_inputs(placeholders)` - # to specify custom placeholders if the need arises. - shape = (None,) + tuple(v.shape[1:]) - if shape == (None,): - shape = (None, 1) - dtype = dtypes.as_dtype(v.dtype) - if dtype.is_floating: - dtype = K.floatx() - v = K.placeholder(shape=shape, name=k, dtype=dtype) - elif isinstance(v, tensor_spec.TensorSpec): - shape = (None,) + tuple(v.shape.as_list()[1:]) - if shape == (None,): - shape = (None, 1) - v = K.placeholder(shape=shape, name=k, dtype=v.dtype) - - self._flattened_inputs[i] = v - - if self._is_dict: - return dict(zip(self._input_names, self._flattened_inputs)) - if self._is_single_input and not return_single_as_list: - return self._flattened_inputs[0] - return self._flattened_inputs - - def as_dict(self): - """An iterable over a dictionary version of inputs.""" - for k, v in zip(self._input_names, self._flattened_inputs): - yield k, v - - def as_list(self): - """Returning the inputs as a list.""" - return self._flattened_inputs - - # Allow use of methods not exposed to the user. # pylint: disable=protected-access def get_input_shape_and_dtype(layer): @@ -1856,187 +218,8 @@ def get_static_batch_size(layer): return None -def generic_output_names(outputs_list): - return ['output_%d' % (i + 1) for i in range(len(outputs_list))] - - -def convert_eager_tensors_to_numpy(structure): - """Convert every EagerTensor in `structure` to NumPy. - - Arguments: - structure: An arbitrary structure of elements to be converted to NumPy - arrays. - - Returns: - An identical structure with EagerTensors converted to NumPy arrays. - """ - - def _convert(element): - if isinstance(element, ops.EagerTensor): - return element.numpy() - return element - - return nest.map_structure(_convert, structure) - - def list_to_tuple(maybe_list): """Datasets will stack the list of tensor, so switch them to tuples.""" if isinstance(maybe_list, list): return tuple(maybe_list) return maybe_list - - -def should_run_validation(validation_freq, epoch): - """Checks if validation should be run this epoch. - - Arguments: - validation_freq: Integer or list. If an integer, specifies how many training - epochs to run before a new validation run is performed. If a list, - specifies the epochs on which to run validation. - epoch: Integer, the number of the training epoch just completed. - - Returns: - Bool, True if validation should be run. - - Raises: - ValueError: if `validation_freq` is an Integer and less than 1, or if - it is neither an Integer nor a Sequence. - """ - # `epoch` is 0-indexed internally but 1-indexed in the public API. - one_indexed_epoch = epoch + 1 - - if isinstance(validation_freq, int): - if validation_freq < 1: - raise ValueError('`validation_freq` can not be less than 1.') - return one_indexed_epoch % validation_freq == 0 - - if not isinstance(validation_freq, collections_abc.Container): - raise ValueError('`validation_freq` must be an Integer or ' - '`collections_abc.Container` (e.g. list, tuple, etc.)') - return one_indexed_epoch in validation_freq - - -def split_training_and_validation_data(x, y, sample_weights, validation_split): - """Split input data into train/eval section based on validation_split.""" - if has_symbolic_tensors(x): - raise ValueError('If your data is in the form of symbolic tensors, ' - 'you cannot use `validation_split`.') - if hasattr(x[0], 'shape'): - split_at = int(x[0].shape[0] * (1. - validation_split)) - else: - split_at = int(len(x[0]) * (1. - validation_split)) - x, val_x = (generic_utils.slice_arrays(x, 0, split_at), - generic_utils.slice_arrays(x, split_at)) - y, val_y = (generic_utils.slice_arrays(y, 0, split_at), - generic_utils.slice_arrays(y, split_at)) - if sample_weights: - sample_weights, val_sample_weights = ( - generic_utils.slice_arrays(sample_weights, 0, split_at), - generic_utils.slice_arrays(sample_weights, split_at), - ) - else: - val_sample_weights = None - return x, y, sample_weights, val_x, val_y, val_sample_weights - - -def unpack_validation_data(validation_data, raise_if_ambiguous=True): - """Unpack validation data based input type. - - The validation data is not touched if its dataset or dataset iterator. - For other type of input (Numpy or tensor), it will be unpacked into tuple of - 3 which is x, y and sample weights. - - Args: - validation_data: dataset, dataset iterator, or numpy, tensor tuple. - raise_if_ambiguous: boolean on whether to fail if validation_data cannot be - parsed. Otherwise simply return validation_data, None, None and defer the - decision to the caller. - - Returns: - tuple of 3, (x, y, sample_weights) for numpy and tensor input. - """ - if (isinstance(validation_data, (iterator_ops.Iterator, - iterator_ops.OwnedIterator, - dataset_ops.DatasetV2, - data_utils.Sequence)) - or not hasattr(validation_data, '__len__')): - val_x = validation_data - val_y = None - val_sample_weight = None - elif len(validation_data) == 2: - try: - val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence - val_sample_weight = None - except ValueError: - val_x, val_y, val_sample_weight = validation_data, None, None - elif len(validation_data) == 3: - try: - val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence - except ValueError: - val_x, val_y, val_sample_weight = validation_data, None, None - else: - if raise_if_ambiguous: - raise ValueError( - 'When passing a `validation_data` argument, ' - 'it must contain either 2 items (x_val, y_val), ' - 'or 3 items (x_val, y_val, val_sample_weights), ' - 'or alternatively it could be a dataset or a ' - 'dataset or a dataset iterator. ' - 'However we received `validation_data=%s`' % validation_data) - val_x, val_y, val_sample_weight = validation_data, None, None - return val_x, val_y, val_sample_weight - - -class TrainingLoop(object): - """TrainingLoop is a wrapper class around the training logic. - - This class is trying to encapsulate the different logic of fit/eval/predict - with regard to different data input and model condition. - - Note that TrainingLoop is stateless, which means it doesn't contain any - internal field and can be reused with different model and inputs. - """ - - def fit(self, - model, - x=None, - y=None, - batch_size=None, - epochs=1, - verbose=1, - callbacks=None, - validation_split=0., - validation_data=None, - shuffle=True, - class_weight=None, - sample_weight=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None, - validation_freq=1, - **kwargs): - """Train the model with the inputs and targets.""" - raise NotImplementedError() - - def evaluate(self, - model, - x=None, - y=None, - batch_size=None, - verbose=1, - sample_weight=None, - steps=None, - callbacks=None, - **kwargs): - """Returns the loss value & metrics values for the model in test mode.""" - raise NotImplementedError() - - def predict(self, - model, - x, - batch_size=None, - verbose=0, - steps=None, - callbacks=None, - **kwargs): - raise NotImplementedError() diff --git a/tensorflow/python/keras/engine/training_utils_v1.py b/tensorflow/python/keras/engine/training_utils_v1.py new file mode 100644 index 00000000000..c198bad1511 --- /dev/null +++ b/tensorflow/python/keras/engine/training_utils_v1.py @@ -0,0 +1,1827 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Training-related utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import atexit +import collections +import functools +import multiprocessing.pool +import threading +import time + +import numpy as np +import six +from six.moves import zip # pylint: disable=redefined-builtin + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python import tf2 +from tensorflow.python.data.experimental.ops import cardinality +from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import composite_tensor_utils +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import backend as K +from tensorflow.python.keras import callbacks as cbks +from tensorflow.python.keras import losses +from tensorflow.python.keras import metrics as metrics_module +from tensorflow.python.keras.utils import data_utils +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import losses_utils +from tensorflow.python.keras.utils import tf_inspect +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest +from tensorflow.python.util.compat import collections_abc + + +@six.add_metaclass(abc.ABCMeta) +class Aggregator(object): + """Abstract base class used to aggregate batch-level outputs of a loop. + + Attributes: + use_steps: Whether the loop is using `step` or `batch_size`. + num_samples: Total number of samples: `batch_size * num_batches`. + steps: Total number of steps. + batch_size: Batch size. It is used for validation checks between inputs and + outputs. + results: What to return at the end of the aggregation loop. + """ + + def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None): + self.use_steps = use_steps + self.num_samples = num_samples + self.steps = steps + self.batch_size = batch_size + self.results = [] + + @abc.abstractmethod + def create(self, batch_outs): + """Creates the initial results from the first batch outputs. + + Arguments: + batch_outs: A list of batch-level outputs. + """ + raise NotImplementedError('Must be implemented in subclasses.') + + @abc.abstractmethod + def aggregate(self, batch_outs, batch_start=None, batch_end=None): + """Aggregates batch-level results into total results. + + Arguments: + batch_outs: A list of batch-level outputs. + batch_start: The start index of this batch. Always `None` if `use_steps` + is `True`. + batch_end: The end index of this batch. Always `None` if `use_steps` is + `True`. + """ + raise NotImplementedError('Must be implemented in subclasses.') + + @abc.abstractmethod + def finalize(self): + """Prepares the total results to be returned.""" + raise NotImplementedError('Must be implemented in subclasses.') + + +class MetricsAggregator(Aggregator): + """Aggregator that calculates loss and metrics info. + + Attributes: + use_steps: Whether the loop is using `step` or `batch_size`. + num_samples: Total number of samples: `batch_size*num_batches`. + steps: Total number of steps, ie number of times to iterate over a dataset + to cover all samples. + """ + + def __init__(self, use_steps, num_samples=None, steps=None): + super(MetricsAggregator, self).__init__( + use_steps=use_steps, + num_samples=num_samples, + steps=steps, + batch_size=None) + + def create(self, batch_outs): + self.results = [0.] * len(batch_outs) + + def aggregate(self, batch_outs, batch_start=None, batch_end=None): + # Loss. + if self.use_steps: + self.results[0] += batch_outs[0] + else: + self.results[0] += batch_outs[0] * (batch_end - batch_start) + # Metrics (always stateful, just grab current values.) + self.results[1:] = batch_outs[1:] + + def finalize(self): + if not self.results: + raise ValueError('Empty training data.') + self.results[0] /= (self.num_samples or self.steps) + + +class ConcatAggregator(Aggregator): + """Combine tensor-likes which cannot be merged on the fly. + + This class expects to aggregate a single tensor-like rather than a nested + structure of tensor-likes. + """ + + def __init__(self, batch_size): + self.composite = None + super(ConcatAggregator, self).__init__( + use_steps=True, num_samples=None, steps=None, batch_size=batch_size) + + def create(self, batch_element): + self.composite = composite_tensor_utils.is_composite_or_composite_value( + batch_element) + + def aggregate(self, batch_element, batch_start=None, batch_end=None): + + # TODO(psv): Add num_samples check here to detect when output batch + # #samples is < batch size and != input batch #samples. + if self.batch_size and self.batch_size < batch_element.shape[0]: + raise ValueError( + 'Mismatch between expected batch size and model output batch size. ' + 'Output shape = {}, expected output shape = shape {}'.format( + batch_element.shape, + (self.batch_size,) + batch_element.shape[1:])) + self.results.append(batch_element) + + def finalize(self): + # Special case of single batch inference which skips a copy. + if len(self.results) == 1: + self.results = self.results[0] + + elif self.composite: + # TODO(taylorrobie): efficiently concatenate. + results = self.results[0] + for r in self.results[1:]: + results = composite_tensor_utils.append_composite_tensor(results, r) + self.results = results + + else: + self.results = np.concatenate(self.results, axis=0) + + +_COPY_THREADS = 4 +_COPY_POOL = None + + +def get_copy_pool(): + """Shared threadpool for copying arrays. + + Pool instantiation takes ~ 2ms, so a singleton pool is used rather than + creating a pool per SliceAggregator. + + Returns: + The global copy threadpool. + """ + global _COPY_POOL + if _COPY_POOL is None: + _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS) + atexit.register(_COPY_POOL.close) + return _COPY_POOL + + +class SliceAggregator(Aggregator): + """Combine arrays where the final size is known. + + This class expects to aggregate a single tensor-like rather than a nested + structure of tensor-likes. + + NumPy copies are an operation that threads handle quite well because all of + the heavy lifting is in c and does not need the GIL. Moreover, we can perform + lock-free writes to the same buffer in multiple threads because the nature of + result aggregation guarantees that either the indices are disjoint or the + aggregator will throw an exception in finalize. Moreover, because aggregation + is performed on the slowest varying dimension, assignments for a given batch + will write to contiguous blocks of memory, further minimizing contention. + + There is, however, some scheduling and context switching overhead which will + offset the gains from pipelining the slice assignment. Below a given threshold + it is faster to simply assign in the main thread rather than enqueue the + assignment in a side thread. The exact threshold will vary from system to + system, but the time is not very sensitive to the exact transition so a value + of 2 ** 14 was chosen which should be reasonable on most systems. + """ + + _BINARY_SIZE_THRESHOLD = 2 ** 14 + _MAX_COPY_SECONDS = 300 + + def __init__(self, num_samples, batch_size): + self._async_copies = [] + self._pool = get_copy_pool() + self._errors = [] + super(SliceAggregator, self).__init__( + use_steps=False, + num_samples=num_samples, + steps=None, + batch_size=batch_size) + + def create(self, batch_element): + # This step does not need to be pipelined because NumPy empty array + # initialization is effectively instantaneous. + shape = (self.num_samples,) + batch_element.shape[1:] + dtype = batch_element.dtype + + self.results = np.empty(shape=shape, dtype=dtype) + + def aggregate(self, batch_element, batch_start, batch_end): + # Fail early. + if self._errors: + six.reraise(type(self._errors[0]), self._errors[0]) + + # In the special case of single batch inference, no copy is needed. + if batch_end - batch_start == self.num_samples: + if self.num_samples != batch_element.shape[0]: + raise ValueError( + 'Mismatch between expected batch size and model output batch size. ' + 'Output shape = {}, expected output shape = shape {}'.format( + batch_element.shape, self.results.shape)) + + self.results = batch_element + return + + # This is an approximate threshold, so we don't need to consider the number + # of bytes per element. + num_elements = np.prod(batch_element.shape) + if num_elements < self._BINARY_SIZE_THRESHOLD: + self.results[batch_start:batch_end] = batch_element + else: + is_finished = threading.Event() + self._pool.apply_async( + self._slice_assign, + args=(batch_element, batch_start, batch_end, is_finished)) + self._async_copies.append(is_finished) + + def _slice_assign(self, batch_element, batch_start, batch_end, is_finished): + """Legacy utility method to slice input arrays.""" + try: + self.results[batch_start:batch_end] = batch_element + + except Exception as e: # pylint: disable=broad-except + # `_slice_assign` should only be called in threads and exceptions raised + # in threads do not carry over to the main thread. So instead we perform a + # a broad catch in the thread and then store the exception to be re-raised + # in the main thread. + self._errors.append(e) + + finally: + is_finished.set() + + def finalize(self): + start_time = time.time() + for is_finished in self._async_copies: + timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)]) + if not is_finished.wait(timeout): + raise ValueError('Timed out waiting for copy to complete.') + + if self._errors: + six.reraise(self._errors[0].__class__, self._errors[0]) + + +class OutputsAggregator(Aggregator): + """Aggregator that concatenates outputs.""" + + _structure = None + + def create(self, batch_outs): + # SparseTensorValue is a named tuple which nest will flatten, so we need + # to guard it to properly handle the structure. + self._structure = nest.get_traverse_shallow_structure( + lambda x: not composite_tensor_utils.is_composite_or_composite_value(x), + batch_outs) + batch_outs = nest.flatten_up_to(self._structure, batch_outs) + + for batch_element in batch_outs: + if composite_tensor_utils.is_composite_or_composite_value(batch_element): + # If the output is not a ndarray, it will be either a composite tensor + # or a composite tensor's Value object. In either case, we can't + # allocate an array to hold the object - we'll handle it later. + self.results.append(ConcatAggregator(self.batch_size)) + elif isinstance(batch_element, np.ndarray): + self.results.append( + (ConcatAggregator(self.batch_size) if self.use_steps else + SliceAggregator(self.num_samples, self.batch_size))) + else: + # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue. + # Fail fast rather than trying to concatenate it. + raise RuntimeError('Attempted to aggregate unsupported object {}.' + .format(batch_element)) + + self.results[-1].create(batch_element) + + def aggregate(self, batch_outs, batch_start=None, batch_end=None): + batch_outs = nest.flatten_up_to(self._structure, batch_outs) + for batch_element, result in zip(batch_outs, self.results): + result.aggregate(batch_element, batch_start, batch_end) + + def finalize(self): + for result in self.results: + result.finalize() + self.results = [i.results for i in self.results] + self.results = nest.pack_sequence_as(self._structure, self.results) + + +def get_progbar(model, count_mode, include_metrics=True): + """Get Progbar.""" + if include_metrics: + stateful_metric_names = getattr(model, 'metrics_names', None) + if stateful_metric_names: + stateful_metric_names = stateful_metric_names[1:] # Exclude `loss` + else: + stateful_metric_names = None + return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names) + + +def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'): + """Determine the number of samples provided for training and evaluation. + + The number of samples is not defined when running with `steps`, + in which case the number of samples is set to `None`. + + Arguments: + ins: List of tensors to be fed to the Keras function. + batch_size: Integer batch size or `None` if not defined. + steps: Total number of steps (batches of samples) before declaring + `_predict_loop` finished. Ignored with the default value of `None`. + steps_name: The public API's parameter name for `steps`. + + Raises: + ValueError: when `steps` is `None` and the attribute `ins.shape` + does not exist. Also raises ValueError when `steps` is not `None` + and `batch_size` is not `None` because they are mutually + exclusive. + + Returns: + When steps is `None`, returns the number of samples to be + processed based on the size of the first dimension of the + first input numpy array. When steps is not `None` and + `batch_size` is `None`, returns `None`. + """ + if steps is not None and batch_size is not None: + raise ValueError('If ' + steps_name + + ' is set, the `batch_size` must be None.') + if check_steps_argument(ins, steps, steps_name): + return None + + if hasattr(ins[0], 'shape'): + return int(ins[0].shape[0]) + return None # Edge case where ins == [static_learning_phase] + + +def standardize_single_array(x, expected_shape=None): + """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1.""" + if x is None: + return None + + if composite_tensor_utils.is_composite_or_composite_value(x): + return x + + if isinstance(x, int): + raise ValueError( + 'Expected an array data type but received an integer: {}'.format(x)) + + if (x.shape is not None and len(x.shape) == 1 and + (expected_shape is None or len(expected_shape) != 1)): + if tensor_util.is_tensor(x): + x = array_ops.expand_dims(x, axis=1) + else: + x = np.expand_dims(x, 1) + return x + + +def standardize_input_data(data, + names, + shapes=None, + check_batch_axis=True, + exception_prefix=''): + """Normalizes inputs and targets provided by users. + + Users may pass data as a list of arrays, dictionary of arrays, + or as a single array. We normalize this to an ordered list of + arrays (same order as `names`), while checking that the provided + arrays have shapes that match the network's expectations. + + Arguments: + data: User-provided input data (polymorphic). + names: List of expected array names. + shapes: Optional list of expected array shapes. + check_batch_axis: Boolean; whether to check that the batch axis of the + arrays matches the expected value found in `shapes`. + exception_prefix: String prefix used for exception formatting. + + Returns: + List of standardized input arrays (one array per model input). + + Raises: + ValueError: in case of improperly formatted user-provided data. + """ + try: + data_len = len(data) + except TypeError: + # For instance if data is `None` or a symbolic Tensor. + data_len = None + + if not names: + if data_len and not isinstance(data, dict): + raise ValueError( + 'Error when checking model ' + exception_prefix + ': ' + 'expected no data, but got:', data) + return [] + if data is None: + return [None for _ in range(len(names))] + + if isinstance(data, dict): + try: + data = [ + data[x].values + if data[x].__class__.__name__ == 'DataFrame' else data[x] + for x in names + ] + except KeyError as e: + raise ValueError('No data provided for "' + e.args[0] + '". Need data ' + 'for each key in: ' + str(names)) + elif isinstance(data, (list, tuple)): + if isinstance(data[0], (list, tuple)): + data = [np.asarray(d) for d in data] + elif len(names) == 1 and isinstance(data[0], (float, int)): + data = [np.asarray(data)] + else: + data = [ + x.values if x.__class__.__name__ == 'DataFrame' else x for x in data + ] + else: + data = data.values if data.__class__.__name__ == 'DataFrame' else data + data = [data] + + if shapes is not None: + data = [ + standardize_single_array(x, shape) for (x, shape) in zip(data, shapes) + ] + else: + data = [standardize_single_array(x) for x in data] + + if len(data) != len(names): + if data and hasattr(data[0], 'shape'): + raise ValueError('Error when checking model ' + exception_prefix + + ': the list of Numpy arrays that you are passing to ' + 'your model is not the size the model expected. ' + 'Expected to see ' + str(len(names)) + ' array(s), ' + + 'for inputs ' + str(names) + ' but instead got the ' + 'following list of ' + str(len(data)) + ' arrays: ' + + str(data)[:200] + '...') + elif len(names) > 1: + raise ValueError('Error when checking model ' + exception_prefix + + ': you are passing a list as input to your model, ' + 'but the model expects a list of ' + str(len(names)) + + ' Numpy arrays instead. The list you passed was: ' + + str(data)[:200]) + elif len(data) == 1 and not hasattr(data[0], 'shape'): + raise TypeError('Error when checking model ' + exception_prefix + + ': data should be a Numpy array, or list/dict of ' + 'Numpy arrays. Found: ' + str(data)[:200] + '...') + elif len(names) == 1: + data = [np.asarray(data)] + + # Check shapes compatibility. + if shapes: + for i in range(len(names)): + if shapes[i] is not None: + if tensor_util.is_tensor(data[i]): + tensorshape = data[i].shape + if not tensorshape: + continue + data_shape = tuple(tensorshape.as_list()) + elif composite_tensor_utils.is_composite_or_composite_value(data[i]): + tensorshape = composite_tensor_utils.get_shape(data[i]) + data_shape = tuple(tensorshape.as_list()) + else: + data_shape = data[i].shape + + shape = shapes[i] + if len(data_shape) != len(shape): + raise ValueError('Error when checking ' + exception_prefix + + ': expected ' + names[i] + ' to have ' + + str(len(shape)) + ' dimensions, but got array ' + 'with shape ' + str(data_shape)) + if not check_batch_axis: + data_shape = data_shape[1:] + shape = shape[1:] + for dim, ref_dim in zip(data_shape, shape): + if ref_dim != dim and ref_dim is not None and dim is not None: + raise ValueError('Error when checking ' + exception_prefix + + ': expected ' + names[i] + ' to have shape ' + + str(shape) + ' but got array with shape ' + + str(data_shape)) + return data + + +def standardize_sample_or_class_weights(x_weight, output_names, weight_type): + """Maps `sample_weight` or `class_weight` to model outputs. + + Arguments: + x_weight: User-provided `sample_weight` or `class_weight` argument. + output_names: List of output names (strings) in the model. + weight_type: A string used purely for exception printing. + + Returns: + A list of `sample_weight` or `class_weight` where there are exactly + one element per model output. + + Raises: + ValueError: In case of invalid user-provided argument. + """ + if x_weight is None or (isinstance(x_weight, (list, tuple)) and + len(x_weight) == 0): # pylint: disable=g-explicit-length-test + return [None for _ in output_names] + if len(output_names) == 1: + if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1: + return x_weight + if isinstance(x_weight, dict) and output_names[0] in x_weight: + return [x_weight[output_names[0]]] + else: + return [x_weight] + if isinstance(x_weight, (list, tuple)): + if len(x_weight) != len(output_names): + raise ValueError('Provided `' + weight_type + '` was a list of ' + + str(len(x_weight)) + ' elements, but the model has ' + + str(len(output_names)) + ' outputs. ' + 'You should provide one `' + weight_type + '`' + 'array per model output.') + return x_weight + if isinstance(x_weight, collections_abc.Mapping): + generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names) + x_weights = [] + for name in output_names: + x_weights.append(x_weight.get(name)) + return x_weights + else: + raise TypeError('The model has multiple outputs, so `' + weight_type + '` ' + 'should be either a list or a dict. ' + 'Provided `' + weight_type + '` type not understood: ' + + str(x_weight)) + + +def standardize_class_weights(class_weight, output_names): + return standardize_sample_or_class_weights(class_weight, output_names, + 'class_weight') + + +def standardize_sample_weights(sample_weight, output_names): + return standardize_sample_or_class_weights(sample_weight, output_names, + 'sample_weight') + + +def check_array_lengths(inputs, targets, weights=None): + """Does user input validation for numpy arrays. + + Arguments: + inputs: list of Numpy arrays of inputs. + targets: list of Numpy arrays of targets. + weights: list of Numpy arrays of sample weights. + + Raises: + ValueError: in case of incorrectly formatted data. + """ + + def is_tensor_or_composite_tensor(x): + return tensor_util.is_tensor( + x) or composite_tensor_utils.is_composite_or_composite_value(x) + + def set_of_lengths(x): + # Returns a set with the variation between + # different shapes, with None => 0 + if x is None: + return {} + else: + return set([ + y.shape[0] + for y in x + if y is not None and not is_tensor_or_composite_tensor(y) + ]) + + set_x = set_of_lengths(inputs) + set_y = set_of_lengths(targets) + set_w = set_of_lengths(weights) + if len(set_x) > 1: + raise ValueError('All input arrays (x) should have ' + 'the same number of samples. Got array shapes: ' + + str([x.shape for x in inputs])) + if len(set_y) > 1: + raise ValueError('All target arrays (y) should have ' + 'the same number of samples. Got array shapes: ' + + str([y.shape for y in targets])) + if set_x and set_y and list(set_x)[0] != list(set_y)[0]: + raise ValueError('Input arrays should have ' + 'the same number of samples as target arrays. ' + 'Found ' + str(list(set_x)[0]) + ' input samples ' + 'and ' + str(list(set_y)[0]) + ' target samples.') + if len(set_w) > 1: + raise ValueError('All sample_weight arrays should have ' + 'the same number of samples. Got array shapes: ' + + str([w.shape for w in weights])) + if set_y and set_w and list(set_y)[0] != list(set_w)[0]: + raise ValueError('Sample_weight arrays should have ' + 'the same number of samples as target arrays. Got ' + + str(list(set_y)[0]) + ' input samples and ' + + str(list(set_w)[0]) + ' target samples.') + + +def check_loss_and_target_compatibility(targets, loss_fns, output_shapes): + """Does validation on the compatibility of targets and loss functions. + + This helps prevent users from using loss functions incorrectly. This check + is purely for UX purposes. + + Arguments: + targets: list of Numpy arrays of targets. + loss_fns: list of loss functions. + output_shapes: list of shapes of model outputs. + + Raises: + ValueError: if a loss function or target array + is incompatible with an output. + """ + key_loss_fns = { + losses.mean_squared_error, losses.binary_crossentropy, + losses.categorical_crossentropy + } + key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy, + losses.CategoricalCrossentropy) + for y, loss, shape in zip(targets, loss_fns, output_shapes): + if y is None or loss is None or tensor_util.is_tensor(y): + continue + if losses.is_categorical_crossentropy(loss): + if y.shape[-1] == 1: + raise ValueError('You are passing a target array of shape ' + + str(y.shape) + + ' while using as loss `categorical_crossentropy`. ' + '`categorical_crossentropy` expects ' + 'targets to be binary matrices (1s and 0s) ' + 'of shape (samples, classes). ' + 'If your targets are integer classes, ' + 'you can convert them to the expected format via:\n' + '```\n' + 'from keras.utils import to_categorical\n' + 'y_binary = to_categorical(y_int)\n' + '```\n' + '\n' + 'Alternatively, you can use the loss function ' + '`sparse_categorical_crossentropy` instead, ' + 'which does expect integer targets.') + + is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper) + if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and + (loss.fn in key_loss_fns))): + for target_dim, out_dim in zip(y.shape[1:], shape[1:]): + if out_dim is not None and target_dim != out_dim: + loss_name = loss.name + if loss_name is None: + loss_type = loss.fn if is_loss_wrapper else type(loss) + loss_name = loss_type.__name__ + raise ValueError('A target array with shape ' + str(y.shape) + + ' was passed for an output of shape ' + str(shape) + + ' while using as loss `' + loss_name + '`. ' + 'This loss expects targets to have the same shape ' + 'as the output.') + + +def collect_per_output_metric_info(metrics, + output_names, + output_shapes, + loss_fns, + is_weighted=False): + """Maps metric names and functions to model outputs. + + Arguments: + metrics: a list or a list of lists or a dict of metric functions. + output_names: a list of the names (strings) of model outputs. + output_shapes: a list of the shapes (strings) of model outputs. + loss_fns: a list of the loss functions corresponding to the model outputs. + is_weighted: Boolean indicating whether the given metrics are weighted. + + Returns: + A list (one entry per model output) of dicts. + For instance, if the model has 2 outputs, and for the first output + we want to compute "binary_accuracy" and "binary_crossentropy", + and just "binary_accuracy" for the second output, + the list would look like: `[{ + 'acc': binary_accuracy(), + 'ce': binary_crossentropy(), + }, { + 'acc': binary_accuracy(), + }]` + + Raises: + TypeError: if an incorrect type is passed for the `metrics` argument. + """ + if not metrics: + return [{} for _ in output_names] + + if isinstance(metrics, list): + any_sub_list = any(isinstance(m, list) for m in metrics) + if any_sub_list: + if len(metrics) != len(output_names): + raise ValueError('When passing a list of lists as `metrics`, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(output_names)) + + ' outputs, but you passed metrics=' + str(metrics)) + # User has provided a list of len = len(outputs). + nested_metrics = [generic_utils.to_list(m) for m in metrics] + else: + # If it is a single list we then apply all metrics to all outputs. + if len(output_names) > 1: + nested_metrics = [] + for _ in output_names: + nested_metrics.append( + [metrics_module.clone_metric(m) for m in metrics]) + else: + nested_metrics = [metrics] + elif isinstance(metrics, collections_abc.Mapping): + generic_utils.check_for_unexpected_keys('metrics', metrics, output_names) + nested_metrics = [] + for name in output_names: + output_metrics = generic_utils.to_list(metrics.get(name, [])) + nested_metrics.append(output_metrics) + else: + raise TypeError('Type of `metrics` argument not understood. ' + 'Expected a list or dictionary, found: ' + str(metrics)) + + per_output_metrics = [] + for i, metrics in enumerate(nested_metrics): + metrics_dict = collections.OrderedDict() + for metric in metrics: + metric_name = get_metric_name(metric, is_weighted) + metric_fn = get_metric_function( + metric, output_shape=output_shapes[i], loss_fn=loss_fns[i]) + + # If the metric function is not stateful, we create a stateful version. + if not isinstance(metric_fn, metrics_module.Metric): + metric_fn = metrics_module.MeanMetricWrapper( + metric_fn, name=metric_name) + metrics_dict[metric_name] = metric_fn + per_output_metrics.append(metrics_dict) + + return per_output_metrics + + +def batch_shuffle(index_array, batch_size): + """Shuffles an array in a batch-wise fashion. + + Useful for shuffling HDF5 arrays + (where one cannot access arbitrary indices). + + Arguments: + index_array: array of indices to be shuffled. + batch_size: integer. + + Returns: + The `index_array` array, shuffled in a batch-wise fashion. + """ + batch_count = int(len(index_array) / batch_size) + # to reshape we need to be cleanly divisible by batch size + # we stash extra items and reappend them after shuffling + last_batch = index_array[batch_count * batch_size:] + index_array = index_array[:batch_count * batch_size] + index_array = index_array.reshape((batch_count, batch_size)) + np.random.shuffle(index_array) + index_array = index_array.flatten() + return np.append(index_array, last_batch) + + +def standardize_weights(y, + sample_weight=None, + class_weight=None, + sample_weight_mode=None): + """Performs sample weight validation and standardization. + + Everything gets normalized to a single sample-wise (or timestep-wise) + weight array. If both `sample_weight` and `class_weight` are provided, + the weights are multiplied. + + Arguments: + y: Numpy array or Tensor of model targets to be weighted. + sample_weight: User-provided `sample_weight` argument. + class_weight: User-provided `class_weight` argument. + sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated + that we expect 2D weight data that will be applied to the last 2 + dimensions of the targets (i.e. we are weighting timesteps, not + samples). + + Returns: + A numpy array of target weights, one entry per sample to weight. + + Raises: + ValueError: In case of invalid user-provided arguments. + """ + # Iterator may return sample_weight as 1-tuple + if isinstance(sample_weight, tuple): + sample_weight = sample_weight[0] + if sample_weight_mode is not None and sample_weight_mode != 'samplewise': + if sample_weight_mode != 'temporal': + raise ValueError('"sample_weight_mode ' + 'should be None or "temporal". ' + 'Found: ' + str(sample_weight_mode)) + if len(y.shape) < 3: + raise ValueError('Found a sample_weight array for ' + 'an input with shape ' + str(y.shape) + '. ' + 'Timestep-wise sample weighting (use of ' + 'sample_weight_mode="temporal") is restricted to ' + 'outputs that are at least 3D, i.e. that have ' + 'a time dimension.') + if sample_weight is not None and len(sample_weight.shape) != 2: + raise ValueError('Found a sample_weight array with shape ' + + str(sample_weight.shape) + '. ' + 'In order to use timestep-wise sample weighting, ' + 'you should pass a 2D sample_weight array.') + else: + if sample_weight is not None and len(sample_weight.shape) != 1: + raise ValueError( + 'Found a sample_weight array with shape {}. In order to ' + 'use timestep-wise sample weights, you should specify ' + 'sample_weight_mode="temporal" in compile(); founssd "{}" ' + 'instead. If you just mean to use sample-wise weights, ' + 'make sure your sample_weight array is 1D.'.format( + sample_weight.shape, sample_weight_mode)) + + if sample_weight is not None: + if len(sample_weight.shape) > len(y.shape): + raise ValueError('Found a sample_weight with shape' + + str(sample_weight.shape) + '.' + 'Expected sample_weight with rank ' + 'less than or equal to ' + str(len(y.shape))) + + if (not tensor_util.is_tensor(sample_weight) and + y.shape[:sample_weight.ndim] != sample_weight.shape): + raise ValueError('Found a sample_weight array with shape ' + + str(sample_weight.shape) + ' for an input with shape ' + + str(y.shape) + '. ' + 'sample_weight cannot be broadcast.') + + # Class weights applied per-sample. + class_sample_weight = None + if isinstance(class_weight, dict): + if len(y.shape) > 2: + raise ValueError('`class_weight` not supported for ' + '3+ dimensional targets.') + + if tensor_util.is_tensor(y): + # Few classes are expected, so densifying is reasonable. + keys = np.array(sorted(class_weight.keys())) + values = np.array([class_weight[i] for i in keys]) + weight_vector = np.zeros(np.max(keys) + 1) + weight_vector[:] = np.nan + weight_vector[keys] = values + + y_classes = smart_cond.smart_cond( + len(y.shape.as_list()) == 2 and K.shape(y)[1] > 1, + lambda: K.argmax(y, axis=1), + lambda: math_ops.cast(K.reshape(y, (-1,)), dtypes.int64)) + class_sample_weight = array_ops.gather(weight_vector, y_classes) + gen_array_ops.check_numerics( + class_sample_weight, + 'Invalid classes or class weights detected. NaN values indicate that ' + 'an appropriate class weight could not be determined.') + class_sample_weight = math_ops.cast(class_sample_weight, K.floatx()) + if sample_weight is not None: + sample_weight = math_ops.cast( + ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx()) + else: + y_classes = y + if len(y.shape) == 2: + if y.shape[1] > 1: + y_classes = np.argmax(y, axis=1) + elif y.shape[1] == 1: + y_classes = np.reshape(y, y.shape[0]) + + class_sample_weight = np.asarray( + [class_weight[cls] for cls in y_classes if cls in class_weight]) + + if len(class_sample_weight) != len(y_classes): + # subtract the sets to pick all missing classes + existing_classes = set(y_classes) + existing_class_weight = set(class_weight.keys()) + raise ValueError( + '`class_weight` must contain all classes in the data.' + ' The classes %s exist in the data but not in ' + '`class_weight`.' % (existing_classes - existing_class_weight)) + + if class_sample_weight is not None and sample_weight is not None: + # Multiply weights if both are provided. + return class_sample_weight * sample_weight + if sample_weight is not None: + return sample_weight + if class_sample_weight is not None: + return class_sample_weight + return None + + +def has_symbolic_tensors(ls): + if context.executing_eagerly(): + return False + return has_tensors(ls) + + +def has_tensors(ls): + """Returns true if `ls` contains tensors.""" + # Note: at some point in time ragged tensors didn't count as tensors, so this + # returned false for ragged tensors. Making this return true fails some tests + # which would then require a steps_per_epoch argument. + if isinstance(ls, (list, tuple)): + return any( + tensor_util.is_tensor(v) and + not isinstance(v, ragged_tensor.RaggedTensor) for v in ls) + if isinstance(ls, dict): + return any( + tensor_util.is_tensor(v) and + not isinstance(v, ragged_tensor.RaggedTensor) + for _, v in six.iteritems(ls)) + return tensor_util.is_tensor(ls) and not isinstance( + ls, ragged_tensor.RaggedTensor) + + +def get_metric_name(metric, weighted=False): + """Returns the name corresponding to the given metric input. + + Arguments: + metric: Metric function name or reference. + weighted: Boolean indicating if the given metric is weighted. + + Returns: + The metric name. + """ + if tf2.enabled(): + # We keep the string that the user has set in compile as the metric name. + if isinstance(metric, six.string_types): + return metric + + metric = metrics_module.get(metric) + return metric.name if hasattr(metric, 'name') else metric.__name__ + else: + metric_name_prefix = 'weighted_' if weighted else '' + if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): + if metric in ('accuracy', 'acc'): + suffix = 'acc' + elif metric in ('crossentropy', 'ce'): + suffix = 'ce' + else: + metric_fn = metrics_module.get(metric) + # Get metric name as string + if hasattr(metric_fn, 'name'): + suffix = metric_fn.name + else: + suffix = metric_fn.__name__ + metric_name = metric_name_prefix + suffix + return metric_name + + +def get_metric_function(metric, output_shape=None, loss_fn=None): + """Returns the metric function corresponding to the given metric input. + + Arguments: + metric: Metric function name or reference. + output_shape: The shape of the output that this metric will be calculated + for. + loss_fn: The loss function used. + + Returns: + The metric function. + """ + if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']: + return metrics_module.get(metric) + + is_sparse_categorical_crossentropy = ( + isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or + (isinstance(loss_fn, losses.LossFunctionWrapper) and + loss_fn.fn == losses.sparse_categorical_crossentropy)) + + is_binary_crossentropy = ( + isinstance(loss_fn, losses.BinaryCrossentropy) or + (isinstance(loss_fn, losses.LossFunctionWrapper) and + loss_fn.fn == losses.binary_crossentropy)) + + if metric in ['accuracy', 'acc']: + if output_shape[-1] == 1 or is_binary_crossentropy: + return metrics_module.binary_accuracy + elif is_sparse_categorical_crossentropy: + return metrics_module.sparse_categorical_accuracy + # If the output_shape[-1] is not 1, then we know output is `categorical`. + # We assume it is sparse categorical only if loss is explicitly given + # as sparse categorical crossentropy loss. + return metrics_module.categorical_accuracy + else: + if output_shape[-1] == 1 or is_binary_crossentropy: + return metrics_module.binary_crossentropy + elif is_sparse_categorical_crossentropy: + return metrics_module.sparse_categorical_crossentropy + return metrics_module.categorical_crossentropy + + +def call_metric_function(metric_fn, + y_true, + y_pred=None, + weights=None, + mask=None): + """Invokes metric function and returns the metric result tensor.""" + if mask is not None: + mask = math_ops.cast(mask, y_pred.dtype) + if weights is None: + # Use mask as sample weight. + weights = mask + else: + # Update dimensions of weights to match with mask. + weights = math_ops.cast(weights, dtype=y_pred.dtype) + mask, _, weights = losses_utils.squeeze_or_expand_dimensions( + mask, sample_weight=weights) + weights *= mask + + if y_pred is not None: + return metric_fn(y_true, y_pred, sample_weight=weights) + # `Mean` metric only takes a single value. + return metric_fn(y_true, sample_weight=weights) + + +def get_loss_function(loss): + """Returns the loss corresponding to the loss input in `compile` API.""" + if loss is None or isinstance(loss, losses.Loss): + return loss + + if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss): + # It is not safe to assume that the loss takes no constructor arguments. + raise ValueError( + 'Received uninstantiated Loss class: {}\nPlease call loss ""classes ' + 'before passing them to Model.compile.'.format(loss)) + + # Deserialize loss configuration, if needed. + if isinstance(loss, collections_abc.Mapping): + loss = losses.get(loss) + + # Custom callable class. + if callable(loss) and not hasattr(loss, '__name__'): + return loss + + # Wrap loss function with signature `(y_true, y_pred, **kwargs)` + # in `LossFunctionWrapper` class. + loss_fn = losses.get(loss) + + # For losses which are given as strings/functions in the compile API, + # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE` + # (both in distribution strategy context and otherwise). + return losses.LossFunctionWrapper( + loss_fn, + name=loss_fn.__name__, + reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE) + + +def validate_dataset_input(x, y, sample_weight, validation_split=None): + """Validates user input arguments when a dataset iterator is passed. + + Arguments: + x: Input data. A `tf.data` dataset or iterator. + y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s). + Expected to be `None` when `x` is a dataset iterator. + sample_weight: An optional sample-weight array passed by the user to weight + the importance of each sample in `x`. Expected to be `None` when `x` is a + dataset iterator + validation_split: Float between 0 and 1. Fraction of the training data to be + used as validation data. Expected to be `None` when `x` is a dataset + iterator. + + Raises: + ValueError: if argument `y` or `sample_weight` or `validation_split` are + provided by user. + """ + if y is not None: + raise ValueError('You passed a dataset or dataset iterator (%s) as ' + 'input `x` to your model. In that case, you should ' + 'not specify a target (`y`) argument, since the dataset ' + 'or dataset iterator generates both input data and ' + 'target data. ' + 'Received: %s' % (x, y)) + if sample_weight is not None: + raise ValueError('`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator. Instead, you' + 'can provide sample_weight as the third element of your' + 'dataset, i.e. (inputs, targets, sample_weight). ' + 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) + if validation_split is not None and validation_split != 0.0: + raise ValueError( + '`validation_split` argument is not supported when ' + 'input `x` is a dataset or a dataset iterator. ' + 'Received: x=%s, validation_split=%f' % (x, validation_split)) + + +def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'): + """Helper function to validate either inputs or targets.""" + if isinstance(inp, (list, tuple)): + if not all(isinstance(v, np.ndarray) or + tensor_util.is_tensor(v) for v in inp): + raise ValueError( + 'Please provide as model inputs either a single array or a list of ' + 'arrays. You passed: {}={}'.format(field_name, str(orig_inp))) + elif isinstance(inp, dict): + if not allow_dict: + raise ValueError( + 'You cannot pass a dictionary as model {}.'.format(field_name)) + elif not isinstance(inp, np.ndarray) and not tensor_util.is_tensor(inp): + raise ValueError( + 'Please provide as model inputs either a single array or a list of ' + 'arrays. You passed: {}={}'.format(field_name, orig_inp)) + + +def check_generator_arguments(y=None, sample_weight=None, + validation_split=None): + """Validates arguments passed when using a generator.""" + if y is not None: + raise ValueError('`y` argument is not supported when data is' + 'a generator or Sequence instance. Instead pass targets' + ' as the second element of the generator.') + if sample_weight is not None: + raise ValueError('`sample_weight` argument is not supported when data is' + 'a generator or Sequence instance. Instead pass sample' + ' weights as the third element of the generator.') + if validation_split: + raise ValueError('If your data is in the form of a Python generator, ' + 'you cannot use `validation_split`.') + + +def check_steps_argument(input_data, steps, steps_name): + """Validates `steps` argument based on input data's type. + + The cases when `steps` value must be provided are when + 1. input data passed is an iterator. + 2. model was built on top of symbolic tensors, input data is not + required and is `None`. + 3. input data passed is a symbolic tensor. + + Arguments: + input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or + tf.data.Dataset iterator or `None`. + steps: Integer or `None`. Total number of steps (batches of samples) to + execute. + steps_name: The public API's parameter name for `steps`. + + Returns: + boolean, True if `steps` argument is required, else False. + + Raises: + ValueError: if `steps` argument is required for given input data type + but not provided. + """ + is_x_iterator = isinstance( + input_data, (iterator_ops.Iterator, iterator_ops.OwnedIterator)) + if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or + (isinstance(input_data, list) and not input_data)): + if steps is None: + input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors' + raise ValueError('When using {input_type} as input to a model, you should' + ' specify the `{steps_name}` argument.'.format( + input_type=input_type_str, steps_name=steps_name)) + return True + + if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): + return True + + if steps is not None: + list_types = (np.ndarray, list, tuple) + if (isinstance(input_data, list_types) or + (isinstance(input_data, dict) and + any(isinstance(v, list_types) for v in input_data.values()))): + logging.warning('When passing input data as arrays, do not specify ' + '`steps_per_epoch`/`steps` argument. ' + 'Please use `batch_size` instead.') + return False + + +def cast_single_tensor(x, dtype=None): + if isinstance(x, np.ndarray): + x = ops.convert_to_tensor_v2_with_dispatch(x) + dtype = dtype or K.floatx() + if x.dtype.is_floating: + return math_ops.cast(x, dtype=dtype) + return x + + +def cast_if_floating_dtype_and_mismatch(targets, outputs): + """Returns target data tensors using correct datatype. + + Checks that each target and output pair are the same datatype. If not, casts + the target to the output's datatype. + + Args: + targets: tensor or list of targets. + outputs: tensor or list of outputs. + + Returns: + Targets in appropriate datatype. + """ + if tensor_util.is_tensor(targets): + # There is one target, so output[0] should be the only output. + return cast_single_tensor(targets, dtype=outputs[0].dtype) + new_targets = [] + for target, out in zip(targets, outputs): + if isinstance(target, np.ndarray): + target = ops.convert_to_tensor_v2_with_dispatch(target) + if target.dtype != out.dtype: + new_targets.append(cast_single_tensor(target, dtype=out.dtype)) + else: + new_targets.append(target) + return new_targets + + +def cast_if_floating_dtype(x, dtype=None): + """Casts the given data tensors to the default floating point type. + + Casts only if the input is already a floating point type. + Args: + x: tensor or list/tuple of tensors. + dtype: The dtype to which Tensors should be cast. + + Returns: + Converted input. + """ + return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype), + x) + + +def cast_to_model_input_dtypes(x, model): + """Casts the given data tensors to the dtypes of the model inputs. + + Args: + x: tensor or list/tuple of tensors. + model: The model. + + Returns: + Converted input. Each tensor is casted to the corresponding input in + `model.inputs`. + """ + input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs) + return nest.map_structure(math_ops.cast, x, input_dtypes) + + +def prepare_sample_weight_modes(training_endpoints, sample_weight_mode): + """Prepares sample weight modes for the model. + + Args: + training_endpoints: List of model _TrainingEndpoints. + sample_weight_mode: sample weight mode user input passed from compile API. + + Raises: + ValueError: In case of invalid `sample_weight_mode` input. + """ + + if isinstance(sample_weight_mode, collections_abc.Mapping): + generic_utils.check_for_unexpected_keys( + 'sample_weight_mode', sample_weight_mode, + [e.output_name for e in training_endpoints]) + + for end_point in training_endpoints: + if not end_point.should_skip_target_weights(): + if end_point.output_name not in sample_weight_mode: + raise ValueError('Output ' + end_point.output_name + + 'missing from `_sample_weight_modes` dictionary') + else: + end_point.sample_weight_mode = sample_weight_mode.get( + end_point.output_name) + elif isinstance(sample_weight_mode, (list, tuple)): + if len(sample_weight_mode) != len(training_endpoints): + raise ValueError('When passing a list as sample_weight_mode, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(training_endpoints)) + + ' outputs, but you passed ' + + str(len(sample_weight_mode)) + '_sample_weight_modes.') + for mode, endpoint in zip(sample_weight_mode, training_endpoints): + if not endpoint.should_skip_target_weights(): + endpoint.sample_weight_mode = mode + else: + for endpoint in training_endpoints: + if not endpoint.should_skip_target_weights(): + endpoint.sample_weight_mode = sample_weight_mode + + +def prepare_loss_functions(loss, output_names): + """Converts loss to a list of loss functions. + + Arguments: + loss: String (name of objective function), objective function or + `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple + outputs, you can use a different loss on each output by passing a + dictionary or a list of losses. The loss value that will be minimized by + the model will then be the sum of all individual losses. + output_names: List of model output names. + + Returns: + A list of loss objective functions. + + Raises: + ValueError: If loss is a dict with keys not in model output names, + or if loss is a list with len not equal to model outputs. + """ + if isinstance(loss, collections_abc.Mapping): + generic_utils.check_for_unexpected_keys('loss', loss, output_names) + loss_functions = [] + for name in output_names: + if name not in loss: + logging.warning( + 'Output {0} missing from loss dictionary. We assume ' + 'this was done on purpose. The fit and evaluate APIs will not be ' + 'expecting any data to be passed to {0}.'.format(name)) + loss_functions.append(get_loss_function(loss.get(name, None))) + elif isinstance(loss, six.string_types): + loss_functions = [get_loss_function(loss) for _ in output_names] + elif isinstance(loss, collections_abc.Sequence): + if len(loss) != len(output_names): + raise ValueError('When passing a list as loss, it should have one entry ' + 'per model outputs. The model has {} outputs, but you ' + 'passed loss={}'.format(len(output_names), loss)) + loss_functions = nest.map_structure(get_loss_function, loss) + else: + loss_functions = [get_loss_function(loss) for _ in range(len(output_names))] + + return loss_functions + + +def prepare_loss_weights(training_endpoints, loss_weights=None): + """Converts loss weights to a list of loss weights. + + The result loss weights will be populated on the training endpoint. + + Arguments: + training_endpoints: List of model training endpoints. + loss_weights: Optional list or dictionary specifying scalar coefficients + (Python floats) to weight the loss contributions of different model + outputs. The loss value that will be minimized by the model will then be + the *weighted sum* of all individual losses, weighted by the + `loss_weights` coefficients. If a list, it is expected to have a 1:1 + mapping to the model's outputs. If a dict, it is expected to map + output names (strings) to scalar coefficients. + + Raises: + ValueError: If loss weight is a dict with key not in model output names, + or if loss is a list with len not equal to model outputs. + """ + if loss_weights is None: + for e in training_endpoints: + e.loss_weight = 1. + elif isinstance(loss_weights, collections_abc.Mapping): + generic_utils.check_for_unexpected_keys( + 'loss_weights', loss_weights, + [e.output_name for e in training_endpoints]) + for e in training_endpoints: + e.loss_weight = loss_weights.get(e.output_name, 1.) + elif isinstance(loss_weights, list): + if len(loss_weights) != len(training_endpoints): + raise ValueError('When passing a list as loss_weights, ' + 'it should have one entry per model output. ' + 'The model has ' + str(len(training_endpoints)) + + ' outputs, but you passed loss_weights=' + + str(loss_weights)) + for w, e in zip(loss_weights, training_endpoints): + e.loss_weight = w + else: + raise TypeError('Could not interpret loss_weights argument: ' + + str(loss_weights) + ' - expected a list of dicts.') + + +# TODO(rohanj): This is a hack to get around not depending on feature_column and +# create a cyclical dependency. Figure out a cleaner solution +def is_feature_layer(layer): + """Returns whether `layer` is a FeatureLayer or not.""" + return getattr(layer, '_is_feature_layer', False) + + +def is_eager_dataset_or_iterator(data): + return context.executing_eagerly() and isinstance( + data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, + iterator_ops.OwnedIterator)) + + +# pylint: disable=protected-access +def get_dataset_graph_def(dataset): + if context.executing_eagerly(): + graph_def_str = dataset._as_serialized_graph().numpy() + else: + graph_def_str = K.get_value(dataset._as_serialized_graph()) + return graph_pb2.GraphDef().FromString(graph_def_str) + + +def verify_dataset_shuffled(x): + """Verifies that the dataset is shuffled. + + Args: + x: Dataset passed as an input to the model. + + Returns: + boolean, whether the input dataset is shuffled or not. + """ + assert isinstance(x, dataset_ops.DatasetV2) + graph_def = get_dataset_graph_def(x) + for node in graph_def.node: + if node.op.startswith('ShuffleDataset'): + return True + # Also check graph_def.library.function for ds.interleave or ds.flat_map + for function in graph_def.library.function: + for node in function.node_def: + if node.op.startswith('ShuffleDataset'): + return True + logging.warning('Expected a shuffled dataset but input dataset `x` is ' + 'not shuffled. Please invoke `shuffle()` on input dataset.') + return False + + +def is_dataset_or_iterator(data): + return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2, + iterator_ops.Iterator, iterator_ops.OwnedIterator)) + + +def get_iterator(dataset): + """Create and initialize an iterator from a dataset.""" + if context.executing_eagerly(): + iterator = dataset_ops.make_one_shot_iterator(dataset) + else: + iterator = dataset_ops.make_initializable_iterator(dataset) + initialize_iterator(iterator) + return iterator + + +def initialize_iterator(iterator): + if not context.executing_eagerly(): + init_op = iterator.initializer + K.get_session((init_op,)).run(init_op) + + +def extract_tensors_from_dataset(dataset): + """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset. + + Arguments: + dataset: Dataset instance. + + Returns: + Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. + """ + iterator = get_iterator(dataset) + inputs, targets, sample_weight = unpack_iterator_input(iterator) + return inputs, targets, sample_weight + + +def unpack_iterator_input(iterator): + """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`. + + Arguments: + iterator: Instance of a dataset iterator. + + Returns: + Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None. + """ + try: + next_element = iterator.get_next() + except errors.OutOfRangeError: + raise RuntimeError('Your dataset iterator ran out of data; ' + 'Make sure that your dataset can generate ' + 'required number of samples.') + + if isinstance(next_element, (list, tuple)): + if len(next_element) not in [2, 3]: + raise ValueError( + 'Please provide model inputs as a list or tuple of 2 or 3 ' + 'elements: (input, target) or (input, target, sample_weights) ' + 'Received %s' % next_element) + if len(next_element) == 2: + x, y = next_element + weights = None + else: + x, y, weights = next_element + else: + x = next_element + y = None + weights = None + return x, y, weights + + +def infer_steps_for_dataset(model, + dataset, + steps, + epochs=1, + steps_name='steps'): + """Infers steps_per_epoch needed to loop through a dataset. + + Arguments: + model: Keras model instance. + dataset: Input data of type tf.data.Dataset. + steps: Number of steps to draw from the dataset (may be None if unknown). + epochs: Number of times to iterate over the dataset. + steps_name: The string name of the steps argument, either `steps`, + `validation_steps`, or `steps_per_epoch`. Only used for error message + formatting. + + Returns: + Integer or `None`. Inferred number of steps to loop through the dataset. + `None` is returned if 1) the size of the dataset is unknown and `steps` was + not specified, or 2) this is multi-worker training and auto sharding is + enabled. + + Raises: + ValueError: In case of invalid argument values. + """ + assert isinstance(dataset, dataset_ops.DatasetV2) + if (model._in_multi_worker_mode() and + (dataset.options().experimental_distribute.auto_shard_policy != + AutoShardPolicy.OFF)): + # If the dataset would be auto-sharded, we should not infer a local + # steps_per_epoch due to the possible inbalanced sharding between workers. + return None + + size = K.get_value(cardinality.cardinality(dataset)) + if size == cardinality.INFINITE and steps is None: + raise ValueError('When passing an infinitely repeating dataset, you ' + 'must specify the `%s` argument.' % (steps_name,)) + if size >= 0: + if steps is not None and steps * epochs > size: + if epochs > 1: + raise ValueError('The dataset you passed contains %s batches, but you ' + 'passed `epochs=%s` and `%s=%s`, which is a total of ' + '%s steps. We cannot draw that many steps from this ' + 'dataset. We suggest to set `%s=%s`.' % + (size, epochs, steps_name, steps, steps * epochs, + steps_name, size // epochs)) + else: + raise ValueError('The dataset you passed contains %s batches, but you ' + 'passed `%s=%s`. We cannot draw that many steps from ' + 'this dataset. We suggest to set `%s=%s`.' % + (size, steps_name, steps, steps_name, size)) + if steps is None: + if size >= 0: + return size + return None + return steps + + +class ModelInputs(object): + """Encapsulates model inputs. + + Allows for transforming model inputs while keeping the same structure. + """ + + def __init__(self, inputs): + self._inputs = inputs + self._is_dict = isinstance(self._inputs, dict) + self._is_single_input = not isinstance(self._inputs, (list, tuple, dict)) + + self._flattened_inputs = [] + self._input_names = [] + + if self._is_dict: + for k in sorted(self._inputs.keys()): + self._flattened_inputs.append(self._inputs[k]) + self._input_names.append(k) + else: + self._flattened_inputs = nest.flatten(self._inputs) + self._input_names = [ + 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs)) + ] + + def get_input_names(self): + """Returns keys to name inputs by. + + In case inputs provided were a list, tuple or single entry, we make up a + key 'input_%d'. For dictionary case, we return a sorted list of keys. + """ + return self._input_names + + def get_symbolic_inputs(self, return_single_as_list=False): + """Returns inputs to be set as self.inputs for a model.""" + # TODO(karmel): There is a side-effect here where what you get + # with as_list and as_dict depends on whether you have called this + # method first, since it modifies in place. + for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)): + if isinstance(v, (list, float, int)): + v = np.asarray(v) + if v.ndim == 1: + v = np.expand_dims(v, 1) + + if isinstance(v, np.ndarray): + # We fix the placeholder shape except the batch size. + # This is suboptimal, but it is the best we can do with the info + # we have. The user should call `model._set_inputs(placeholders)` + # to specify custom placeholders if the need arises. + shape = (None,) + tuple(v.shape[1:]) + if shape == (None,): + shape = (None, 1) + dtype = dtypes.as_dtype(v.dtype) + if dtype.is_floating: + dtype = K.floatx() + v = K.placeholder(shape=shape, name=k, dtype=dtype) + elif isinstance(v, tensor_spec.TensorSpec): + shape = (None,) + tuple(v.shape.as_list()[1:]) + if shape == (None,): + shape = (None, 1) + v = K.placeholder(shape=shape, name=k, dtype=v.dtype) + + self._flattened_inputs[i] = v + + if self._is_dict: + return dict(zip(self._input_names, self._flattened_inputs)) + if self._is_single_input and not return_single_as_list: + return self._flattened_inputs[0] + return self._flattened_inputs + + def as_dict(self): + """An iterable over a dictionary version of inputs.""" + for k, v in zip(self._input_names, self._flattened_inputs): + yield k, v + + def as_list(self): + """Returning the inputs as a list.""" + return self._flattened_inputs + + +# Allow use of methods not exposed to the user. +# pylint: disable=protected-access + + +# pylint: enable=protected-access + + +def generic_output_names(outputs_list): + return ['output_%d' % (i + 1) for i in range(len(outputs_list))] + + +def should_run_validation(validation_freq, epoch): + """Checks if validation should be run this epoch. + + Arguments: + validation_freq: Integer or list. If an integer, specifies how many training + epochs to run before a new validation run is performed. If a list, + specifies the epochs on which to run validation. + epoch: Integer, the number of the training epoch just completed. + + Returns: + Bool, True if validation should be run. + + Raises: + ValueError: if `validation_freq` is an Integer and less than 1, or if + it is neither an Integer nor a Sequence. + """ + # `epoch` is 0-indexed internally but 1-indexed in the public API. + one_indexed_epoch = epoch + 1 + + if isinstance(validation_freq, int): + if validation_freq < 1: + raise ValueError('`validation_freq` can not be less than 1.') + return one_indexed_epoch % validation_freq == 0 + + if not isinstance(validation_freq, collections_abc.Container): + raise ValueError('`validation_freq` must be an Integer or ' + '`collections_abc.Container` (e.g. list, tuple, etc.)') + return one_indexed_epoch in validation_freq + + +def split_training_and_validation_data(x, y, sample_weights, validation_split): + """Split input data into train/eval section based on validation_split.""" + if has_symbolic_tensors(x): + raise ValueError('If your data is in the form of symbolic tensors, ' + 'you cannot use `validation_split`.') + if hasattr(x[0], 'shape'): + split_at = int(x[0].shape[0] * (1. - validation_split)) + else: + split_at = int(len(x[0]) * (1. - validation_split)) + x, val_x = (generic_utils.slice_arrays(x, 0, split_at), + generic_utils.slice_arrays(x, split_at)) + y, val_y = (generic_utils.slice_arrays(y, 0, split_at), + generic_utils.slice_arrays(y, split_at)) + if sample_weights: + sample_weights, val_sample_weights = ( + generic_utils.slice_arrays(sample_weights, 0, split_at), + generic_utils.slice_arrays(sample_weights, split_at), + ) + else: + val_sample_weights = None + return x, y, sample_weights, val_x, val_y, val_sample_weights + + +def unpack_validation_data(validation_data, raise_if_ambiguous=True): + """Unpack validation data based input type. + + The validation data is not touched if its dataset or dataset iterator. + For other type of input (Numpy or tensor), it will be unpacked into tuple of + 3 which is x, y and sample weights. + + Args: + validation_data: dataset, dataset iterator, or numpy, tensor tuple. + raise_if_ambiguous: boolean on whether to fail if validation_data cannot be + parsed. Otherwise simply return validation_data, None, None and defer the + decision to the caller. + + Returns: + tuple of 3, (x, y, sample_weights) for numpy and tensor input. + """ + if (isinstance(validation_data, (iterator_ops.Iterator, + iterator_ops.OwnedIterator, + dataset_ops.DatasetV2, + data_utils.Sequence)) + or not hasattr(validation_data, '__len__')): + val_x = validation_data + val_y = None + val_sample_weight = None + elif len(validation_data) == 2: + try: + val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence + val_sample_weight = None + except ValueError: + val_x, val_y, val_sample_weight = validation_data, None, None + elif len(validation_data) == 3: + try: + val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence + except ValueError: + val_x, val_y, val_sample_weight = validation_data, None, None + else: + if raise_if_ambiguous: + raise ValueError( + 'When passing a `validation_data` argument, ' + 'it must contain either 2 items (x_val, y_val), ' + 'or 3 items (x_val, y_val, val_sample_weights), ' + 'or alternatively it could be a dataset or a ' + 'dataset or a dataset iterator. ' + 'However we received `validation_data=%s`' % validation_data) + val_x, val_y, val_sample_weight = validation_data, None, None + return val_x, val_y, val_sample_weight + + +class TrainingLoop(object): + """TrainingLoop is a wrapper class around the training logic. + + This class is trying to encapsulate the different logic of fit/eval/predict + with regard to different data input and model condition. + + Note that TrainingLoop is stateless, which means it doesn't contain any + internal field and can be reused with different model and inputs. + """ + + def fit(self, + model, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0., + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_freq=1, + **kwargs): + """Train the model with the inputs and targets.""" + raise NotImplementedError() + + def evaluate(self, + model, + x=None, + y=None, + batch_size=None, + verbose=1, + sample_weight=None, + steps=None, + callbacks=None, + **kwargs): + """Returns the loss value & metrics values for the model in test mode.""" + raise NotImplementedError() + + def predict(self, + model, + x, + batch_size=None, + verbose=0, + steps=None, + callbacks=None, + **kwargs): + raise NotImplementedError() diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_v1_test.py similarity index 83% rename from tensorflow/python/keras/engine/training_utils_test.py rename to tensorflow/python/keras/engine/training_utils_v1_test.py index 8a3fd3926cf..64d44cb7955 100644 --- a/tensorflow/python/keras/engine/training_utils_test.py +++ b/tensorflow/python/keras/engine/training_utils_v1_test.py @@ -35,7 +35,7 @@ from tensorflow.python.keras import backend from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import keras_tensor -from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.utils import tf_utils from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging @@ -45,7 +45,7 @@ class ModelInputsTest(test.TestCase): def test_single_thing(self): a = np.ones(10) - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['input_1'], model_inputs.get_input_names()) vals = model_inputs.get_symbolic_inputs() self.assertTrue(tensor_util.is_tensor(vals)) @@ -59,7 +59,7 @@ class ModelInputsTest(test.TestCase): self.skipTest('Run in eager mode only.') with testing_utils.use_keras_tensors_scope(False): a = np.ones(10, dtype=np.int32) - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['input_1'], model_inputs.get_input_names()) val = model_inputs.get_symbolic_inputs() self.assertTrue(tf_utils.is_symbolic_tensor(val)) @@ -69,7 +69,7 @@ class ModelInputsTest(test.TestCase): self.assertEqual(dtypes.int32, vals[0].dtype) with testing_utils.use_keras_tensors_scope(True): a = np.ones(10, dtype=np.int32) - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['input_1'], model_inputs.get_input_names()) val = model_inputs.get_symbolic_inputs() self.assertIsInstance(val, keras_tensor.KerasTensor) @@ -80,7 +80,7 @@ class ModelInputsTest(test.TestCase): def test_list(self): a = [np.ones(10), np.ones(20)] - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names()) vals = model_inputs.get_symbolic_inputs() self.assertTrue(tensor_util.is_tensor(vals[0])) @@ -91,14 +91,14 @@ class ModelInputsTest(test.TestCase): self.skipTest('Run in eager mode only.') with testing_utils.use_keras_tensors_scope(False): a = [np.ones(10), np.ones(20)] - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names()) vals = model_inputs.get_symbolic_inputs() self.assertTrue(tf_utils.is_symbolic_tensor(vals[0])) self.assertTrue(tf_utils.is_symbolic_tensor(vals[1])) with testing_utils.use_keras_tensors_scope(True): a = [np.ones(10), np.ones(20)] - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['input_1', 'input_2'], model_inputs.get_input_names()) vals = model_inputs.get_symbolic_inputs() self.assertIsInstance(vals[0], keras_tensor.KerasTensor) @@ -106,7 +106,7 @@ class ModelInputsTest(test.TestCase): def test_dict(self): a = {'b': np.ones(10), 'a': np.ones(20)} - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['a', 'b'], model_inputs.get_input_names()) vals = model_inputs.get_symbolic_inputs() self.assertTrue(tensor_util.is_tensor(vals['a'])) @@ -117,14 +117,14 @@ class ModelInputsTest(test.TestCase): self.skipTest('Run in eager mode only.') with testing_utils.use_keras_tensors_scope(False): a = {'b': np.ones(10), 'a': np.ones(20)} - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['a', 'b'], model_inputs.get_input_names()) vals = model_inputs.get_symbolic_inputs() self.assertTrue(tf_utils.is_symbolic_tensor(vals['a'])) self.assertTrue(tf_utils.is_symbolic_tensor(vals['b'])) with testing_utils.use_keras_tensors_scope(True): a = {'b': np.ones(10), 'a': np.ones(20)} - model_inputs = training_utils.ModelInputs(a) + model_inputs = training_utils_v1.ModelInputs(a) self.assertEqual(['a', 'b'], model_inputs.get_input_names()) vals = model_inputs.get_symbolic_inputs() self.assertIsInstance(vals['a'], keras_tensor.KerasTensor) @@ -182,12 +182,12 @@ class DatasetUtilsTest(test.TestCase, parameterized.TestCase): if not expect_shuffled: with test.mock.patch.object(logging, 'warning') as mock_log: - shuffled = training_utils.verify_dataset_shuffled(dataset) + shuffled = training_utils_v1.verify_dataset_shuffled(dataset) self.assertRegex( str(mock_log.call_args), 'input dataset `x` is not shuffled.') self.assertFalse(shuffled) else: - self.assertTrue(training_utils.verify_dataset_shuffled(dataset)) + self.assertTrue(training_utils_v1.verify_dataset_shuffled(dataset)) class StandardizeWeightsTest(keras_parameterized.TestCase): @@ -195,21 +195,22 @@ class StandardizeWeightsTest(keras_parameterized.TestCase): def test_sample_weights(self): y = np.array([0, 1, 0, 0, 2]) sample_weights = np.array([0.5, 1., 1., 0., 2.]) - weights = training_utils.standardize_weights(y, sample_weights) + weights = training_utils_v1.standardize_weights(y, sample_weights) self.assertAllClose(weights, sample_weights) def test_class_weights(self): y = np.array([0, 1, 0, 0, 2]) class_weights = {0: 0.5, 1: 1., 2: 1.5} - weights = training_utils.standardize_weights(y, class_weight=class_weights) + weights = training_utils_v1.standardize_weights( + y, class_weight=class_weights) self.assertAllClose(weights, np.array([0.5, 1., 0.5, 0.5, 1.5])) def test_sample_weights_and_class_weights(self): y = np.array([0, 1, 0, 0, 2]) sample_weights = np.array([0.5, 1., 1., 0., 2.]) class_weights = {0: 0.5, 1: 1., 2: 1.5} - weights = training_utils.standardize_weights(y, sample_weights, - class_weights) + weights = training_utils_v1.standardize_weights(y, sample_weights, + class_weights) expected = sample_weights * np.array([0.5, 1., 0.5, 0.5, 1.5]) self.assertAllClose(weights, expected) @@ -276,32 +277,35 @@ class AggregationTest(keras_parameterized.TestCase): def setUp(self): super(AggregationTest, self).setUp() - self._old_pool = training_utils._COPY_POOL - self._old_threshold = training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD - self._old_timeout = training_utils.SliceAggregator._MAX_COPY_SECONDS - training_utils._COPY_POOL = MonitoredPool(training_utils._COPY_THREADS) + self._old_pool = training_utils_v1._COPY_POOL + self._old_threshold = ( + training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD) + self._old_timeout = training_utils_v1.SliceAggregator._MAX_COPY_SECONDS + training_utils_v1._COPY_POOL = MonitoredPool( + training_utils_v1._COPY_THREADS) def tearDown(self): super(AggregationTest, self).tearDown() - training_utils._COPY_POOL = self._old_pool - training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = self._old_threshold - training_utils.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout + training_utils_v1._COPY_POOL = self._old_pool + training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = ( + self._old_threshold) + training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = self._old_timeout def _run_with_steps(self): - aggregator = training_utils.OutputsAggregator(use_steps=True) + aggregator = training_utils_v1.OutputsAggregator(use_steps=True) for i, batch in enumerate(np.array_split(_TEST_DATA, 4)): if i == 0: aggregator.create(batch) aggregator.aggregate(batch) assert len(aggregator.results) == 1 - assert isinstance(aggregator.results[0], training_utils.ConcatAggregator) + assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator) aggregator.finalize() return aggregator.results def _run_without_steps(self): - aggregator = training_utils.OutputsAggregator( + aggregator = training_utils_v1.OutputsAggregator( use_steps=False, num_samples=6) batch_start = 0 @@ -314,7 +318,7 @@ class AggregationTest(keras_parameterized.TestCase): batch_start = batch_end assert len(aggregator.results) == 1 - assert isinstance(aggregator.results[0], training_utils.SliceAggregator) + assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator) aggregator.finalize() return aggregator.results @@ -326,7 +330,7 @@ class AggregationTest(keras_parameterized.TestCase): self.assertAllEqual(self._run_without_steps(), _TEST_DATA) def test_nested_aggregation(self): - aggregator = training_utils.OutputsAggregator( + aggregator = training_utils_v1.OutputsAggregator( use_steps=False, num_samples=6) batches = np.array_split(_TEST_DATA, 4) @@ -344,46 +348,46 @@ class AggregationTest(keras_parameterized.TestCase): self.assertAllEqual(aggregator.results, (_TEST_DATA, _TEST_DATA)) def test_concat_single_batch(self): - aggregator = training_utils.OutputsAggregator(use_steps=True) + aggregator = training_utils_v1.OutputsAggregator(use_steps=True) data = _TEST_DATA.copy() aggregator.create(data) assert len(aggregator.results) == 1 - assert isinstance(aggregator.results[0], training_utils.ConcatAggregator) + assert isinstance(aggregator.results[0], training_utils_v1.ConcatAggregator) aggregator.aggregate(data) aggregator.finalize() assert aggregator.results is data # No copy. def test_slice_single_batch(self): - aggregator = training_utils.OutputsAggregator( + aggregator = training_utils_v1.OutputsAggregator( use_steps=False, num_samples=6) data = _TEST_DATA.copy() aggregator.create(data) assert len(aggregator.results) == 1 - assert isinstance(aggregator.results[0], training_utils.SliceAggregator) + assert isinstance(aggregator.results[0], training_utils_v1.SliceAggregator) aggregator.aggregate(data, 0, 6) aggregator.finalize() assert aggregator.results is data # No copy. def test_async_copy(self): - training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 + training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 self.assertAllEqual(self._run_without_steps(), _TEST_DATA) # Two of the four batches will have 20 elements and two will have 10. - self.assertEqual(training_utils._COPY_POOL._apply_counter, 2) + self.assertEqual(training_utils_v1._COPY_POOL._apply_counter, 2) def test_async_copy_timeout(self): - training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 - training_utils.SliceAggregator._MAX_COPY_SECONDS = 0.1 - training_utils._COPY_POOL._func_wrapper = add_sleep + training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 + training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 0.1 + training_utils_v1._COPY_POOL._func_wrapper = add_sleep with self.assertRaisesRegex(ValueError, 'Timed out waiting for copy'): self._run_without_steps() def test_async_copy_reraise(self): - training_utils.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 - training_utils.SliceAggregator._MAX_COPY_SECONDS = 1. - training_utils._COPY_POOL._func_wrapper = cause_error + training_utils_v1.SliceAggregator._BINARY_SIZE_THRESHOLD = 15 + training_utils_v1.SliceAggregator._MAX_COPY_SECONDS = 1. + training_utils_v1._COPY_POOL._func_wrapper = cause_error with self.assertRaisesRegex(TypeError, 'NoneType'): self._run_without_steps() diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 7dca6ae3da7..54969bb5e83 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -51,6 +51,7 @@ from tensorflow.python.keras.engine import training_distributed_v1 from tensorflow.python.keras.engine import training_eager_v1 from tensorflow.python.keras.engine import training_generator_v1 from tensorflow.python.keras.engine import training_utils +from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.saving.saved_model import model_serialization @@ -413,7 +414,7 @@ class Model(training_lib.Model): base_layer.keras_api_gauge.get_cell('compile').set(True) # Prepare list of loss functions, same size of model outputs. - self.loss_functions = training_utils.prepare_loss_functions( + self.loss_functions = training_utils_v1.prepare_loss_functions( self.loss, self.output_names) target_tensors = self._process_target_tensor_for_compile(target_tensors) @@ -425,7 +426,8 @@ class Model(training_lib.Model): self._training_endpoints.append(endpoint) # Prepare list loss weights, same size of model outputs. - training_utils.prepare_loss_weights(self._training_endpoints, loss_weights) + training_utils_v1.prepare_loss_weights(self._training_endpoints, + loss_weights) # Initialization for Eager mode execution. if self.run_eagerly: @@ -447,7 +449,7 @@ class Model(training_lib.Model): masks=self._prepare_output_masks()) # Prepare sample weight modes. List with the same length as model outputs. - training_utils.prepare_sample_weight_modes( + training_utils_v1.prepare_sample_weight_modes( self._training_endpoints, sample_weight_mode) # Creates the model loss and weighted metrics sub-graphs. @@ -593,7 +595,7 @@ class Model(training_lib.Model): # or a non-distributed Dataset or iterator in eager execution. if data_utils.is_generator_or_sequence(inputs): return training_generator_v1.GeneratorOrSequenceTrainingLoop() - if training_utils.is_eager_dataset_or_iterator(inputs): + if training_utils_v1.is_eager_dataset_or_iterator(inputs): return training_generator_v1.EagerDatasetOrIteratorTrainingLoop() # Case 3: Symbolic tensors or Numpy array-like. @@ -1074,7 +1076,7 @@ class Model(training_lib.Model): + output_dict['metrics']) outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access else: - x = training_utils.ModelInputs(x).as_list() + x = training_utils_v1.ModelInputs(x).as_list() ins = x + list(y or []) + list(sample_weights or []) if not isinstance(K.symbolic_learning_phase(), int): @@ -1153,7 +1155,7 @@ class Model(training_lib.Model): + output_dict['metrics']) outputs = [_non_none_constant_value(v) for v in outputs] # pylint: disable=protected-access else: - x = training_utils.ModelInputs(x).as_list() + x = training_utils_v1.ModelInputs(x).as_list() inputs = x + list(y or []) + list(sample_weights or []) self._update_sample_weight_modes(sample_weights=sample_weights) @@ -1198,7 +1200,7 @@ class Model(training_lib.Model): # If `self._distribution_strategy` is True, then we are in a replica context # at this point. if self.run_eagerly or self._distribution_strategy: - inputs = training_utils.cast_if_floating_dtype(inputs) + inputs = training_utils_v1.cast_if_floating_dtype(inputs) if isinstance(inputs, collections_abc.Sequence): # Unwrap lists with only one input, as we do when training on batch if len(inputs) == 1: @@ -1370,7 +1372,7 @@ class Model(training_lib.Model): def _prepare_validation_data(self, validation_data, batch_size, validation_steps): """Unpack and check the validation data.""" - val_x, val_y, val_sample_weights = training_utils.unpack_validation_data( + val_x, val_y, val_sample_weights = training_utils_v1.unpack_validation_data( validation_data) return self._standardize_user_data( val_x, @@ -1449,7 +1451,7 @@ class Model(training_lib.Model): def _compile_eagerly(self, metrics, weighted_metrics, sample_weight_mode): # Prepare sample weight modes. List with the same length as model outputs. - training_utils.prepare_sample_weight_modes( + training_utils_v1.prepare_sample_weight_modes( self._training_endpoints, sample_weight_mode) # Prepare sample weights. self._prepare_sample_weights() @@ -1788,10 +1790,10 @@ class Model(training_lib.Model): output_shapes.append(None) else: output_shapes.append(output.shape.as_list()) - self._per_output_metrics = training_utils.collect_per_output_metric_info( + self._per_output_metrics = training_utils_v1.collect_per_output_metric_info( metrics, self.output_names, output_shapes, self.loss_functions) self._per_output_weighted_metrics = ( - training_utils.collect_per_output_metric_info( + training_utils_v1.collect_per_output_metric_info( weighted_metrics, self.output_names, output_shapes, @@ -1901,7 +1903,7 @@ class Model(training_lib.Model): metric_results = [] for metric_name, metric_fn in metrics_dict.items(): with K.name_scope(metric_name): - metric_result = training_utils.call_metric_function( + metric_result = training_utils_v1.call_metric_function( metric_fn, y_true, y_pred, weights=weights, mask=mask) metric_results.append(metric_result) return metric_results @@ -2138,7 +2140,7 @@ class Model(training_lib.Model): # in the codebase. if isinstance(x, dataset_ops.DatasetV2): if shuffle: - training_utils.verify_dataset_shuffled(x) + training_utils_v1.verify_dataset_shuffled(x) strategy = self._distribution_strategy with strategy.scope(): @@ -2190,8 +2192,8 @@ class Model(training_lib.Model): x = ds.batch(batch_size, drop_remainder=drop_remainder) else: assert isinstance(x, dataset_ops.DatasetV2) - training_utils.validate_dataset_input(x, y, sample_weight, - validation_split) + training_utils_v1.validate_dataset_input(x, y, sample_weight, + validation_split) return x def _standardize_user_data(self, @@ -2268,28 +2270,28 @@ class Model(training_lib.Model): # Graph mode dataset. We'll pass the dataset as-is (unless # `extract_tensors_from_dataset` is True, in which case we extract # the tensors from the dataset and we output them. - training_utils.validate_dataset_input(x, y, sample_weight, - validation_split) + training_utils_v1.validate_dataset_input(x, y, sample_weight, + validation_split) if shuffle: - training_utils.verify_dataset_shuffled(x) + training_utils_v1.verify_dataset_shuffled(x) is_dataset = True if extract_tensors_from_dataset: # We do this for `train_on_batch`/etc. - x, y, sample_weight = training_utils.extract_tensors_from_dataset(x) + x, y, sample_weight = training_utils_v1.extract_tensors_from_dataset(x) elif isinstance(x, iterator_ops.Iterator): # Graph mode iterator. We extract the symbolic tensors. - training_utils.validate_dataset_input(x, y, sample_weight, - validation_split) + training_utils_v1.validate_dataset_input(x, y, sample_weight, + validation_split) iterator = x - x, y, sample_weight = training_utils.unpack_iterator_input(iterator) + x, y, sample_weight = training_utils_v1.unpack_iterator_input(iterator) is_dataset = True else: is_dataset = False # Validates `steps` argument based on x's type. if check_steps: - training_utils.check_steps_argument(x, steps, steps_name) + training_utils_v1.check_steps_argument(x, steps, steps_name) # First, we build the model on the fly if necessary. if not self.inputs: @@ -2352,7 +2354,7 @@ class Model(training_lib.Model): # Standardize the inputs. if not isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): # TODO(fchollet): run static checks with dataset output shape(s). - x = training_utils.standardize_input_data( + x = training_utils_v1.standardize_input_data( x, feed_input_names, feed_input_shapes, @@ -2399,8 +2401,8 @@ class Model(training_lib.Model): if y is not None: # Prepare self._sample_weight_modes. List with the same length as # model outputs. - training_utils.prepare_sample_weight_modes(self._training_endpoints, - self.sample_weight_mode) + training_utils_v1.prepare_sample_weight_modes(self._training_endpoints, + self.sample_weight_mode) feed_output_names = self._feed_output_names feed_sample_weight_modes = self._sample_weight_modes if not self._is_graph_network: @@ -2409,7 +2411,7 @@ class Model(training_lib.Model): feed_output_shapes = self._feed_output_shapes # Standardize the outputs. - y = training_utils.standardize_input_data( + y = training_utils_v1.standardize_input_data( y, feed_output_names, # Don't enforce target shapes to match output shapes. @@ -2420,22 +2422,22 @@ class Model(training_lib.Model): # Generate sample-wise weight values given the `sample_weight` and # `class_weight` arguments. - sample_weights = training_utils.standardize_sample_weights( + sample_weights = training_utils_v1.standardize_sample_weights( sample_weight, feed_output_names) - class_weights = training_utils.standardize_class_weights( + class_weights = training_utils_v1.standardize_class_weights( class_weight, feed_output_names) sample_weights = [ - training_utils.standardize_weights(ref, sw, cw, mode) + training_utils_v1.standardize_weights(ref, sw, cw, mode) for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, feed_sample_weight_modes) ] # Check that all arrays have the same length. if not self._distribution_strategy: - training_utils.check_array_lengths(x, y, sample_weights) + training_utils_v1.check_array_lengths(x, y, sample_weights) if self._is_graph_network and not run_eagerly: # Additional checks to avoid users mistakenly using improper loss fns. - training_utils.check_loss_and_target_compatibility( + training_utils_v1.check_loss_and_target_compatibility( y, self._feed_loss_fns, feed_output_shapes) sample_weights, _, _ = training_utils.handle_partial_sample_weights( @@ -2470,11 +2472,12 @@ class Model(training_lib.Model): # iterator and only one batch of samples is required, we fetch the data # tensors from the iterator and then standardize them. if isinstance(inputs, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)): - inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs) + inputs, targets, _ = training_utils_v1.extract_tensors_from_dataset( + inputs) # We type-check that `inputs` and `targets` are either single arrays # or lists of arrays, and extract a flat list of inputs from the passed # structure. - training_utils.validate_input_types(inputs, orig_inputs) + training_utils_v1.validate_input_types(inputs, orig_inputs) if isinstance(inputs, (list, tuple)): processed_inputs += list(inputs) @@ -2509,14 +2512,14 @@ class Model(training_lib.Model): if not self.inputs: # For subclassed models, a robust input spec is not available so we # must cast to the model dtype. - inputs = training_utils.cast_if_floating_dtype(inputs, self.dtype) + inputs = training_utils_v1.cast_if_floating_dtype(inputs, self.dtype) def create_tensor_spec(t): return tensor_spec.TensorSpec(t.shape, t.dtype) cast_inputs = nest.map_structure(create_tensor_spec, inputs) - elif training_utils.has_tensors(inputs): - cast_inputs = training_utils.cast_if_floating_dtype(inputs) + elif training_utils_v1.has_tensors(inputs): + cast_inputs = training_utils_v1.cast_if_floating_dtype(inputs) else: cast_inputs = inputs self._set_inputs(cast_inputs) @@ -2525,11 +2528,11 @@ class Model(training_lib.Model): def _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target): if target is not None: # We need to use `y` to set the model targets. - if training_utils.has_tensors(target): - target = training_utils.cast_if_floating_dtype_and_mismatch( + if training_utils_v1.has_tensors(target): + target = training_utils_v1.cast_if_floating_dtype_and_mismatch( target, self.outputs) - training_utils.validate_input_types(target, orig_target, - allow_dict=False, field_name='target') + training_utils_v1.validate_input_types( + target, orig_target, allow_dict=False, field_name='target') if isinstance(target, (list, tuple)): all_inputs += list(target) else: @@ -2628,7 +2631,7 @@ class Model(training_lib.Model): input_shape = (None,) + tuple(inputs.as_list()[1:]) elif isinstance(inputs, dict): # We assert that the first layer is a FeatureLayer. - if not training_utils.is_feature_layer(self.layers[0]): + if not training_utils_v1.is_feature_layer(self.layers[0]): raise ValueError('Passing a dictionary input to a Sequential Model ' 'which doesn\'t have FeatureLayer as the first layer' ' is an error.') @@ -2643,7 +2646,7 @@ class Model(training_lib.Model): # On-the-fly setting of symbolic model inputs (either by using the tensor # provided, or by creating a placeholder if Numpy data was provided). - model_inputs = training_utils.ModelInputs(inputs) + model_inputs = training_utils_v1.ModelInputs(inputs) inputs = model_inputs.get_symbolic_inputs() self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True) self.input_names = model_inputs.get_input_names() @@ -2667,7 +2670,7 @@ class Model(training_lib.Model): # data adapter since it assumes nest.flatten ordering. outputs = nest.flatten(outputs) self.outputs = outputs - self.output_names = training_utils.generic_output_names(outputs) + self.output_names = training_utils_v1.generic_output_names(outputs) # TODO(scottzhu): Should we cleanup the self._training_endpoints here? self.built = True @@ -3138,7 +3141,7 @@ class _TrainingTarget(object): def _is_symbolic_tensor(x): - return tensor_util.is_tensor(x) and not isinstance(x, ops.EagerTensor) + return tensor_util.is_tensor(x) def _convert_scipy_sparse_tensor(value, expected_input): diff --git a/tensorflow/python/keras/feature_column/BUILD b/tensorflow/python/keras/feature_column/BUILD index 424e114674d..89eaa6d1dd2 100644 --- a/tensorflow/python/keras/feature_column/BUILD +++ b/tensorflow/python/keras/feature_column/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -24,7 +22,6 @@ py_library( ":dense_features", ":dense_features_v2", ":sequence_feature_column", - "//tensorflow/python/keras:combinations", ], ) @@ -66,7 +63,6 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:tf_export", "//tensorflow/python/feature_column:feature_column_v2", - "//tensorflow/python/keras:combinations", ], ) @@ -91,6 +87,8 @@ tf_py_test( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", ], ) @@ -114,6 +112,8 @@ tf_py_test( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", ], ) @@ -171,6 +171,7 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/keras:metrics", # Import it here since base_layer didn't import it due to circular dependency. "//tensorflow/python/keras/layers:recurrent", ], ) diff --git a/tensorflow/python/keras/integration_test/BUILD b/tensorflow/python/keras/integration_test/BUILD index 20e1a886d4e..3b4db66ab55 100644 --- a/tensorflow/python/keras/integration_test/BUILD +++ b/tensorflow/python/keras/integration_test/BUILD @@ -11,8 +11,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tf_py_test( name = "forwardprop_test", srcs = ["forwardprop_test.py"], diff --git a/tensorflow/python/keras/integration_test/gradients_test.py b/tensorflow/python/keras/integration_test/gradients_test.py index 706150ddaa8..e9a851e02a8 100644 --- a/tensorflow/python/keras/integration_test/gradients_test.py +++ b/tensorflow/python/keras/integration_test/gradients_test.py @@ -79,6 +79,34 @@ class GradientsTest(tf.test.TestCase): for g, g_re in zip(grads, grads_re): self.assertAllClose(g, g_re) + def testLSTMBatchJacobian(self): + class HasLSTM(tf.keras.Model): + + def __init__(self): + super(HasLSTM, self).__init__() + self.lstm = tf.keras.layers.LSTM(units=5) + self.dense = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid) + + def call(self, x): + return self.dense(self.lstm(x)) + + m = HasLSTM() + + def jacobian(x): + with tf.GradientTape() as tape: + tape.watch(x) + y = m(x) # pylint: disable=not-callable + return tape.batch_jacobian(y, x) + + inp = tf.nn.l2_normalize(tf.ones([1, 2, 3]), axis=[1, 2]) + eager_result = jacobian(inp) + function_result = tf.function(jacobian)(inp) + self.assertAllClose(eager_result, function_result) + backprop_result, numeric_result = tf.test.compute_gradient( + m, [inp], delta=1e-3) + self.assertAllClose(numeric_result, backprop_result, rtol=1e-2) + self.assertAllClose(tf.reshape(numeric_result, [-1]), + tf.reshape(eager_result, [-1]), rtol=1e-2) if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 8e8b2fbc387..4846df4f213 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -16,8 +16,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -54,7 +52,9 @@ py_library( ":recurrent_v2", ":rnn_cell_wrapper_v2", ":wrappers", + "//tensorflow/python/keras/feature_column", "//tensorflow/python/keras/layers/preprocessing", + "//tensorflow/python/keras/premade", "//tensorflow/python/keras/utils:tf_utils", ], ) @@ -439,8 +439,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":recurrent", - "//tensorflow/python:rnn_cell", "//tensorflow/python:util", + "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_wrapper_impl", ], ) @@ -541,10 +541,8 @@ cuda_py_test( python_version = "PY3", shard_count = 4, tags = [ - "no_cuda11", "no_windows_gpu", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -825,7 +823,6 @@ cuda_py_test( tags = [ "no_oss", ], - xla_enable_strict_auto_jit = False, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -844,7 +841,6 @@ cuda_py_test( "no_cuda11", "no_oss", ], - xla_enable_strict_auto_jit = False, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -873,7 +869,6 @@ tf_py_test( size = "small", srcs = ["kernelized_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":layers", "//tensorflow/python:array_ops", @@ -932,6 +927,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", + "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_impl", + "//tensorflow/python/keras/legacy_tf_layers:layers_base", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 39e0eaa935f..c383ef2b8da 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -2024,7 +2024,7 @@ class SeparableConv1D(SeparableConv): def call(self, inputs): if self.padding == 'causal': - inputs = array_ops.pad(inputs, self._compute_causal_padding()) + inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) if self.data_format == 'channels_last': strides = (1,) + self.strides * 2 + (1,) spatial_start_dim = 1 diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index c0932e7297b..212ce42ddaa 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -30,6 +30,7 @@ import numpy as np from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import monitoring from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -62,6 +63,13 @@ from tensorflow.python.util.tf_export import get_canonical_name_for_symbol from tensorflow.python.util.tf_export import get_symbol_from_name from tensorflow.python.util.tf_export import keras_export +# TODO(b/168039935): track dropout rate to decide whether/how to make a +# dropout rate fastpath. +keras_temporary_dropout_rate = monitoring.BoolGauge( + '/tensorflow/api/keras/dropout/temp_rate_is_zero', + 'Temporarily record if Keras dropout layer was created w/' + 'constant rate = 0') + # pylint: disable=g-classes-have-attributes @keras_export('keras.layers.Masking') @@ -186,6 +194,10 @@ class Dropout(Layer): def __init__(self, rate, noise_shape=None, seed=None, **kwargs): super(Dropout, self).__init__(**kwargs) self.rate = rate + if isinstance(rate, (int, float)) and not rate: + keras_temporary_dropout_rate.get_cell().set(True) + else: + keras_temporary_dropout_rate.get_cell().set(False) self.noise_shape = noise_shape self.seed = seed self.supports_masking = True @@ -1292,8 +1304,20 @@ class TFOpLambda(Layer): api_name='keras', add_prefix_to_v1_names=True)) if 'name' not in kwargs: + # Generate a name. + # TFOpLambda layers avoid already-observed names, + # because users cannot easily control the generated names. + # Without this avoidance, users would be more likely to run + # into unavoidable duplicate layer name collisions. + # (For standard layers users could just set `name` when creating the + # layer to work around a collision, but they can't do that for + # auto-generated layers) + if self.symbol: + name = 'tf.' + self.symbol + else: + name = self.function.__name__ kwargs['name'] = K.unique_object_name( - 'tf.' + self.symbol, zero_based=True) + name, zero_based=True, avoid_observed_names=True) kwargs['autocast'] = False # Decorate the function to produce this layer's call method diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index b7a11d32c71..3f113bf8cbb 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -559,5 +559,20 @@ class CoreLayersTest(keras_parameterized.TestCase): self.assertAllEqual(np.ones((10, 20)), layer([x, y])) +@keras_parameterized.run_all_keras_modes +class TFOpLambdaTest(keras_parameterized.TestCase): + + def test_non_tf_symbol(self): + def dummy_func(a, b): + return a + b + + layer = core.TFOpLambda(dummy_func) + self.assertIsNone(layer.symbol) + self.assertEqual(layer.name, 'dummy_func') + + with self.assertRaisesRegex(ValueError, 'was generated from .*dummy_func'): + layer.get_config() + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py index ab2912505ef..89750606c17 100644 --- a/tensorflow/python/keras/layers/dense_attention.py +++ b/tensorflow/python/keras/layers/dense_attention.py @@ -49,8 +49,6 @@ class BaseDenseAttention(Layer): flow of information from the future towards the past. dropout: Float between 0 and 1. Fraction of the units to drop for the attention scores. - return_attention_scores: bool, it `True`, returns the attention scores - (after masking and softmax) as an additional output argument. Call Arguments: @@ -69,6 +67,8 @@ class BaseDenseAttention(Layer): `mask==False` do not contribute to the result. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). + return_attention_scores: bool, it `True`, returns the attention scores + (after masking and softmax) as an additional output argument. Output: @@ -77,12 +77,11 @@ class BaseDenseAttention(Layer): `[batch_size, Tq, Tv]`. """ - def __init__(self, causal=False, dropout=0.0, return_attention_scores=False, + def __init__(self, causal=False, dropout=0.0, **kwargs): super(BaseDenseAttention, self).__init__(**kwargs) self.causal = causal self.dropout = dropout - self.return_attention_scores = return_attention_scores self.supports_masking = True def _calculate_scores(self, query, key): @@ -140,7 +139,11 @@ class BaseDenseAttention(Layer): return math_ops.matmul(weights, value), weights # TODO(b/125916026): Consider exposing a __call__ method with named args. - def call(self, inputs, mask=None, training=None): + def call(self, + inputs, + mask=None, + training=None, + return_attention_scores=False): self._validate_call_args(inputs=inputs, mask=mask) q = inputs[0] v = inputs[1] @@ -170,7 +173,7 @@ class BaseDenseAttention(Layer): # Mask of shape [batch_size, Tq, 1]. q_mask = array_ops.expand_dims(q_mask, axis=-1) result *= math_ops.cast(q_mask, dtype=result.dtype) - if self.return_attention_scores: + if return_attention_scores: return result, attention_scores return result @@ -209,7 +212,6 @@ class BaseDenseAttention(Layer): config = { 'causal': self.causal, 'dropout': self.dropout, - 'return_attention_scores': self.return_attention_scores, } base_config = super(BaseDenseAttention, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -239,8 +241,6 @@ class Attention(BaseDenseAttention): flow of information from the future towards the past. dropout: Float between 0 and 1. Fraction of the units to drop for the attention scores. - return_attention_scores: bool, it `True`, returns the attention scores - (after masking and softmax) as an additional output argument. Call Arguments: @@ -257,6 +257,8 @@ class Attention(BaseDenseAttention): * value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`. If given, will apply the mask such that values at positions where `mask==False` do not contribute to the result. + return_attention_scores: bool, it `True`, returns the attention scores + (after masking and softmax) as an additional output argument. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). @@ -378,8 +380,6 @@ class AdditiveAttention(BaseDenseAttention): flow of information from the future towards the past. dropout: Float between 0 and 1. Fraction of the units to drop for the attention scores. - return_attention_scores: bool, it `True`, returns the attention scores - (after masking and softmax) as an additional output argument. Call Arguments: @@ -398,6 +398,8 @@ class AdditiveAttention(BaseDenseAttention): `mask==False` do not contribute to the result. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (no dropout). + return_attention_scores: bool, it `True`, returns the attention scores + (after masking and softmax) as an additional output argument. Output: diff --git a/tensorflow/python/keras/layers/dense_attention_test.py b/tensorflow/python/keras/layers/dense_attention_test.py index 5df53a8d1fb..8570f41b34f 100644 --- a/tensorflow/python/keras/layers/dense_attention_test.py +++ b/tensorflow/python/keras/layers/dense_attention_test.py @@ -82,8 +82,8 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase): # => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863 # softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137 # softmax_scores002 = 0 - expected_scores = np.array( - [[[0.73105857863, 0.26894142137, 0.]]], dtype=np.float32) + expected_scores = np.array([[[0.73105857863, 0.26894142137, 0.]]], + dtype=np.float32) self.assertAllClose(expected_scores, actual_scores) # Expected tensor of shape [1, 1, 1]. # expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8 @@ -187,8 +187,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase): def test_calculate_scores_multi_dim(self): # Query tensor of shape [1, 2, 4] - q = np.array( - [[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) + q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) # Key tensor of shape [1, 3, 4] k = np.array( [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], @@ -204,8 +203,8 @@ class AttentionTest(test.TestCase, parameterized.TestCase): # expected010 = 2.*1.5+2.1*1.6+2.2*1.7+2.3*1.8 = 14.24 # expected011 = 2.*2.5+2.1*2.6+2.2*2.7+2.3*2.8 = 22.84 # expected012 = 2.*3.5+2.1*3.6+2.2*3.7+2.3*3.8 = 31.44 - expected = np.array( - [[[7.64, 12.24, 16.84], [14.24, 22.84, 31.44]]], dtype=np.float32) + expected = np.array([[[7.64, 12.24, 16.84], [14.24, 22.84, 31.44]]], + dtype=np.float32) self.assertAllClose(expected, actual) def test_calculate_scores_one_dim_batch_size_two(self): @@ -241,8 +240,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase): def test_shape(self): # Query tensor of shape [1, 2, 4] - q = np.array( - [[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) + q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) # Value tensor of shape [1, 3, 4] v = np.array( [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], @@ -257,8 +255,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase): def test_shape_with_key(self): # Query tensor of shape [1, 2, 4] - q = np.array( - [[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) + q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) # Value tensor of shape [1, 3, 4] v = np.array( [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], @@ -342,12 +339,16 @@ class AttentionTest(test.TestCase, parameterized.TestCase): q_mask = np.array([[True, False]], dtype=np.bool_) # Value mask tensor of shape [1, 3] v_mask = np.array([[True, True, False]], dtype=np.bool_) - attention_layer = dense_attention.Attention( - return_attention_scores=return_attention_scores) + attention_layer = dense_attention.Attention() if return_attention_scores: - actual, actual_scores = attention_layer([q, v], mask=[q_mask, v_mask]) + actual, actual_scores = attention_layer( + [q, v], + mask=[q_mask, v_mask], + return_attention_scores=return_attention_scores) else: - actual = attention_layer([q, v], mask=[q_mask, v_mask]) + actual = attention_layer([q, v], + mask=[q_mask, v_mask], + return_attention_scores=return_attention_scores) # Expected scores of shape [1, 2, 3] # scores = [[[1.1*1.6, 1.1*0.7, -1.1*0.8], [-0.5*1.6, -0.5*0.7, 0.5*0.8]]] @@ -365,10 +366,9 @@ class AttentionTest(test.TestCase, parameterized.TestCase): # = 0.61063923394 # attention_distribution012 = 0 if return_attention_scores: - expected_scores = np.array( - [[[0.72908792234, 0.27091207765, 0.], - [0.38936076605, 0.61063923394, 0.]]], - dtype=np.float32) + expected_scores = np.array([[[0.72908792234, 0.27091207765, 0.], + [0.38936076605, 0.61063923394, 0.]]], + dtype=np.float32) self.assertAllClose(expected_scores, actual_scores) # Expected tensor of shape [1, 2, 1] with zeros where q_mask == False. # expected000 = 0.72908792234 * 1.6 + 0.27091207765 * 0.7 - 0 * 0.8 @@ -406,12 +406,13 @@ class AttentionTest(test.TestCase, parameterized.TestCase): def test_self_attention_causal(self, return_attention_scores): # Query-value tensor of shape [1, 3, 1] q = np.array([[[0.5], [0.8], [-0.3]]], dtype=np.float32) - attention_layer = dense_attention.Attention( - causal=True, return_attention_scores=return_attention_scores) + attention_layer = dense_attention.Attention(causal=True) if return_attention_scores: - actual, actual_scores = attention_layer([q, q]) + actual, actual_scores = attention_layer( + [q, q], return_attention_scores=return_attention_scores) else: - actual = attention_layer([q, q]) + actual = attention_layer([q, q], + return_attention_scores=return_attention_scores) # Expected scores of shape [1, 3, 3] # scores = [[0.25, 0.4, -0.15], [0.4, 0.64, -0.24], [-0.15, -0.24, 0.09]] @@ -426,8 +427,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase): # = [0.31395396638, 0.28693232061, 0.399113713] if return_attention_scores: expected_scores = np.array( - [[[1., 0., 0.], - [0.44028635073, 0.55971364926, 0.], + [[[1., 0., 0.], [0.44028635073, 0.55971364926, 0.], [0.31395396638, 0.28693232061, 0.399113713]]], dtype=np.float32) self.assertAllClose(expected_scores, actual_scores) @@ -437,8 +437,8 @@ class AttentionTest(test.TestCase, parameterized.TestCase): # = 0.66791409477 # expected020 = 0.31395396638 * 0.5 +0.28693232061 * 0.8 -0.399113713 * 0.3 # = 0.26678872577 - expected = np.array( - [[[0.5], [0.66791409477], [0.26678872577]]], dtype=np.float32) + expected = np.array([[[0.5], [0.66791409477], [0.26678872577]]], + dtype=np.float32) self.assertAllClose(expected, actual) def test_inputs_not_list(self): @@ -501,24 +501,20 @@ class AttentionTest(test.TestCase, parameterized.TestCase): self.assertAllClose([[[0], [1]]], actual) @parameterized.named_parameters( - ('', False, False), - ('use_scale', True, False), - ('return_attention_scores', False, True), + ('', False), + ('use_scale', True), ) - def test_serialization(self, use_scale, return_attention_scores): + def test_serialization(self, use_scale): # Test serialization with use_scale - layer = dense_attention.Attention( - use_scale=use_scale, return_attention_scores=return_attention_scores) + layer = dense_attention.Attention(use_scale=use_scale) config = keras.layers.serialize(layer) new_layer = keras.layers.deserialize(config) self.assertEqual(new_layer.use_scale, use_scale) - self.assertEqual(new_layer.return_attention_scores, return_attention_scores) config = layer.get_config() new_layer = dense_attention.Attention.from_config(config) self.assertEqual(new_layer.use_scale, use_scale) - self.assertEqual(new_layer.return_attention_scores, return_attention_scores) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) @@ -542,8 +538,7 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase): def test_calculate_scores_multi_dim(self): # Query tensor of shape [1, 2, 4] - q = np.array( - [[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) + q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) # Key tensor of shape [1, 3, 4] k = np.array( [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], @@ -562,10 +557,9 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase): # expected011 = 0.5*tanh(2.+2.5) + 0.6*tanh(2.1+2.6) + 0.7*tanh(2.2+2.7) + 0.8*tanh(2.3+2.8) = 2.59964024652 # expected012 = 0.5*tanh(2.+3.5) + 0.6*tanh(2.1+3.6) + 0.7*tanh(2.2+3.7) + 0.8*tanh(2.3+3.8) = 2.59995130916 # pylint:enable=line-too-long - expected = np.array( - [[[2.58044532581, 2.59734317449, 2.59964024652], - [2.59734317449, 2.59964024652, 2.59995130916]]], - dtype=np.float32) + expected = np.array([[[2.58044532581, 2.59734317449, 2.59964024652], + [2.59734317449, 2.59964024652, 2.59995130916]]], + dtype=np.float32) self.assertAllClose(expected, actual) def test_calculate_scores_one_dim_batch_size_two(self): @@ -582,14 +576,13 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase): # Expected tensor of shape [2, 1, 1]. # expected000 = 0.5 * tanh(1.1 + 1.6) = 0.49550372683 # expected100 = 0.5 * tanh(2.1 + 2.6) = 0.49991728277 - expected = np.array( - [[[0.49550372683]], [[0.49991728277]]], dtype=np.float32) + expected = np.array([[[0.49550372683]], [[0.49991728277]]], + dtype=np.float32) self.assertAllClose(expected, actual) def test_shape(self): # Query tensor of shape [1, 2, 4] - q = np.array( - [[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) + q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) # Value tensor of shape [1, 3, 4] v = np.array( [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], @@ -604,8 +597,7 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase): def test_shape_no_scale(self): # Query tensor of shape [1, 2, 4] - q = np.array( - [[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) + q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) # Value tensor of shape [1, 3, 4] v = np.array( [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], @@ -620,8 +612,7 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase): def test_shape_with_key(self): # Query tensor of shape [1, 2, 4] - q = np.array( - [[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) + q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32) # Value tensor of shape [1, 3, 4] v = np.array( [[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]], @@ -779,8 +770,8 @@ class LowerTriangularMaskTest(test.TestCase, parameterized.TestCase): def test_orthogonal_shape(self): actual = dense_attention._lower_triangular_mask([3, 2]) - expected = np.array( - [[True, False], [True, True], [True, True]], dtype=np.bool_) + expected = np.array([[True, False], [True, True], [True, True]], + dtype=np.bool_) self.assertAllEqual(expected, actual) def test_three_dim(self): diff --git a/tensorflow/python/keras/layers/gru_v2_test.py b/tensorflow/python/keras/layers/gru_v2_test.py index 8760728eb99..db2b0a2e7b9 100644 --- a/tensorflow/python/keras/layers/gru_v2_test.py +++ b/tensorflow/python/keras/layers/gru_v2_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import os import shutil @@ -639,6 +640,31 @@ class GRUV2Test(keras_parameterized.TestCase): # Make sure it doesn't crash with cudnn kernel. model.predict(inputs) + # TODO (b/169895267): test with xla_gpu is disabled. + def test_deepcopy(self): + if not context.executing_eagerly(): + self.skipTest('v2-only test') + original_layer = rnn.GRU(5) + copied_layer = copy.deepcopy(original_layer) + self.assertEqual(copied_layer.units, 5) + self.assertEqual(original_layer.get_config(), original_layer.get_config()) + + # Copy layer before layer call on inputs without weight initialization. + inputs = np.random.normal(size=[32, 10, 8]).astype(np.float32) + original_layer = rnn.GRU(4) + copied_layer = copy.deepcopy(original_layer) + outputs = original_layer(inputs) + copied_outputs = copied_layer(inputs) + self.assertNotAllClose( + self.evaluate(outputs), self.evaluate(copied_outputs)) + + # Copy layer after layer call on inputs with weight initialization. + original_layer = rnn.GRU(4) + outputs = original_layer(inputs) + copied_layer = copy.deepcopy(original_layer) + copied_outputs = copied_layer(inputs) + self.assertAllClose(self.evaluate(outputs), self.evaluate(copied_outputs)) + class GRULayerGradientTapeTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/layers/kernelized_test.py b/tensorflow/python/keras/layers/kernelized_test.py index e6178afaf7d..da4fd16e029 100644 --- a/tensorflow/python/keras/layers/kernelized_test.py +++ b/tensorflow/python/keras/layers/kernelized_test.py @@ -74,19 +74,20 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase): @testing_utils.run_v2_only def test_state_saving_and_loading(self): - input_data = np.random.random((1, 2)) - rff_layer = kernel_layers.RandomFourierFeatures(output_dim=10, scale=3.0) - inputs = input_layer.Input((2,)) - outputs = rff_layer(inputs) - model = training.Model(inputs, outputs) - output_data = model.predict(input_data) - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - saved_model_dir = os.path.join(temp_dir, 'rff_model') - model.save(saved_model_dir) - new_model = save.load_model(saved_model_dir) - new_output_data = new_model.predict(input_data) - self.assertAllClose(output_data, new_output_data, atol=1e-4) + with self.cached_session(): + input_data = np.random.random((1, 2)) + rff_layer = kernel_layers.RandomFourierFeatures(output_dim=10, scale=3.0) + inputs = input_layer.Input((2,)) + outputs = rff_layer(inputs) + model = training.Model(inputs, outputs) + output_data = model.predict(input_data) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + saved_model_dir = os.path.join(temp_dir, 'rff_model') + model.save(saved_model_dir) + new_model = save.load_model(saved_model_dir) + new_output_data = new_model.predict(input_data) + self.assertAllClose(output_data, new_output_data, atol=1e-4) def test_invalid_output_dim(self): with self.assertRaisesRegex( diff --git a/tensorflow/python/keras/layers/legacy_rnn/BUILD b/tensorflow/python/keras/layers/legacy_rnn/BUILD index 203c6d71f92..2da1445bec4 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/BUILD +++ b/tensorflow/python/keras/layers/legacy_rnn/BUILD @@ -2,7 +2,10 @@ # Contains the legacy TF RNN APIs (internal TensorFlow version). package( - default_visibility = ["//tensorflow:__subpackages__"], + default_visibility = [ + "//tensorflow:__subpackages__", + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) @@ -37,6 +40,7 @@ py_library( "//tensorflow/python/keras:activations", "//tensorflow/python/keras:initializers", "//tensorflow/python/keras/engine:input_spec", + "//tensorflow/python/keras/legacy_tf_layers:layers_base", "//tensorflow/python/keras/utils:tf_utils", "//tensorflow/python/training/tracking:base", ], diff --git a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py index 94222e27669..91f3aede33a 100644 --- a/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py +++ b/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py @@ -40,8 +40,8 @@ from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import input_spec from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl +from tensorflow.python.keras.legacy_tf_layers import base as base_layer from tensorflow.python.keras.utils import tf_utils -from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import init_ops diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py index c33c88f3a3d..88c8fede08c 100644 --- a/tensorflow/python/keras/layers/local.py +++ b/tensorflow/python/keras/layers/local.py @@ -29,6 +29,9 @@ from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.utils import conv_utils from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.util.tf_export import keras_export @@ -774,7 +777,7 @@ def local_conv_matmul(inputs, kernel, kernel_mask, output_shape): kernel = kernel_mask * kernel kernel = make_2d(kernel, split_dim=K.ndim(kernel) // 2) - output_flat = K.math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True) + output_flat = math_ops.sparse_matmul(inputs_flat, kernel, b_is_sparse=True) output = K.reshape(output_flat, [K.shape(output_flat)[0],] + output_shape.as_list()[1:]) return output @@ -807,7 +810,7 @@ def local_conv_sparse_matmul(inputs, kernel, kernel_idxs, kernel_shape, Output (N+2)-D dense tensor with shape `output_shape`. """ inputs_flat = K.reshape(inputs, (K.shape(inputs)[0], -1)) - output_flat = K.sparse_ops.sparse_tensor_dense_mat_mul( + output_flat = sparse_ops.sparse_tensor_dense_mat_mul( kernel_idxs, kernel, kernel_shape, inputs_flat, adjoint_b=True) output_flat_transpose = K.transpose(output_flat) @@ -833,11 +836,11 @@ def make_2d(tensor, split_dim): Tensor of shape `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`. """ - shape = K.array_ops.shape(tensor) + shape = array_ops.shape(tensor) in_dims = shape[:split_dim] out_dims = shape[split_dim:] - in_size = K.math_ops.reduce_prod(in_dims) - out_size = K.math_ops.reduce_prod(out_dims) + in_size = math_ops.reduce_prod(in_dims) + out_size = math_ops.reduce_prod(out_dims) - return K.array_ops.reshape(tensor, (in_size, out_size)) + return array_ops.reshape(tensor, (in_size, out_size)) diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py index 78611317972..2505eed6bab 100644 --- a/tensorflow/python/keras/layers/local_test.py +++ b/tensorflow/python/keras/layers/local_test.py @@ -22,9 +22,13 @@ from absl.testing import parameterized import numpy as np from tensorflow.python import keras +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import combinations from tensorflow.python.keras import testing_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.platform import test from tensorflow.python.training.rmsprop import RMSPropOptimizer @@ -423,9 +427,9 @@ def get_inputs(data_format, filters, height, num_samples, width): def xent(y_true, y_pred): y_true = keras.backend.cast( keras.backend.reshape(y_true, (-1,)), - keras.backend.dtypes_module.int32) + dtypes.int32) - return keras.backend.nn.sparse_softmax_cross_entropy_with_logits( + return nn.sparse_softmax_cross_entropy_with_logits( labels=y_true, logits=y_pred) @@ -496,9 +500,9 @@ def copy_lc_weights_2_to_1(lc_layer_2_from, lc_layer_1_to): lc_2_kernel_masked = keras.backend.permute_dimensions( lc_2_kernel_masked, permutation) - lc_2_kernel_mask = keras.backend.math_ops.not_equal( + lc_2_kernel_mask = math_ops.not_equal( lc_2_kernel_masked, 0) - lc_2_kernel_flat = keras.backend.array_ops.boolean_mask( + lc_2_kernel_flat = array_ops.boolean_mask( lc_2_kernel_masked, lc_2_kernel_mask) lc_2_kernel_reshaped = keras.backend.reshape(lc_2_kernel_flat, lc_layer_1_to.kernel.shape) @@ -516,8 +520,8 @@ def copy_lc_weights_2_to_3(lc_layer_2_from, lc_layer_3_to): lc_2_kernel_masked = keras.layers.local.make_2d( lc_2_kernel_masked, split_dim=keras.backend.ndim(lc_2_kernel_masked) // 2) lc_2_kernel_masked = keras.backend.transpose(lc_2_kernel_masked) - lc_2_kernel_mask = keras.backend.math_ops.not_equal(lc_2_kernel_masked, 0) - lc_2_kernel_flat = keras.backend.array_ops.boolean_mask( + lc_2_kernel_mask = math_ops.not_equal(lc_2_kernel_masked, 0) + lc_2_kernel_flat = array_ops.boolean_mask( lc_2_kernel_masked, lc_2_kernel_mask) lc_2_kernel_flat = keras.backend.get_value(lc_2_kernel_flat) diff --git a/tensorflow/python/keras/layers/lstm_v2_test.py b/tensorflow/python/keras/layers/lstm_v2_test.py index c0f41b4bf7c..c6cb9208357 100644 --- a/tensorflow/python/keras/layers/lstm_v2_test.py +++ b/tensorflow/python/keras/layers/lstm_v2_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import os import shutil import time @@ -840,6 +841,31 @@ class LSTMV2Test(keras_parameterized.TestCase): # Make sure it doesn't crash with cudnn kernel. model.predict(inputs) + # TODO (b/169895267): test with xla_gpu is disabled. + def test_deepcopy(self): + if not context.executing_eagerly(): + self.skipTest('v2-only test') + original_layer = rnn.LSTM(5) + copied_layer = copy.deepcopy(original_layer) + self.assertEqual(copied_layer.units, 5) + self.assertEqual(original_layer.get_config(), original_layer.get_config()) + + # Copy layer before layer call on inputs without weight initialization. + inputs = np.random.normal(size=[32, 10, 8]).astype(np.float32) + original_layer = rnn.LSTM(4) + copied_layer = copy.deepcopy(original_layer) + outputs = original_layer(inputs) + copied_outputs = copied_layer(inputs) + self.assertNotAllClose( + self.evaluate(outputs), self.evaluate(copied_outputs)) + + # Copy layer after layer call on inputs with weight initialization. + original_layer = rnn.LSTM(4) + outputs = original_layer(inputs) + copied_layer = copy.deepcopy(original_layer) + copied_outputs = copied_layer(inputs) + self.assertAllClose(self.evaluate(outputs), self.evaluate(copied_outputs)) + @keras_parameterized.run_all_keras_modes(config=_config) class LSTMGraphRewriteTest(keras_parameterized.TestCase): diff --git a/tensorflow/python/keras/layers/multi_head_attention.py b/tensorflow/python/keras/layers/multi_head_attention.py index 7ddce8caceb..bda0056fe7e 100644 --- a/tensorflow/python/keras/layers/multi_head_attention.py +++ b/tensorflow/python/keras/layers/multi_head_attention.py @@ -191,11 +191,15 @@ class MultiHeadAttention(Layer): value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use `value` for both `key` and `value`, which is the most common case. - attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention - to certain positions. + attention_mask: a boolean mask of shape `[B, T, S]`, that prevents + attention to certain positions. return_attention_scores: A boolean to indicate whether the output should be attention output if True, or (attention_output, attention_scores) if False. Defaults to False. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + Defaults to either using the training mode of the parent layer/model, + or False (inference) if there is no parent layer. Returns: attention_output: The result of the computation, of shape [B, T, E], @@ -396,7 +400,12 @@ class MultiHeadAttention(Layer): attention_mask, axis=mask_expansion_axes) return self._softmax(attention_scores, attention_mask) - def _compute_attention(self, query, key, value, attention_mask=None): + def _compute_attention(self, + query, + key, + value, + attention_mask=None, + training=None): """Applies Dot-product attention with query, key, value tensors. This function defines the computation inside `call` with projected @@ -409,6 +418,8 @@ class MultiHeadAttention(Layer): value: Projected value `Tensor` of shape `[B, T, N, value_dim]`. attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention to certain positions. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (doing nothing). Returns: attention_output: Multi-headed outputs of attention computation. @@ -428,15 +439,21 @@ class MultiHeadAttention(Layer): # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_scores_dropout = self._dropout_layer(attention_scores) + attention_scores_dropout = self._dropout_layer( + attention_scores, training=training) # `context_layer` = [B, T, N, H] attention_output = special_math_ops.einsum(self._combine_equation, attention_scores_dropout, value) return attention_output, attention_scores - def call(self, query, value, key=None, attention_mask=None, - return_attention_scores=False): + def call(self, + query, + value, + key=None, + attention_mask=None, + return_attention_scores=False, + training=None): if not self._built_from_signature: self._build_from_signature(query=query, value=value, key=key) if key is None: @@ -454,10 +471,9 @@ class MultiHeadAttention(Layer): value = self._value_dense(value) attention_output, attention_scores = self._compute_attention( - query, key, value, attention_mask) + query, key, value, attention_mask, training) attention_output = self._output_dense(attention_output) if return_attention_scores: return attention_output, attention_scores return attention_output - diff --git a/tensorflow/python/keras/layers/multi_head_attention_test.py b/tensorflow/python/keras/layers/multi_head_attention_test.py index a50fefd05ba..4c957b8973b 100644 --- a/tensorflow/python/keras/layers/multi_head_attention_test.py +++ b/tensorflow/python/keras/layers/multi_head_attention_test.py @@ -225,6 +225,22 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): model.predict([query, value, mask_data]), model.predict([query, value, null_mask_data])) + def test_dropout(self): + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=2, key_dim=2, dropout=0.5) + + # Generate data for the input (non-mask) tensors. + from_data = keras.backend.ones(shape=(32, 4, 8)) + to_data = keras.backend.ones(shape=(32, 2, 8)) + train_out = test_layer(from_data, to_data, None, None, None, True) + test_out = test_layer(from_data, to_data, None, None, None, False) + + # Output should be close when not in training mode, + # and should not be close when enabling dropout in training mode. + self.assertNotAllClose( + keras.backend.eval(train_out), + keras.backend.eval(test_out)) + class SubclassAttention(multi_head_attention.MultiHeadAttention): @@ -235,7 +251,8 @@ class SubclassAttention(multi_head_attention.MultiHeadAttention): query_tensor, key_tensor, value_tensor, - attention_mask=None): + attention_mask=None, + training=None): return value_tensor, None diff --git a/tensorflow/python/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py index 96c6d595bdf..47133891aa8 100644 --- a/tensorflow/python/keras/layers/noise_test.py +++ b/tensorflow/python/keras/layers/noise_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras -from tensorflow.python.keras.backend import dtypes_module +from tensorflow.python.framework import dtypes from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.platform import test @@ -48,7 +48,7 @@ class NoiseLayersTest(keras_parameterized.TestCase): @staticmethod def _make_model(dtype, class_type): - assert dtype in (dtypes_module.float32, dtypes_module.float64) + assert dtype in (dtypes.float32, dtypes.float64) assert class_type in ('gaussian_noise', 'gaussian_dropout', 'alpha_noise') model = keras.Sequential() model.add(keras.layers.Dense(8, input_shape=(32,), dtype=dtype)) @@ -70,22 +70,22 @@ class NoiseLayersTest(keras_parameterized.TestCase): model.train_on_batch(np.zeros((8, 32)), np.zeros((8, 8))) def test_noise_float32(self): - self._train_model(dtypes_module.float32, 'gaussian_noise') + self._train_model(dtypes.float32, 'gaussian_noise') def test_noise_float64(self): - self._train_model(dtypes_module.float64, 'gaussian_noise') + self._train_model(dtypes.float64, 'gaussian_noise') def test_dropout_float32(self): - self._train_model(dtypes_module.float32, 'gaussian_dropout') + self._train_model(dtypes.float32, 'gaussian_dropout') def test_dropout_float64(self): - self._train_model(dtypes_module.float64, 'gaussian_dropout') + self._train_model(dtypes.float64, 'gaussian_dropout') def test_alpha_dropout_float32(self): - self._train_model(dtypes_module.float32, 'alpha_noise') + self._train_model(dtypes.float32, 'alpha_noise') def test_alpha_dropout_float64(self): - self._train_model(dtypes_module.float64, 'alpha_noise') + self._train_model(dtypes.float64, 'alpha_noise') if __name__ == '__main__': diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 92178bc8fb1..2809cbb0108 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -330,13 +330,13 @@ class BatchNormalizationBase(Layer): # output back to its original shape accordingly. if self._USE_V2_BEHAVIOR: if self.fused is None: - self.fused = (ndims == 4) - elif self.fused and ndims != 4: + self.fused = ndims in (4, 5) + elif self.fused and ndims not in (4, 5): raise ValueError('Batch normalization layers with fused=True only ' - 'support 4D input tensors.') + 'support 4D or 5D input tensors.') else: assert self.fused is not None - self.fused = (ndims == 4 and self._fused_can_be_used()) + self.fused = (ndims in (4, 5) and self._fused_can_be_used()) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping @@ -345,13 +345,22 @@ class BatchNormalizationBase(Layer): # common use case (turning 5D w/ virtual batch to NCHW) if self.fused: - if self.axis == [1]: + if self.axis == [1] and ndims == 4: self._data_format = 'NCHW' - elif self.axis == [3]: + elif self.axis == [1] and ndims == 5: + self._data_format = 'NCDHW' + elif self.axis == [3] and ndims == 4: self._data_format = 'NHWC' + elif self.axis == [4] and ndims == 5: + self._data_format = 'NDHWC' + elif ndims == 5: + # 5D tensors that can be passed in but should not use fused batch norm + # due to unsupported axis. + self.fused = False else: raise ValueError('Unsupported axis, fused batch norm only supports ' - 'axis == [1] or axis == [3]') + 'axis == [1] or axis == [3] for 4D input tensors or ' + 'axis == [1] or axis == [4] for 5D input tensors') axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} for x in axis_to_dim: @@ -961,7 +970,8 @@ def enclosing_xla_context(): return None -@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring +# pylint: disable=missing-docstring +@keras_export(v1=['keras.layers.BatchNormalization']) class BatchNormalization(BatchNormalizationBase): __doc__ = replace_in_base_docstring([(""" diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index f89a615bee5..79ecc3c3fe1 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -66,6 +66,15 @@ class BatchNormalizationTest(keras_parameterized.TestCase): kwargs={'scale': False, 'center': False}, input_shape=(3, 3)) + testing_utils.layer_test( + keras.layers.BatchNormalization, + kwargs={ + 'gamma_initializer': 'ones', + 'beta_initializer': 'ones', + 'moving_mean_initializer': 'zeros', + 'moving_variance_initializer': 'ones' + }, + input_shape=(3, 2, 4, 2)) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_batchnorm_weights(self): @@ -319,7 +328,7 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase): norm = normalization_v2.BatchNormalization(fused=True) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(4, 4)) - with self.assertRaisesRegex(ValueError, '4D input tensors'): + with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'): norm(inp) def test_updates_in_wrap_function(self): diff --git a/tensorflow/python/keras/layers/normalization_v2.py b/tensorflow/python/keras/layers/normalization_v2.py index afc5c4ba412..ce487d2359a 100644 --- a/tensorflow/python/keras/layers/normalization_v2.py +++ b/tensorflow/python/keras/layers/normalization_v2.py @@ -28,7 +28,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[]) # pylint: disable=g-classes-have-attributes +# pylint: disable=g-classes-have-attributes +@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[]) class SyncBatchNormalization(normalization.BatchNormalizationBase): r"""Normalize and scale inputs or activations synchronously across replicas. @@ -169,9 +170,14 @@ class SyncBatchNormalization(normalization.BatchNormalizationBase): local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True) batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32) - y_sum, y_squared_sum, global_batch_size = ( - replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [ - local_sum, local_squared_sum, batch_size])) + # TODO(b/163099951): batch the all-reduces once we sort out the ordering + # issue for NCCL. We don't have a mechanism to launch NCCL in the same + # order in each replica nowadays, so we limit NCCL to batch all-reduces. + y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum) + y_squared_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, + local_squared_sum) + global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, + batch_size) axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), @@ -204,7 +210,8 @@ class SyncBatchNormalization(normalization.BatchNormalizationBase): return (mean, variance) -@keras_export('keras.layers.BatchNormalization', v1=[]) # pylint: disable=missing-docstring +# pylint: disable=missing-docstring +@keras_export('keras.layers.BatchNormalization', v1=[]) class BatchNormalization(normalization.BatchNormalizationBase): __doc__ = normalization.replace_in_base_docstring([ diff --git a/tensorflow/python/keras/layers/ops/BUILD b/tensorflow/python/keras/layers/ops/BUILD index e63bc36f368..ea70ee982db 100644 --- a/tensorflow/python/keras/layers/ops/BUILD +++ b/tensorflow/python/keras/layers/ops/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD index 34a644b70d2..bca8898fcca 100644 --- a/tensorflow/python/keras/layers/preprocessing/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/BUILD @@ -16,8 +16,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -52,10 +50,18 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:boosted_trees_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", - "//tensorflow/python/keras/engine:base_layer", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:tf_export", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/utils:tf_utils", "//tensorflow/python/ops/parallel_for:control_flow_ops", + "//tensorflow/python/ops/ragged:ragged_functional_ops", + "//third_party/py/numpy", ], ) @@ -66,12 +72,19 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", + "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", - "//tensorflow/python/keras/engine:base_layer", + "//tensorflow/python:tf_export", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/utils:tf_utils", + "//tensorflow/python/ops/ragged:ragged_array_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", + "//third_party/py/numpy", ], ) @@ -84,10 +97,18 @@ py_library( deps = [ "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", + "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", - "//tensorflow/python/keras/engine:base_layer", + "//tensorflow/python:tensor_util", + "//tensorflow/python:tf_export", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/utils:tf_utils", + "//tensorflow/python/ops/ragged:ragged_functional_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", + "//third_party/py/numpy", ], ) @@ -108,9 +129,16 @@ py_library( "//tensorflow/python:stateful_random_ops", "//tensorflow/python:stateless_random_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:tf_export", + "//tensorflow/python:variables", + "//tensorflow/python/compat", + "//tensorflow/python/eager:context", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras/engine:base_layer", - "//tensorflow/python/keras/utils:tf_utils", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/engine:input_spec", + "//tensorflow/python/keras/utils:control_flow_util", + "//third_party/py/numpy", ], ) @@ -123,20 +151,14 @@ py_library( srcs_version = "PY2AND3", deps = [ ":table_utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/keras:backend", - "//tensorflow/python/keras/engine:base_preprocessing_layer", - "//tensorflow/python/ops/ragged", + "//tensorflow/python/keras/engine", + "//third_party/py/numpy", ], ) @@ -150,11 +172,15 @@ py_library( deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tf_export", "//tensorflow/python:util", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras/engine:base_preprocessing_layer", + "//tensorflow/python/keras/engine", + "//third_party/py/numpy", ], ) @@ -169,6 +195,8 @@ py_library( ":index_lookup", ":table_utils", "//tensorflow/python:dtypes", + "//tensorflow/python:tf_export", + "//tensorflow/python/keras/engine", ], ) @@ -180,19 +208,17 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:tensor_spec", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras/engine:base_preprocessing_layer", - "//tensorflow/python/ops/ragged", + "//tensorflow/python/keras/utils:tf_utils", + "//tensorflow/python/ops/ragged:ragged_functional_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", + "//third_party/py/numpy", ], ) @@ -210,16 +236,18 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", - "//tensorflow/python:util", + "//tensorflow/python:tf_export", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras/engine:base_preprocessing_layer", - "//tensorflow/python/ops/ragged", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/utils:layer_utils", + "//tensorflow/python/keras/utils:tf_utils", + "//tensorflow/python/ops/ragged:ragged_functional_ops", + "//tensorflow/python/ops/ragged:ragged_string_ops", + "//third_party/py/numpy", ], ) @@ -232,19 +260,23 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", + "//tensorflow/python:bincount_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:string_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", + "//tensorflow/python:tf_export", "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras/engine:base_preprocessing_layer", - "//tensorflow/python/ops/ragged", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/engine:input_spec", + "//tensorflow/python/keras/utils:layer_utils", + "//tensorflow/python/ops/ragged:ragged_tensor", + "//third_party/py/numpy", ], ) @@ -255,8 +287,9 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", - "//tensorflow/python/keras:backend", + "//tensorflow/python:platform", "//tensorflow/python/keras/engine:base_layer", ], ) @@ -272,6 +305,8 @@ py_library( ":index_lookup", ":table_utils", "//tensorflow/python:dtypes", + "//tensorflow/python:tf_export", + "//tensorflow/python/keras/engine", ], ) @@ -283,9 +318,11 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/keras:engine", - "//tensorflow/python/keras/engine:base_preprocessing_layer", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/utils:tf_utils", + "//third_party/py/numpy", ], ) @@ -295,6 +332,8 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:client_testlib", + "//tensorflow/python:util", + "//third_party/py/numpy", ], ) @@ -309,9 +348,20 @@ cuda_py_test( ], deps = [ ":category_crossing", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python/keras", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/engine", + "//tensorflow/python/ops/ragged:ragged_factory_ops", + "//tensorflow/python/ops/ragged:ragged_tensor", "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) @@ -345,8 +395,15 @@ distribute_py_test( ], deps = [ ":category_encoding", + ":preprocessing_test_utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python:config", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_combinations_lib", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/keras", ], ) @@ -364,8 +421,15 @@ distribute_py_test( ], deps = [ ":category_crossing", + ":preprocessing_test_utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python:config", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_combinations_lib", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/keras", ], ) @@ -380,7 +444,12 @@ distribute_py_test( ], tpu_tags = ["no_oss"], deps = [ - ":category_crossing", + ":image_preprocessing", + ":preprocessing_test_utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_combinations_lib", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/keras", @@ -397,6 +466,7 @@ tf_py_test( ":discretization", ":preprocessing_test_utils", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", "@absl_py//absl/testing:parameterized", ], ) @@ -406,9 +476,15 @@ distribute_py_test( srcs = ["discretization_distribution_test.py"], main = "discretization_distribution_test.py", python_version = "PY3", - tags = ["multi_and_single_gpu"], + tags = [ + "multi_and_single_gpu", + ], deps = [ ":discretization", + ":preprocessing_test_utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python:config", + "//tensorflow/python:framework_test_combinations_lib", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/keras", @@ -424,8 +500,16 @@ cuda_py_test( deps = [ ":hashing", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python/keras", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/engine", + "//tensorflow/python/ops/ragged:ragged_factory_ops", "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", ], ) @@ -484,7 +568,18 @@ cuda_py_test( deps = [ ":image_preprocessing", "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:image_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:stateful_random_ops", + "//tensorflow/python:stateless_random_ops", + "//tensorflow/python/compat", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/keras", "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/engine", + "//tensorflow/python/keras/utils:generic_utils", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -499,6 +594,7 @@ tf_py_test( ":normalization", ":preprocessing_test_utils", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", "@absl_py//absl/testing:parameterized", ], ) @@ -524,11 +620,18 @@ distribute_py_test( srcs = ["normalization_distribution_test.py"], main = "normalization_distribution_test.py", python_version = "PY3", - tags = ["no_oss"], + tags = [ + "no_oss", + ], deps = [ ":normalization", + ":preprocessing_test_utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_combinations_lib", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/eager:context", "//tensorflow/python/keras", ], ) @@ -580,10 +683,16 @@ distribute_py_test( "no_oss", # b/155502591 ], deps = [ + ":preprocessing_test_utils", ":text_vectorization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:config", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_combinations_lib", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", - "//tensorflow/python/eager:test", + "//tensorflow/python/eager:context", "//tensorflow/python/keras", ], ) @@ -595,6 +704,7 @@ tf_py_test( deps = [ ":reduction", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD index 7c8d904acb1..88693c7fa25 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -22,8 +20,15 @@ tf_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:random_ops", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras/layers/preprocessing:category_encoding", + "@absl_py//absl/flags", ], ) @@ -33,8 +38,16 @@ tf_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras/layers/preprocessing:category_crossing", + "@absl_py//absl/flags", ], ) @@ -44,8 +57,16 @@ tf_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras/layers/preprocessing:hashing", + "@absl_py//absl/flags", ], ) @@ -55,8 +76,15 @@ tf_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras/layers/preprocessing:index_lookup", + "@absl_py//absl/flags", ], ) @@ -66,8 +94,17 @@ tf_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras/layers/preprocessing:normalization", + "@absl_py//absl/flags", ], ) @@ -77,7 +114,17 @@ cuda_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:random_ops", + "//tensorflow/python/compat:v2_compat", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/keras/layers/preprocessing:image_preprocessing", + "@absl_py//absl/flags", ], ) diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding.py b/tensorflow/python/keras/layers/preprocessing/category_encoding.py index 87112fa3d04..35c4363ccdd 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_encoding.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding.py @@ -34,6 +34,7 @@ from tensorflow.python.keras.engine.input_spec import InputSpec from tensorflow.python.keras.utils import layer_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import bincount_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops @@ -128,7 +129,7 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): # We need to call super() before we call _add_state_variable(). combiner = _CategoryEncodingCombiner( - compute_max_element=max_tokens is None, + max_tokens=max_tokens, compute_idf=output_mode == TFIDF) super(CategoryEncoding, self).__init__(combiner=combiner, **kwargs) base_preprocessing_layer._kpl_gauge.get_cell("V2").set("CategoryEncoding") @@ -138,15 +139,6 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): self._sparse = sparse self._called = False - # We are adding these here instead of in build() since they do not depend - # on the input shape at all. - if max_tokens is None: - self.num_elements = self._add_state_variable( - name=_NUM_ELEMENTS_NAME, - shape=(), - dtype=dtypes.int32, - initializer=init_ops.zeros_initializer) - if self._output_mode == TFIDF: # The TF-IDF weight may have a (None,) tensorshape. This creates # a 1D variable with arbitrary shape, which we can assign any weight to @@ -159,6 +151,8 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): else: initializer = init_ops.zeros_initializer + # We are adding these here instead of in build() since they do not depend + # on the input shape at all. self.tf_idf_weights = self._add_state_variable( name=_IDF_NAME, shape=tensor_shape.TensorShape((max_tokens,)), @@ -198,16 +192,17 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): if not reset_state: raise ValueError("CategoryEncoding does not support streaming adapts.") - if self._called and self._max_tokens is None: - raise RuntimeError("CategoryEncoding can't be adapted after being called " - "if max_tokens is None.") super(CategoryEncoding, self).adapt(data, reset_state) def _set_state_variables(self, updates): if not self.built: raise RuntimeError("_set_state_variables() must be called after build().") - if self._max_tokens is None: - self.set_num_elements(updates[_NUM_ELEMENTS_NAME]) + if _NUM_ELEMENTS_NAME in updates: + if self._max_tokens is None: + self.set_num_elements(updates[_NUM_ELEMENTS_NAME]) + elif self._max_tokens != updates[_NUM_ELEMENTS_NAME]: + raise RuntimeError("Cannot update states if you construct the layer " + "with `max_tokens`={}".format(self._max_tokens)) if self._output_mode == TFIDF: self.set_tfidf_data(updates[_IDF_NAME]) @@ -248,7 +243,7 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): if self._called: raise RuntimeError("num_elements cannot be changed after the layer is " "called.") - K.set_value(self.num_elements, num_elements) + self._max_tokens = num_elements def set_tfidf_data(self, tfidf_data): tfidf_data = self._convert_to_ndarray(tfidf_data) @@ -278,12 +273,10 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): "or `output_mode='binary'`. Please pass a single input.") self._called = True if self._max_tokens is None: - out_depth = K.get_value(self.num_elements) - if out_depth == 0: - raise RuntimeError( - "If you construct a `CategoryEncoding` layer with " - "`max_tokens=None`, you need to call `adapt()` " - "on it before using it") + raise RuntimeError( + "If you construct a `CategoryEncoding` layer with " + "`max_tokens=None`, you need to call `adapt()` " + "on it before using it") else: out_depth = self._max_tokens @@ -303,11 +296,26 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): return tf_idf_data binary_output = (self._output_mode == BINARY) + if isinstance(inputs, sparse_tensor.SparseTensor): + max_value = math_ops.reduce_max(inputs.values) + min_value = math_ops.reduce_min(inputs.values) + else: + max_value = math_ops.reduce_max(inputs) + min_value = math_ops.reduce_min(inputs) + condition = math_ops.logical_and( + math_ops.greater_equal( + math_ops.cast(out_depth, max_value.dtype), max_value), + math_ops.greater_equal( + min_value, math_ops.cast(0, min_value.dtype))) + control_flow_ops.Assert( + condition, ["Input values must be in the range 0 <= values < max_tokens" + " with max_tokens={}".format(out_depth)]) if self._sparse: result = bincount_ops.sparse_bincount( inputs, weights=count_weights, minlength=out_depth, + maxlength=out_depth, axis=-1, binary_output=binary_output) result = math_ops.cast(result, K.floatx()) @@ -322,6 +330,7 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): inputs, weights=count_weights, minlength=out_depth, + maxlength=out_depth, dtype=K.floatx(), axis=-1, binary_output=binary_output) @@ -350,9 +359,9 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): MAX_VALUE_IDX = 0 DOC_ID_IDX = 1 - def __init__(self, compute_max_element=True, compute_idf=False): + def __init__(self, max_tokens=None, compute_idf=False): + self._max_tokens = max_tokens self._compute_idf = compute_idf - self._compute_max_element = compute_max_element def compute(self, values, accumulator=None): """Computes a step in this computation, returning a new accumulator.""" @@ -367,9 +376,10 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): element = [element] current_doc_id = accumulator.data[self.DOC_ID_IDX] for value in element: - current_max_value = accumulator.data[self.MAX_VALUE_IDX] - if value > current_max_value: - accumulator.data[self.MAX_VALUE_IDX] = value + if self._max_tokens is None: + current_max_value = accumulator.data[self.MAX_VALUE_IDX] + if value > current_max_value: + accumulator.data[self.MAX_VALUE_IDX] = value if self._compute_idf: doc_count = accumulator.per_doc_count_dict[value] if doc_count["last_doc_id"] != current_doc_id: @@ -389,9 +399,10 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): for accumulator in accumulators[1:]: base_accumulator.data[self.DOC_ID_IDX] += accumulator.data[ self.DOC_ID_IDX] - base_accumulator.data[self.MAX_VALUE_IDX] = max( - base_accumulator.data[self.MAX_VALUE_IDX], - accumulator.data[self.MAX_VALUE_IDX]) + if self._max_tokens is None: + base_accumulator.data[self.MAX_VALUE_IDX] = max( + base_accumulator.data[self.MAX_VALUE_IDX], + accumulator.data[self.MAX_VALUE_IDX]) if self._compute_idf: for token, value in accumulator.per_doc_count_dict.items(): # Any newly created token counts in 'base_accumulator''s @@ -431,10 +442,13 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): the IDF value for index i. Only returned if `compute_idf` is True. """ data, document_counts = accumulator - max_element = data[self.MAX_VALUE_IDX] + if data[self.MAX_VALUE_IDX] is not None: + max_element = data[self.MAX_VALUE_IDX] + 1 + else: + max_element = self._max_tokens output_dict = {} - if self._compute_max_element: - output_dict[_NUM_ELEMENTS_NAME] = max_element + 1 + if self._max_tokens is None: + output_dict[_NUM_ELEMENTS_NAME] = max_element if self._compute_idf: num_documents = data[self.DOC_ID_IDX] @@ -444,7 +458,7 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): # the dict directly for those values gives us meaningful counts (of 0). # However, this also means we can't just extract the values in # document_counts - we need to do a deliberate indexing using range(). - doc_counts = [document_counts[i]["count"] for i in range(max_element + 1)] + doc_counts = [document_counts[i]["count"] for i in range(max_element)] idf = self._inverse_document_frequency(doc_counts, num_documents) output_dict[_IDF_NAME] = idf @@ -493,5 +507,8 @@ class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): per_doc_count_dict = collections.defaultdict(create_default_dict) else: per_doc_count_dict = None - data = [0, 0] + if self._max_tokens is None: + data = [0, 0] + else: + data = [None, 0] return _CategoryEncodingAccumulator(data, per_doc_count_dict) diff --git a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py index 7e7f7f32be0..3b2026e5048 100644 --- a/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py @@ -28,6 +28,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import backend from tensorflow.python.keras import keras_parameterized @@ -267,6 +268,34 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase, model = keras.Model(inputs=input_data, outputs=output_data) _ = model.predict(input_array, steps=1) + def test_dense_oov_input(self): + input_array = constant_op.constant([[1, 2, 3], [4, 3, 4]]) + max_tokens = 3 + expected_output_shape = [None, max_tokens] + encoder_layer = get_layer_class()(max_tokens) + input_data = keras.Input(shape=(3,), dtype=dtypes.int32) + int_data = encoder_layer(input_data) + self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) + model = keras.Model(inputs=input_data, outputs=int_data) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + ".*must be in the range 0 <= values < max_tokens.*"): + _ = model.predict(input_array, steps=1) + + def test_dense_negative(self): + input_array = constant_op.constant([[1, 2, 0], [2, 2, -1]]) + max_tokens = 3 + expected_output_shape = [None, max_tokens] + encoder_layer = get_layer_class()(max_tokens) + input_data = keras.Input(shape=(3,), dtype=dtypes.int32) + int_data = encoder_layer(input_data) + self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) + model = keras.Model(inputs=input_data, outputs=int_data) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + ".*must be in the range 0 <= values < max_tokens.*"): + _ = model.predict(input_array, steps=1) + @keras_parameterized.run_all_keras_modes class CategoryEncodingAdaptTest(keras_parameterized.TestCase, @@ -324,28 +353,6 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase, output_dataset = model.predict(input_array, steps=1) self.assertAllEqual(expected_output, output_dataset) - def test_adapt_after_build(self): - vocab_data = np.array([[1, 1, 1, 1, 2, 2, 2, 3, 3, 4]]) - input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) - - # pyformat: disable - expected_output = [[0, 1, 1, 1, 0], - [1, 1, 0, 1, 0]] - # pyformat: enable - max_tokens = 5 - expected_output_shape = [None, max_tokens] - - input_data = keras.Input(shape=(None,), dtype=dtypes.int32) - layer = get_layer_class()( - max_tokens=max_tokens, output_mode=category_encoding.BINARY) - int_data = layer(input_data) - layer.adapt(vocab_data) - self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) - - model = keras.Model(inputs=input_data, outputs=int_data) - output_dataset = model.predict(input_array) - self.assertAllEqual(expected_output, output_dataset) - def test_hard_maximum_set_state_variables_after_build(self): state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5} input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) @@ -403,20 +410,10 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase, max_tokens=None, output_mode=category_encoding.BINARY) layer.adapt([1, 2]) _ = layer(input_data) - with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): + with self.assertRaisesRegex( + RuntimeError, ".*'max_tokens' arg must be set to None."): layer.set_num_elements(5) - def test_adapt_after_call_fails(self): - vocab_data = np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 4]) - - input_data = keras.Input(shape=(None,), dtype=dtypes.int32) - layer = get_layer_class()( - max_tokens=None, output_mode=category_encoding.BINARY) - layer.adapt(vocab_data) - _ = layer(input_data) - with self.assertRaisesRegex(RuntimeError, "can't be adapted"): - layer.adapt(vocab_data) - def test_set_state_variables_after_call_fails(self): state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5} @@ -425,7 +422,7 @@ class CategoryEncodingAdaptTest(keras_parameterized.TestCase, max_tokens=None, output_mode=category_encoding.BINARY) layer.adapt([1, 2]) _ = layer(input_data) - with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): + with self.assertRaisesRegex(RuntimeError, "Cannot update states.*"): layer._set_state_variables(state_variables) @@ -468,7 +465,7 @@ class CategoryEncodingOutputTest(keras_parameterized.TestCase, input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( max_tokens=None, output_mode=category_encoding.BINARY) - layer.set_weights([np.array(max_tokens)]) + layer.set_num_elements(max_tokens) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -508,7 +505,7 @@ class CategoryEncodingOutputTest(keras_parameterized.TestCase, input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( max_tokens=None, output_mode=category_encoding.COUNT) - layer.set_weights([np.array(max_tokens)]) + layer.set_num_elements(max_tokens) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -604,7 +601,7 @@ class CategoryEncodingModelBuildingTest( weights = [] if max_tokens is None: - weights.append(np.array(5)) + layer.set_num_elements(5) if output_mode == category_encoding.TFIDF: weights.append(tfidf_data) @@ -755,6 +752,20 @@ class CategoryEncodingCombinerTest( RuntimeError, r".*you need to call.*"): _ = layer([1, 2, 3]) + def test_saving_loading(self): + encoder = category_encoding.CategoryEncoding() + encoder.adapt([1, 2, 3]) + model = keras.Sequential([encoder]) + model.save("/tmp/model", save_format="tf") + loaded_model = keras.models.load_model("/tmp/model") + self.assertAllClose(model.predict([[1]]), loaded_model.predict([[1]])) + + def test_serialize(self): + encoder = category_encoding.CategoryEncoding() + encoder.adapt([1, 2, 3]) + model = keras.Sequential([encoder]) + _ = keras.models.clone_model(model) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/normalization.py b/tensorflow/python/keras/layers/preprocessing/normalization.py index b8cf233d780..1de1c30d8e6 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization.py @@ -29,6 +29,7 @@ from tensorflow.python.keras.engine import base_preprocessing_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.util import compat from tensorflow.python.util.tf_export import keras_export @@ -37,7 +38,15 @@ _MEAN_NAME = 'mean' _VARIANCE_NAME = 'variance' -# TODO(momernick): Find a good example of normalization? +def convert_to_ndarray(values): + if isinstance(values, np.ndarray): + return values + elif isinstance(values, ops.Tensor): + return K.get_value(values) + else: + return np.array(values) + + @keras_export('keras.layers.experimental.preprocessing.Normalization', v1=[]) class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer): """Feature-wise normalization of the data. @@ -56,9 +65,17 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer): normalization statistics. By default the last axis, the `features` axis is kept and any `space` or `time` axes are summed. Each element in the the axes that are kept is normalized independently. If `axis` is set to - 'None', the layer will perform scalar normalization (diving the input + 'None', the layer will perform scalar normalization (dividing the input by a single scalar value). The `batch` axis, 0, is always summed over (`axis=0` is not allowed). + mean: The mean value(s) to use during normalization. The passed value(s) + will be broadcast to the shape of the kept axes above; if the value(s) + cannot be broadcast, an error will be raised when this layer's build() + method is called. + variance: The variance value(s) to use during normalization. The passed + value(s) will be broadcast to the shape of the kept axes above; if the + value(s)cannot be broadcast, an error will be raised when this layer's + build() method is called. Examples: @@ -70,12 +87,22 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer): >>> layer.adapt(adapt_data) >>> layer(input_data) + + Pass the mean and variance directly. + + >>> input_data = np.array([[1.], [2.], [3.]], np.float32) + >>> layer = Normalization(mean=3., variance=2.) + >>> layer(input_data) + """ - def __init__(self, axis=-1, dtype=None, **kwargs): + def __init__(self, axis=-1, dtype=None, mean=None, variance=None, **kwargs): # This ensures that if the value of K.floatx() changes after file-loading # time, the dtype value will change to reflect it. dtype = dtype or K.floatx() @@ -97,6 +124,24 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer): self.axis = axis + if isinstance(mean, variables.Variable): + raise ValueError('Normalization does not support passing a Variable ' + 'for the `mean` init arg.') + if isinstance(variance, variables.Variable): + raise ValueError('Normalization does not support passing a Variable ' + 'for the `variance` init arg.') + + if mean is not None and variance is not None: + mean = convert_to_ndarray(mean) + variance = convert_to_ndarray(variance) + elif mean is not None or variance is not None: + raise ValueError( + 'When setting values directly, both `mean` and `variance` ' + 'must be set. Got mean: {} and variance: {}'.format(mean, variance)) + + self.mean_val = mean + self.variance_val = variance + def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() if len(input_shape) == 1: @@ -144,6 +189,11 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer): super(Normalization, self).build(input_shape) + if (self.mean_val is not None and self.variance_val is not None): + mean_val = self.mean_val * np.ones(mean_and_var_shape) + variance_val = self.variance_val * np.ones(mean_and_var_shape) + self.set_weights([mean_val, variance_val]) + def call(self, inputs): inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) if inputs.shape.rank == 1: @@ -297,8 +347,9 @@ class _NormalizingCombiner(base_preprocessing_layer.Combiner): if (count == 0 and (mean.any() != 0.0 or var.any() != 0.0)): raise RuntimeError( 'The mean and/or variance of a Normalization preprocessing layer ' - "were set without also setting 'count'. If 'count' is not also set," - " 'adapt' cannot be called unless the 'reset_state' arg is True.") + "were set without also setting 'count'. If 'count' is not also set, " + " or was set to 0, 'adapt' cannot be called unless the 'reset_state'" + 'arg is True.') return self._create_accumulator(output[_COUNT_NAME], output[_MEAN_NAME], output[_VARIANCE_NAME]) diff --git a/tensorflow/python/keras/layers/preprocessing/normalization_test.py b/tensorflow/python/keras/layers/preprocessing/normalization_test.py index 69eafc54adc..f629b88f369 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization_test.py @@ -25,12 +25,14 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers.preprocessing import normalization from tensorflow.python.keras.layers.preprocessing import normalization_v1 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -216,7 +218,7 @@ class NormalizationTest(keras_parameterized.TestCase, @parameterized.named_parameters(*_get_layer_computation_test_cases()) def test_layer_computation(self, adapt_data, axis, test_data, use_dataset, expected): - input_shape = tuple([None for _ in range(test_data.ndim - 1)]) + input_shape = tuple([test_data.shape[i] for i in range(1, test_data.ndim)]) if use_dataset: # Keras APIs expect batched datasets adapt_data = dataset_ops.Dataset.from_tensor_slices(adapt_data).batch( @@ -235,6 +237,45 @@ class NormalizationTest(keras_parameterized.TestCase, output_data = model.predict(test_data) self.assertAllClose(expected, output_data) + weights = layer.get_weights() + mean = weights[0] + var = weights[1] + + direct_set_layer = cls(axis=axis, mean=mean, variance=var) + input_data = keras.Input(shape=input_shape) + output = direct_set_layer(input_data) + model = keras.Model(input_data, output) + model._run_eagerly = testing_utils.should_run_eagerly() + output_data = model.predict(test_data) + self.assertAllClose(expected, output_data) + + def test_broadcasting_during_direct_setting(self): + cls = get_layer_class() + layer = cls(axis=-1, mean=[1.0], variance=[2.0]) + layer.build((None, 2)) + weights = layer.get_weights() + self.assertAllClose([1.0, 1.0], weights[0]) + self.assertAllClose([2.0, 2.0], weights[1]) + + def test_broadcasting_during_direct_setting_with_tensors(self): + cls = get_layer_class() + layer = cls( + axis=-1, + mean=constant_op.constant([1.0]), + variance=constant_op.constant([2.0])) + layer.build((None, 2)) + weights = layer.get_weights() + self.assertAllClose([1.0, 1.0], weights[0]) + self.assertAllClose([2.0, 2.0], weights[1]) + + def test_broadcasting_during_direct_setting_with_variables_fails(self): + cls = get_layer_class() + with self.assertRaisesRegex(ValueError, "passing a Variable"): + _ = cls( + axis=-1, + mean=variables.Variable([1.0]), + variance=variables.Variable([2.0])) + def test_mean_setting_continued_adapt_failure(self): if not context.executing_eagerly(): diff --git a/tensorflow/python/keras/layers/preprocessing/preprocessing_stage.py b/tensorflow/python/keras/layers/preprocessing/preprocessing_stage.py index 5d4b5d2e2cd..6bcae297d51 100644 --- a/tensorflow/python/keras/layers/preprocessing/preprocessing_stage.py +++ b/tensorflow/python/keras/layers/preprocessing/preprocessing_stage.py @@ -161,9 +161,9 @@ class FunctionalPreprocessingStage(base_preprocessing_layer.PreprocessingLayer, >>> stage = FunctionalPreprocessingStage(inputs, outputs) >>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)), ... 'x2': tf.ones((4,1))}) - >>> ds.element_spec # Check element_spec - {'x1': TensorSpec(shape=(5,), dtype=tf.float32, name=None), - 'x2': TensorSpec(shape=(1,), dtype=tf.float32, name=None)} + >>> sorted(ds.element_spec.items()) # Check element_spec + [('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)), + ('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))] >>> stage.adapt(ds) >>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))} >>> stage.adapt(data_np) diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py index 36e326bdc5c..6449d8afaf7 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py @@ -155,6 +155,10 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): the number of unique tokens in the vocabulary is less than max_tokens, resulting in a tensor of shape [batch_size, max_tokens] regardless of vocabulary size. Defaults to True. + vocabulary: An optional list of vocabulary terms, or a path to a text file + containing a vocabulary to load into this layer. The file should contain + one token per line. If the list or file contains the same token multiple + times, an error will be thrown. Example: This example instantiates a TextVectorization layer that lowercases text, @@ -196,6 +200,43 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): array([[2, 1, 4, 0], [1, 3, 0, 0]]) + Example: + This example instantiates a TextVectorization layer by passing a list + of vocabulary terms to the layer's __init__ method. + + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] + + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=None, + standardize=None, + split=None, + output_mode=text_vectorization.INT, + vocabulary=vocab_data) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + + output_dataset = model.predict(input_array) + >>> vocab_data = ["earth", "wind", "and", "fire"] + >>> max_len = 4 # Sequence length to pad the outputs to. + >>> + >>> # Create the layer, passing the vocab directly. You can also pass the + >>> # vocabulary arg a path to a file containing one vocabulary word per + >>> # line. + >>> vectorize_layer = TextVectorization( + ... max_tokens=max_features, + ... output_mode='int', + ... output_sequence_length=max_len, + ... vocabulary=vocab_data) + >>> + >>> # Because we've passed the vocabulary directly, we don't need to adapt + >>> # the layer - the vocabulary is already set. The vocabulary contains the + >>> # padding token ('') and OOV token ('[UNK]') as well as the passed tokens. + >>> vectorize_layer.get_vocabulary() + ['', '[UNK]', 'earth', 'wind', 'and', 'fire'] + """ # TODO(momernick): Add an examples section to the docstring. @@ -207,6 +248,7 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): output_mode=INT, output_sequence_length=None, pad_to_max_tokens=True, + vocabulary=None, **kwargs): # This layer only applies to string processing, and so should only have @@ -295,7 +337,7 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): mask_token = "" if output_mode in [None, INT] else None self._index_lookup_layer = self._get_index_lookup_class()( - max_tokens=max_tokens, mask_token=mask_token) + max_tokens=max_tokens, mask_token=mask_token, vocabulary=vocabulary) # If this layer is configured for string or integer output, we do not # create a vectorization layer (as the output is not vectorized). @@ -404,6 +446,9 @@ class TextVectorization(base_preprocessing_layer.CombinerPreprocessingLayer): return self._index_lookup_layer.get_vocabulary() def get_config(self): + # This does not include the 'vocabulary' arg, since if the vocab was passed + # at init time it's now stored in variable state - we don't need to + # pull it off disk again. config = { "max_tokens": self._max_tokens, "standardize": self._standardize, diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py index e7f61e94724..a1f9f54a39f 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py @@ -690,6 +690,25 @@ class TextVectorizationPreprocessingTest( self.assertAllEqual(expected_output, output) + def test_vocab_setting_via_init(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] + + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=None, + standardize=None, + split=None, + output_mode=text_vectorization.INT, + vocabulary=vocab_data) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + @keras_parameterized.run_all_keras_modes class TextVectorizationDistributionTest( @@ -1042,24 +1061,6 @@ class TextVectorizationOutputTest( with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"): layer.set_vocabulary(vocab_data) - def test_bag_output_soft_maximum_adapt_after_call_fails(self): - vocab_data = np.array([ - "earth", "earth", "earth", "earth", "wind", "wind", "wind", "and", - "and", "fire" - ]) - - input_data = keras.Input(shape=(None,), dtype=dtypes.string) - layer = get_layer_class()( - max_tokens=None, - standardize=None, - split=None, - output_mode=text_vectorization.BINARY, - pad_to_max_tokens=False) - layer.adapt(vocab_data) - _ = layer(input_data) - with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"): - layer.adapt(vocab_data) - def test_bag_output_soft_maximum_set_state_variables_after_call_fails(self): state_variables = { text_vectorization._VOCAB_NAME: ["earth", "wind", "and", "fire"] diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index 94bc06a067f..8eddc14f5f8 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -67,9 +67,42 @@ _CUDNN_NOT_AVAILABLE_MSG = ('Layer %s will not use cuDNN kernel since it ' def _use_new_code(): - # TODO(b/168313799): Enable when the new codepath doesn't break deepcopy of - # built LSTM layers. - return False + return True + + +# TODO(b/169707691): The wrapper can be removed if TFLite doesn't need to rely +# on supportive attributes from LSTM/GRU. +class _DefunWrapper(object): + """A wrapper with no deep copy of the Defun in LSTM/GRU layer.""" + + def __init__(self, time_major, go_backwards, layer_name): + self.time_major = time_major + self.go_backwards = go_backwards + self.layer_name = layer_name + if self.layer_name not in ['lstm', 'gru']: + raise ValueError('Defun wrapper only applies to LSTM and GRU layer, ' + 'but given {}'.format(self.layer_name)) + # The first two attributes are added to support TFLite use case. + supportive_attributes = { + 'time_major': self.time_major, + 'go_backwards': self.go_backwards, + _FUNCTION_API_NAME_ATTRIBUTE: self.layer_name + '_' + str(uuid.uuid4()) + } + if self.layer_name == 'lstm': + layer_func = lstm_with_backend_selection + else: + layer_func = gru_with_backend_selection + + self.defun_layer = function.defun_with_attributes( + layer_func, + attributes=supportive_attributes, + autograph=False) + + def __deepcopy__(self, memo): + new_wrapper = type(self)( + self.time_major, self.go_backwards, self.layer_name) + memo[id(self)] = new_wrapper + return new_wrapper @keras_export('keras.layers.GRUCell', v1=[]) @@ -379,19 +412,8 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): else: logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name) - # TODO(b/162616551): Remove all compat statements - # This follows b/161915509 and is mainly to test the stateless Case op. if _use_new_code(): - # The first two attributes are added to support TFLite use case. - supportive_attributes = { - 'time_major': time_major, - 'go_backwards': go_backwards, - _FUNCTION_API_NAME_ATTRIBUTE: 'gru_' + str(uuid.uuid4()) - } - self.defun_gru_with_backend_selection = function.defun_with_attributes( - gru_with_backend_selection, - attributes=supportive_attributes, - autograph=False) + self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'gru') def build(self, input_shape): super(GRU, self).build(input_shape) @@ -489,7 +511,7 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): 'zero_output_for_mask': self.zero_output_for_mask } (last_output, outputs, new_h, - runtime) = self.defun_gru_with_backend_selection(**gru_kwargs) + runtime) = self._defun_wrapper.defun_layer(**gru_kwargs) else: gpu_gru_kwargs = { 'inputs': inputs, @@ -1122,17 +1144,7 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name) if _use_new_code(): - # The first two attributes are added to support TFLite use case. - supportive_attributes = { - 'time_major': time_major, - 'go_backwards': go_backwards, - _FUNCTION_API_NAME_ATTRIBUTE: 'lstm_' + str(uuid.uuid4()) - } - - self.defun_lstm_with_backend_selection = function.defun_with_attributes( - lstm_with_backend_selection, - attributes=supportive_attributes, - autograph=False) + self._defun_wrapper = _DefunWrapper(time_major, go_backwards, 'lstm') def call(self, inputs, mask=None, training=None, initial_state=None): # The input should be dense, padded with zeros. If a ragged input is fed @@ -1208,7 +1220,7 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): self.zero_output_for_mask, } (last_output, outputs, new_h, new_c, - runtime) = self.defun_lstm_with_backend_selection(**lstm_kwargs) + runtime) = self._defun_wrapper.defun_layer(**lstm_kwargs) else: gpu_lstm_kwargs = { 'inputs': diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py index 7b6ade79ae5..652322c810e 100644 --- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py +++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py @@ -26,8 +26,8 @@ from __future__ import print_function from tensorflow.python.keras.layers import recurrent +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_wrapper_impl from tensorflow.python.keras.utils import tf_inspect -from tensorflow.python.ops import rnn_cell_wrapper_impl from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py index 19ea3dcce90..08de645328e 100644 --- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py +++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py @@ -26,11 +26,11 @@ from tensorflow.python.framework import ops from tensorflow.python.keras import combinations from tensorflow.python.keras import layers from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl +from tensorflow.python.keras.legacy_tf_layers import base as legacy_base_layer from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test @@ -140,7 +140,7 @@ class RNNCellWrapperTest(test.TestCase, parameterized.TestCase): def testWrapperV2Caller(self, wrapper): """Tests that wrapper V2 is using the LayerRNNCell's caller.""" - with base_layer.keras_style_scope(): + with legacy_base_layer.keras_style_scope(): base_cell = rnn_cell_impl.MultiRNNCell( [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)]) rnn_cell = wrapper(base_cell) diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index bde6dd137d7..b89bafde8d2 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -158,15 +158,15 @@ def _int32_manipulation_at_max_shape_dims_limit(): # of the max tensor size Keras can try inferring values for. inputs = keras.Input(batch_size=2, shape=(10,)) batch_size = array_ops.shape(inputs)[0] - num_features = int(keras_tensor._MAX_TENSOR_DIMS / int(inputs.shape[0])) + num_features = int(keras_tensor._MAX_TENSOR_RANK / int(inputs.shape[0])) x = math_ops.range(batch_size * num_features, dtype='int32') - assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_DIMS] + assert x.shape.as_list() == [keras_tensor._MAX_TENSOR_RANK] # Verify that a value was actually inferred for a tensor that *might* # represent the shape, bying checking that a value in # the range appears in the printed inferred value if keras_tensor.keras_tensors_enabled(): - assert str(keras_tensor._MAX_TENSOR_DIMS - 1) in str(x) + assert str(keras_tensor._MAX_TENSOR_RANK - 1) in str(x) x = array_ops.reshape(x, (batch_size, num_features)) x = math_ops.cast(x, dtype='float32') diff --git a/tensorflow/python/keras/legacy_tf_layers/BUILD b/tensorflow/python/keras/legacy_tf_layers/BUILD index df24b618efa..45ccd958db0 100644 --- a/tensorflow/python/keras/legacy_tf_layers/BUILD +++ b/tensorflow/python/keras/legacy_tf_layers/BUILD @@ -5,7 +5,10 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( - default_visibility = ["//tensorflow:__subpackages__"], + default_visibility = [ + "//tensorflow:__subpackages__", + "//tensorflow/python/keras:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) @@ -144,7 +147,6 @@ tf_py_test( srcs = ["convolutional_test.py"], main = "convolutional_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":convolutional", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index c66bc55a9a2..d739c16f116 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -183,7 +183,7 @@ class Loss(object): Returns: Loss values with the shape `[batch_size, d0, .. dN-1]`. """ - NotImplementedError('Must be implemented in subclasses.') + raise NotImplementedError('Must be implemented in subclasses.') def _get_reduction(self): """Handles `AUTO` reduction cases and returns the reduction value.""" @@ -1288,7 +1288,7 @@ def mean_squared_logarithmic_error(y_true, y_pred): >>> assert loss.shape == (2,) >>> y_true = np.maximum(y_true, 1e-7) >>> y_pred = np.maximum(y_pred, 1e-7) - >>> assert np.array_equal( + >>> assert np.allclose( ... loss.numpy(), ... np.mean( ... np.square(np.log(y_true + 1.) - np.log(y_pred + 1.)), axis=-1)) @@ -1456,7 +1456,8 @@ def huber(y_true, y_pred, delta=1.0): axis=-1) -@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh') +@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh', + 'keras.metrics.log_cosh', 'keras.metrics.logcosh') @dispatch.add_dispatch_support def log_cosh(y_true, y_pred): """Logarithm of the hyperbolic cosine of the prediction error. diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py index b297063e0d3..1ce86e0f355 100644 --- a/tensorflow/python/keras/metrics_test.py +++ b/tensorflow/python/keras/metrics_test.py @@ -37,6 +37,8 @@ from tensorflow.python.keras import layers from tensorflow.python.keras import metrics from tensorflow.python.keras import Model from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.engine import training as training_mod from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -2059,6 +2061,47 @@ class CustomMetricsTest(test.TestCase): metric_result = tf_functioned_metric_fn(sum_metric, y_true, y_pred) self.assertAllClose(self.evaluate(metric_result), 10, 1e-2) + def test_metric_not_tracked_as_sublayer_in_layer(self): + + class MyLayer(base_layer.Layer): + + def __init__(self, **kwargs): + super(MyLayer, self).__init__(**kwargs) + self.mean_obj = metrics.Mean(name='my_mean_obj') + + def call(self, x): + self.add_metric( + math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor') + self.add_metric(self.mean_obj(x)) + return x + + layer = MyLayer() + x = np.ones((1, 1)) + layer(x) + self.assertLen(list(layer._flatten_layers(include_self=False)), 0) + self.assertLen(layer.metrics, 2) + + def test_metric_not_tracked_as_sublayer_in_model(self): + + class MyModel(training_mod.Model): + + def __init__(self, **kwargs): + super(MyModel, self).__init__(**kwargs) + self.mean_obj = metrics.Mean(name='my_mean_obj') + + def call(self, x): + self.add_metric( + math_ops.reduce_sum(x), aggregation='mean', name='my_mean_tensor') + self.add_metric(self.mean_obj(x)) + return x + + model = MyModel() + x = np.ones((1, 1)) + model(x) + self.assertLen(list(model._flatten_layers(include_self=False)), 0) + self.assertLen(model.layers, 0) + self.assertLen(model.metrics, 2) + def _get_model(compile_metrics): model_layers = [ diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index d41b3686512..d1bd18f85a5 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -153,6 +153,7 @@ py_test( "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/distribute:test_util", "//tensorflow/python/eager:context", "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", @@ -232,6 +233,8 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:mixed_precision", "//tensorflow/python:tf2", + "//tensorflow/python/keras", + "//tensorflow/python/keras:combinations", "//tensorflow/python/keras:testing_utils", "//tensorflow/python/keras/optimizer_v2", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py index 8bb1dd1a2d4..738333039da 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable_test.py @@ -28,6 +28,7 @@ from tensorflow.python.distribute import combinations as ds_combinations from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op @@ -58,7 +59,7 @@ def get_var(val, dtype, name=None): class AutoCastVariableTest(test.TestCase, parameterized.TestCase): def setUp(self): - strategy_combinations.set_virtual_cpus_to_at_least(3) + test_util.set_logical_devices_to_at_least('CPU', 3) super(AutoCastVariableTest, self).setUp() @ds_combinations.generate(maybe_distribute) diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index 8eafe725514..dd754e87bb4 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -53,6 +53,7 @@ from tensorflow.python.keras.saving import save from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables +from tensorflow.python.platform import flags from tensorflow.python.platform import test from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.tracking import util as trackable_utils @@ -128,10 +129,11 @@ class KerasLayerTest(keras_parameterized.TestCase): @parameterized.named_parameters(*TESTCASES) def test_mixed_policies_(self, strategy_fn): + strategy = strategy_fn() for dtype in 'float16', 'bfloat16': x = constant_op.constant([1.]) policy_name = 'mixed_' + dtype - with strategy_fn().scope(), policy.policy_scope(policy_name): + with strategy.scope(), policy.policy_scope(policy_name): layer = mp_test_util.MultiplyLayer(assert_type=dtype) self.assertEqual(layer.dtype, dtypes.float32) self.assertEqual(get_layer_policy.get_layer_policy(layer).name, @@ -1006,8 +1008,12 @@ class KerasModelTest(keras_parameterized.TestCase): # The checkpoint and expected values were obtained from the program in # testdata/BUILD. - ckpt_dir = test.test_src_dir_path( - 'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2') + ckpt_dir = os.path.join( + flags.FLAGS['test_srcdir'].value, + 'org_tensorflow/tensorflow/python/keras', + 'mixed_precision/experimental/testdata/lso_ckpt_tf2.2') + # ckpt_dir = test.test_src_dir_path( + # 'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2') model.load_weights(os.path.join(ckpt_dir, 'ckpt')) model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly()) model(np.zeros((2, 2))) # Create model weights @@ -1037,9 +1043,13 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1) def test_restore_old_saved_model(self): - saved_model_dir = test.test_src_dir_path( - 'python/keras/mixed_precision/experimental/testdata/' - 'lso_savedmodel_tf2.2') + saved_model_dir = os.path.join( + flags.FLAGS['test_srcdir'].value, + 'org_tensorflow/tensorflow/python/keras', + 'mixed_precision/experimental/testdata/lso_savedmodel_tf2.2') + # saved_model_dir = test.test_src_dir_path( + # 'python/keras/mixed_precision/experimental/testdata/' + # 'lso_savedmodel_tf2.2') model = save.load_model(saved_model_dir) expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]]) self.assertAllClose(backend.eval(model.weights[0]), expected_kernel) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index eb31c647ca3..dd7bf6a682d 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -229,6 +229,41 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): >>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here >>> var.numpy() 0.25 + + Hyperparameters can be accessed and set on the LossScaleOptimizer, which will + be delegated to the wrapped optimizer. + + >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) + >>> lso = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, + ... "dynamic") + >>> opt.beta_1 + 0.8 + >>> lso.beta_1 # Equivalent to `opt.beta_1` + 0.8 + >>> lso.beta_1 = 0.7 # Equivalent to `opt.beta_1 = 0.7` + >>> opt.beta_1 + 0.7 + >>> lso.beta_1 + 0.7 + + However, accessing or setting non-hyperparameters is not delegated to the + LossScaleOptimizer. In an Adam optimizer, `beta_1` is a hyperparameter but + `epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on + `beta_1`. + + >>> opt.epsilon + 1e-5 + >>> lso.epsilon + Traceback (most recent call last): + ... + AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon' + >>> lso.epsilon = 1e-4 + >>> opt.epsilon + >>> 1e-5 + + In the above example, despite epsilon being set on the LossScaleOptimizer, the + old epsilon value will still be used when training as epsilon was not set on + the Adam optimizer. """ _HAS_AGGREGATE_GRAD = True @@ -268,9 +303,6 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): backend.track_variable(weight) self._track_trackable(self._loss_scale, 'loss_scale') - # Needed because the superclass's __getattribute__ checks this. - self._hyper = {} - # To support restoring TensorFlow 2.2 checkpoints. self._track_trackable(FakeOptimizerForRestoration(self._optimizer), 'base_optimizer') @@ -516,26 +548,47 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): def add_slot(self, var, slot_name, initializer='zeros'): return self._optimizer.add_slot(var, slot_name, initializer) - # For the most part, we only expose methods in the base OptimizerV2, not - # individual subclasses like Adam. However, although "learning_rate" and "lr" - # properties are not part of the base OptimizerV2 class, they are part of most - # subclasses, so we expose them here for convenience. + def __getattribute__(self, name): + try: + return object.__getattribute__(self, name) + except AttributeError as e: + if name == '_optimizer' or name == '_hyper': + # Avoid infinite recursion + raise e - @property - def learning_rate(self): - return self._optimizer.learning_rate + # Delegate hyperparameter accesses to inner optimizer. + if name == 'lr': + name = 'learning_rate' + if name in self._optimizer._hyper: + return self._optimizer._get_hyper(name) + raise e - @learning_rate.setter - def learning_rate(self, lr): - self._optimizer.learning_rate = lr + def __dir__(self): + result = set(super(LossScaleOptimizer, self).__dir__()) + if '_optimizer' in result: + result |= self._optimizer._hyper.keys() + if 'learning_rate' in self._optimizer._hyper.keys(): + result.add('lr') + return list(result) - @property - def lr(self): - return self._optimizer.lr - - @lr.setter - def lr(self, lr): - self._optimizer.lr = lr + def __setattr__(self, name, value): + if name == 'lr': + name = 'learning_rate' + # Delegate setting hyperparameter to inner optimizer if the attribute does + # not exist on the LossScaleOptimizer + try: + # We cannot check for the 'iterations' attribute as it cannot be set after + # it is accessed. + if name != 'iterations': + object.__getattribute__(self, name) + has_attribute = True + except AttributeError: + has_attribute = False + if (name != '_optimizer' and name in self._optimizer._hyper + and not has_attribute): + self._optimizer._set_hyper(name, value) + else: + super(LossScaleOptimizer, self).__setattr__(name, value) # We do not override some OptimizerV2 methods. For each, we describe why we do # not delegate them to self._optimizer: diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py index e375fcf557b..fe3a237ef83 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py @@ -382,47 +382,71 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, r'loss_scale cannot be None'): loss_scale_optimizer.LossScaleOptimizer(opt, None) - @parameterized.named_parameters(*TESTCASES) - def testGettingAndSettingLearningRate(self, strategy_fn): - with self.test_session(), strategy_fn().scope() as strategy: - var = variables.Variable([5.0]) - opt = adam.Adam(learning_rate=1.0) - loss = lambda: var * 2.0 - run_fn = lambda: opt.minimize(loss, [var]) - run_op = strategy.experimental_run(run_fn) + def testHyperParametersExposed(self): + with self.cached_session(): + opt = adam.Adam(learning_rate=1.0, beta_1=0.5, beta_2=0.9) + lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + # Force hyperparameters to be created + opt.lr # pylint: disable=pointless-statement self.evaluate(variables.global_variables_initializer()) - self._run_if_in_graph_mode(run_op) - lr = self.evaluate(opt.lr) - self.assertEqual(1.0, lr) + self.assertEqual(self.evaluate(lso.beta_1), 0.5) + self.assertIsInstance(lso.beta_1, variables.Variable) + self.assertEqual(self.evaluate(lso.lr), 1.0) + self.assertIs(lso.lr, opt.lr) + self.assertIs(lso.lr, lso.learning_rate) - opt.lr = 2.0 - lr = self.evaluate(opt.lr) - self.assertEqual(2.0, lr) + lso.beta_1 = 0.25 + self.assertEqual(self.evaluate(lso.beta_1), 0.25) + self.assertEqual(self.evaluate(opt.beta_1), 0.25) + self.assertIs(lso.beta_1, opt.beta_1) + opt.beta_1 = 0.75 + self.assertEqual(self.evaluate(lso.beta_1), 0.75) + self.assertEqual(self.evaluate(opt.beta_1), 0.75) + self.assertIs(lso.beta_1, opt.beta_1) + lso.lr = 2.0 + self.assertEqual(self.evaluate(lso.lr), 2.0) + self.assertEqual(self.evaluate(lso.learning_rate), 2.0) + self.assertEqual(self.evaluate(opt.lr), 2.0) + self.assertEqual(self.evaluate(opt.learning_rate), 2.0) + self.assertIs(lso.lr, opt.lr) - self.evaluate(opt.lr.assign(3.0)) - lr = self.evaluate(opt.lr) - self.assertEqual(3.0, lr) + # Test setting attribute that is both attribute on LossScaleOptimizer and + # hyperparameter on wrapped optimizer. + class MyOpt(gradient_descent.SGD): + def __init__(self): + super().__init__() + self._set_hyper('loss_scale', 123.) + + opt = MyOpt() + lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') with self.assertRaises(AttributeError): - opt.not_an_attr += 3 + lso.loss_scale = loss_scale_module.FixedLossScale(2.) def testArbitraryAttributesNotExposed(self): - opt = adam.Adam(learning_rate=1.0) - # Test that Adam has attributes 'epsilon' and 'beta1' - opt.epsilon # pylint: disable=pointless-statement - opt.beta_1 # pylint: disable=pointless-statement - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=10.) - # Test that attributes defined by OptimizerV2 subclasses are not exposed in - # LossScaleOptimizer, and that the error message is sensible. + opt = gradient_descent.SGD() + lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + self.assertFalse(opt.nesterov) with self.assertRaisesRegex( AttributeError, - "'LossScaleOptimizer' object has no attribute 'epsilon'"): - opt.epsilon # pylint: disable=pointless-statement - with self.assertRaisesRegex( - AttributeError, - "'LossScaleOptimizer' object has no attribute 'beta_1'"): - opt.beta_1 # pylint: disable=pointless-statement + "'LossScaleOptimizer' object has no attribute 'nesterov'"): + lso.nesterov # pylint: disable=pointless-statement + + lso.nesterov = True + self.assertTrue(lso.nesterov) + self.assertFalse(opt.nesterov) + + def testDir(self): + lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), + 'dynamic') + dir_result = dir(lso) + self.assertIn('learning_rate', dir_result) # Hyperparameter + self.assertIn('lr', dir_result) # Hyperparameter + self.assertIn('minimize', dir_result) # Attribute + self.assertIn('loss_scale', dir_result) # Attribute + self.assertNotIn('nesterov', dir_result) # Attribute on inner optimizer + self.assertIn('nesterov', dir(lso._optimizer)) def testApplyGradientsGetsUnwrappedTensors(self): # Tests that gradients passed to apply_gradients are not wrapped in a diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py index c4b414d7bf5..33f6562f796 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -32,11 +32,6 @@ from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.util.tf_export import keras_export -# Default value of certain arguments, indicating the default behavior for -# that argument should be used. -USE_DEFAULT = 'USE_DEFAULT' - - @keras_export('keras.mixed_precision.experimental.Policy', v1=[]) class Policy(object): """A dtype policy for a Keras layer. @@ -293,7 +288,7 @@ class Policy(object): layer would only work if the inputs were float32. """ - def __init__(self, name, loss_scale=USE_DEFAULT): + def __init__(self, name, loss_scale='auto'): """Constructs the policy. The `name` argument determines the compute and variable dtype, the default @@ -310,9 +305,10 @@ class Policy(object): a dynamic loss scale is used. These policies are used for mixed precision training. loss_scale: A `tf.mixed_precision.experimental.LossScale`, an int (which - uses a `FixedLossScale`), or the string "dynamic" (which uses a - `DynamicLossScale`). Defaults to using no loss scaling unless `name` is - "mixed_float16", in which case this defaults to "dynamic". Only + uses a `FixedLossScale`), the string "dynamic" (which uses a + `DynamicLossScale`), or None (which uses no loss scale). Defaults to + `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then + use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only `tf.keras.Model`s, not layers, use the loss scale, and it is only used during `Model.fit`, `Model.train_on_batch`, and other similar methods. """ @@ -324,7 +320,7 @@ class Policy(object): self._name = name self._compute_dtype, self._variable_dtype = self._parse_name(name) - if loss_scale == USE_DEFAULT: + if loss_scale == 'auto': loss_scale = 'dynamic' if name == 'mixed_float16' else None self._using_default_loss_scale = True else: diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index 11966ce8211..6d774f606b1 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -14,8 +14,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -136,7 +134,7 @@ cuda_py_test( size = "medium", srcs = ["adamax_test.py"], shard_count = 4, - tfrt_enabled = True, + # TODO(b/168527439): invalid resource variable reference on GPU for TFRT. deps = [ ":optimizer_v2", "//tensorflow/python:client_testlib", @@ -158,7 +156,7 @@ cuda_py_test( srcs = ["adadelta_test.py"], shard_count = 4, tags = ["no_rocm"], - tfrt_enabled = True, + # TODO(b/168527439): invalid resource variable reference on GPU for TFRT. deps = [ ":optimizer_v2", "//tensorflow/python:client_testlib", @@ -290,6 +288,7 @@ cuda_py_test( "//tensorflow/python:training_lib", "//tensorflow/python:variables", "//tensorflow/python/eager:context", + "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", ], ) @@ -300,7 +299,7 @@ cuda_py_test( srcs = ["rmsprop_test.py"], shard_count = 2, tags = ["no_rocm"], - tfrt_enabled = True, + # TODO(b/168527439): invalid resource variable reference on GPU for TFRT. deps = [ ":optimizer_v2", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 8eb7e1725e1..ca3f1a3a9b1 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -30,6 +30,7 @@ from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import values as ds_values from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import monitoring from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -54,6 +55,9 @@ from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export +keras_optimizers_gauge = monitoring.BoolGauge( + "/tensorflow/api/keras/optimizers", "keras optimizer usage", "method") + _DEFAULT_VALID_DTYPES = frozenset([ dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128 @@ -326,6 +330,9 @@ class OptimizerV2(trackable.Trackable): Raises: ValueError: in case of any invalid argument. """ + # Instrument optimizer usages + keras_optimizers_gauge.get_cell(self.__class__.__name__).set(True) + allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"} for k in kwargs: if k not in allowed_kwargs: @@ -790,6 +797,14 @@ class OptimizerV2(trackable.Trackable): return self._get_hyper(name) raise e + def __dir__(self): + result = set(super(OptimizerV2, self).__dir__()) + if "_hyper" in result: + result |= self._hyper.keys() + if "learning_rate" in self._hyper.keys(): + result.add("lr") + return list(result) + def __setattr__(self, name, value): """Override setattr to support dynamic hyperparameter setting.""" # Backwards compatibility with Keras optimizers. diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index 9a8946e34cc..d56ec49ad00 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -485,6 +485,16 @@ class OptimizerTest(test.TestCase, parameterized.TestCase): lr = self.evaluate(opt.lr) self.assertEqual(4.0, lr) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def testDir(self): + opt = gradient_descent.SGD(learning_rate=1.0, momentum=0.1) + dir_result = set(dir(opt)) + self.assertIn('learning_rate', dir_result) # Hyperparameter + self.assertIn('lr', dir_result) # Hyperparameter + self.assertIn('momentum', dir_result) # Hyperparameter + self.assertIn('nesterov', dir_result) # Attribute + self.assertIn('minimize', dir_result) # Attribute + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def testOptimizerWithKerasModel(self): a = input_layer.Input(shape=(3,), name='input_a') diff --git a/tensorflow/python/keras/optimizer_v2/utils.py b/tensorflow/python/keras/optimizer_v2/utils.py index 44958792c10..90f9a4975e7 100644 --- a/tensorflow/python/keras/optimizer_v2/utils.py +++ b/tensorflow/python/keras/optimizer_v2/utils.py @@ -92,7 +92,8 @@ def make_gradient_clipnorm_fn(clipnorm): def gradient_clipnorm_fn(grads_and_vars): if isinstance(distribute_ctx.get_strategy(), - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): raise ValueError( "`clipnorm` is not supported with `CenteralStorageStrategy`") @@ -112,7 +113,8 @@ def make_global_gradient_clipnorm_fn(clipnorm): def gradient_clipnorm_fn(grads_and_vars): if isinstance(distribute_ctx.get_strategy(), - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): raise ValueError( "`global_clipnorm` is not supported with `CenteralStorageStrategy`") @@ -132,7 +134,8 @@ def make_gradient_clipvalue_fn(clipvalue): def gradient_clipvalue_fn(grads_and_vars): if isinstance(distribute_ctx.get_strategy(), - central_storage_strategy.CentralStorageStrategy): + (central_storage_strategy.CentralStorageStrategy, + central_storage_strategy.CentralStorageStrategyV1)): raise ValueError( "`clipvalue` is not supported with `CenteralStorageStrategy`") diff --git a/tensorflow/python/keras/premade/BUILD b/tensorflow/python/keras/premade/BUILD index fc85ae04811..9b3852e7ce2 100644 --- a/tensorflow/python/keras/premade/BUILD +++ b/tensorflow/python/keras/premade/BUILD @@ -9,8 +9,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), diff --git a/tensorflow/python/keras/premade/linear.py b/tensorflow/python/keras/premade/linear.py index 438e3270021..f8ea38fa5f6 100644 --- a/tensorflow/python/keras/premade/linear.py +++ b/tensorflow/python/keras/premade/linear.py @@ -95,7 +95,7 @@ class LinearModel(training.Model): self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) super(LinearModel, self).__init__(**kwargs) - base_layer.keras_model_gauge.get_cell('Linear').set(True) + base_layer.keras_premade_model_gauge.get_cell('Linear').set(True) def build(self, input_shape): if isinstance(input_shape, dict): diff --git a/tensorflow/python/keras/premade/wide_deep.py b/tensorflow/python/keras/premade/wide_deep.py index edb0124276f..1f70a38cc93 100644 --- a/tensorflow/python/keras/premade/wide_deep.py +++ b/tensorflow/python/keras/premade/wide_deep.py @@ -85,7 +85,7 @@ class WideDeepModel(keras_training.Model): Allowed keyword arguments include `name`. """ super(WideDeepModel, self).__init__(**kwargs) - base_layer.keras_model_gauge.get_cell('WideDeep').set(True) + base_layer.keras_premade_model_gauge.get_cell('WideDeep').set(True) self.linear_model = linear_model self.dnn_model = dnn_model self.activation = activations.get(activation) diff --git a/tensorflow/python/keras/preprocessing/BUILD b/tensorflow/python/keras/preprocessing/BUILD index 80bc1d5459d..04925088f88 100644 --- a/tensorflow/python/keras/preprocessing/BUILD +++ b/tensorflow/python/keras/preprocessing/BUILD @@ -13,8 +13,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -31,6 +29,7 @@ py_library( ":sequence", ":text", ":timeseries", + "//tensorflow/python/keras/utils:all_utils", ], ) diff --git a/tensorflow/python/keras/preprocessing/image_dataset.py b/tensorflow/python/keras/preprocessing/image_dataset.py index 892fbe59709..66164086e7e 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset.py +++ b/tensorflow/python/keras/preprocessing/image_dataset.py @@ -203,6 +203,8 @@ def image_dataset_from_directory(directory, dataset = dataset.batch(batch_size) # Users may need to reference `class_names`. dataset.class_names = class_names + # Include file paths for images as attribute. + dataset.file_paths = image_paths return dataset diff --git a/tensorflow/python/keras/preprocessing/image_test.py b/tensorflow/python/keras/preprocessing/image_test.py index 78cec643570..2f22b92f05a 100644 --- a/tensorflow/python/keras/preprocessing/image_test.py +++ b/tensorflow/python/keras/preprocessing/image_test.py @@ -22,6 +22,7 @@ import os import shutil import tempfile +from absl.testing import parameterized import numpy as np from tensorflow.python.data import Dataset @@ -71,18 +72,21 @@ class TestImage(keras_parameterized.TestCase): output = preprocessing_image.smart_resize(test_input, size=(5, 15)) self.assertListEqual(list(output.shape), [5, 15, 3]) + @parameterized.named_parameters( + ('size1', (50, 50)), + ('size2', (10, 10)), + ('size3', (100, 50)), + ('size4', (5, 15))) @testing_utils.run_v2_only - def test_smart_resize_tf_dataset(self): + def test_smart_resize_tf_dataset(self, size): test_input_np = np.random.random((2, 20, 40, 3)) test_ds = Dataset.from_tensor_slices(test_input_np) resize = lambda img: preprocessing_image.smart_resize(img, size=size) - - for size in [(50, 50), (10, 10), (100, 50), (5, 15)]: - test_ds = test_ds.map(resize) - for sample in test_ds.as_numpy_iterator(): - self.assertIsInstance(sample, np.ndarray) - self.assertListEqual(list(sample.shape), [size[0], size[1], 3]) + test_ds = test_ds.map(resize) + for sample in test_ds.as_numpy_iterator(): + self.assertIsInstance(sample, np.ndarray) + self.assertListEqual(list(sample.shape), [size[0], size[1], 3]) def test_smart_resize_errors(self): with self.assertRaisesRegex(ValueError, 'a tuple of 2 integers'): diff --git a/tensorflow/python/keras/protobuf/BUILD b/tensorflow/python/keras/protobuf/BUILD index 644c32bb507..df253252f36 100644 --- a/tensorflow/python/keras/protobuf/BUILD +++ b/tensorflow/python/keras/protobuf/BUILD @@ -9,8 +9,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tf_proto_library( name = "projector_config_proto", srcs = ["projector_config.proto"], diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index 18aa8990dc0..4b7d0e2fef5 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -53,11 +53,13 @@ py_library( "//tensorflow/python:tensor_spec", "//tensorflow/python/eager:def_function", "//tensorflow/python/keras:backend", + "//tensorflow/python/keras:losses", "//tensorflow/python/keras:optimizers", "//tensorflow/python/keras:regularizers", "//tensorflow/python/keras/engine:input_spec", "//tensorflow/python/keras/mixed_precision/experimental:autocast_variable", "//tensorflow/python/keras/utils:engine_utils", + "//tensorflow/python/keras/utils:metrics_utils", "//tensorflow/python/keras/utils:mode_keys", "//tensorflow/python/saved_model", "//tensorflow/python/saved_model/model_utils", @@ -137,7 +139,6 @@ tf_py_test( "no_oss", # TODO(b/119349471): Re-enable "no_windows", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -207,6 +208,7 @@ tf_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ + ":saving", "//tensorflow/python:client_testlib", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/keras/saving/hdf5_format_test.py b/tensorflow/python/keras/saving/hdf5_format_test.py index e91b9b323ce..9ed90bc4999 100644 --- a/tensorflow/python/keras/saving/hdf5_format_test.py +++ b/tensorflow/python/keras/saving/hdf5_format_test.py @@ -831,29 +831,30 @@ class TestWholeModelSaving(keras_parameterized.TestCase): saved_model_dir = self._save_model_dir() save_format = testing_utils.get_save_format() - model = _make_model() - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(), - optimizer=optimizers.gradient_descent_v2.SGD(), - metrics=[keras.metrics.SparseCategoricalCrossentropy()]) - x = np.random.normal(size=(32, 4)) - y = np.random.randint(0, 3, size=32) - model.train_on_batch(x, y) - evaluation_results = model.evaluate(x, y) - # Save and reload model. - model.save(saved_model_dir, save_format=save_format) - del model # Prevent misuse. - loaded_model = keras.models.load_model(saved_model_dir) - loaded_model_eval_results = loaded_model.evaluate(x, y) - # Assert all evaluation results are the same. - self.assertAllClose(evaluation_results, loaded_model_eval_results, 1e-9) - # Check correctness of the loss calculation. - self.assertAllGreater(evaluation_results, 0.) - evaluation_results = dict( - zip(loaded_model.metrics_names, evaluation_results)) - self.assertNear( - evaluation_results['sparse_categorical_crossentropy'] + - evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6) + with self.cached_session(): + model = _make_model() + model.compile( + loss=keras.losses.SparseCategoricalCrossentropy(), + optimizer=optimizers.gradient_descent_v2.SGD(), + metrics=[keras.metrics.SparseCategoricalCrossentropy()]) + x = np.random.normal(size=(32, 4)) + y = np.random.randint(0, 3, size=32) + model.train_on_batch(x, y) + evaluation_results = model.evaluate(x, y) + # Save and reload model. + model.save(saved_model_dir, save_format=save_format) + del model # Prevent misuse. + loaded_model = keras.models.load_model(saved_model_dir) + loaded_model_eval_results = loaded_model.evaluate(x, y) + # Assert all evaluation results are the same. + self.assertAllClose(evaluation_results, loaded_model_eval_results, 1e-9) + # Check correctness of the loss calculation. + self.assertAllGreater(evaluation_results, 0.) + evaluation_results = dict( + zip(loaded_model.metrics_names, evaluation_results)) + self.assertNear( + evaluation_results['sparse_categorical_crossentropy'] + + evaluation_results['custom_loss'], evaluation_results['loss'], 1e-6) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_save_uncompiled_model_with_optimizer(self): diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index fb2114a292b..7330a4a189f 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -33,6 +33,7 @@ from tensorflow.python.keras import losses from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.feature_column import dense_features +from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc from tensorflow.python.keras.layers import core from tensorflow.python.keras.saving import model_config from tensorflow.python.keras.saving import save @@ -190,7 +191,7 @@ class TestSaveModel(test.TestCase, parameterized.TestCase): shape=(None, 1), sparse=True, name='b', dtype='string') } - fc_layer, _ = feature_column_lib.SequenceFeatures(cols)(input_layers) + fc_layer, _ = ksfc.SequenceFeatures(cols)(input_layers) # TODO(tibell): Figure out the right dtype and apply masking. # sequence_length_mask = array_ops.sequence_mask(sequence_length) # x = keras.layers.GRU(32)(fc_layer, mask=sequence_length_mask) diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index 556675d4bb5..e0426e74f6b 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -23,6 +23,7 @@ import types from tensorflow.python.eager import context from tensorflow.python.eager import function as defun from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend from tensorflow.python.keras import regularizers @@ -975,8 +976,11 @@ def infer_inputs_from_restored_call_function(fn): TensorSpec of call function inputs. """ def common_spec(x, y): - return tensor_spec.TensorSpec(defun.common_shape(x.shape, y.shape), - x.dtype, x.name) + common_shape = defun.common_shape(x.shape, y.shape) + if isinstance(x, sparse_tensor.SparseTensorSpec): + return sparse_tensor.SparseTensorSpec(common_shape, x.dtype) + return tensor_spec.TensorSpec(common_shape, x.dtype, x.name) + spec = fn.concrete_functions[0].structured_input_signature[0][0] for concrete in fn.concrete_functions[1:]: spec2 = concrete.structured_input_signature[0][0] diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py index efe977ec55f..419d02811d5 100644 --- a/tensorflow/python/keras/saving/saved_model/metric_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.keras.saving.saved_model import layer_serialization +from tensorflow.python.keras.utils import generic_utils from tensorflow.python.training.tracking import data_structures @@ -31,7 +32,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver): def _python_properties_internal(self): metadata = dict( - class_name=type(self.obj).__name__, + class_name=generic_utils.get_registered_name(type(self.obj)), name=self.obj.name, dtype=self.obj.dtype) metadata.update(layer_serialization.get_config(self.obj)) diff --git a/tensorflow/python/keras/saving/saved_model/revive_test.py b/tensorflow/python/keras/saving/saved_model/revive_test.py index 693568d4848..80170dbd708 100644 --- a/tensorflow/python/keras/saving/saved_model/revive_test.py +++ b/tensorflow/python/keras/saving/saved_model/revive_test.py @@ -33,12 +33,15 @@ from tensorflow.python import keras from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import backend from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.saving.saved_model import load as keras_load from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -72,6 +75,36 @@ class SubclassedModelNoConfig(keras.Model): return x +class SparseDense(keras.layers.Dense): + + def call(self, inputs): + input_shape = array_ops.stack( + (math_ops.reduce_prod(array_ops.shape(inputs)[:-1]), + self.kernel.shape[0])) + output_shape = array_ops.concat( + (array_ops.shape(inputs)[:-1], [self.kernel.shape[1]]), -1) + x = sparse_ops.sparse_reshape(inputs, input_shape) + return array_ops.reshape( + self.activation( + sparse_ops.sparse_tensor_dense_matmul(x, self.kernel) + self.bias), + output_shape) + + +class SubclassedSparseModelNoConfig(keras.Model): + + def __init__(self, a, b): + super(SubclassedSparseModelNoConfig, self).__init__() + self.a = a + self.shared = CustomLayerNoConfig(a, b) + self.all_layers = [SparseDense(4)] + + def call(self, inputs): + x = inputs + for layer in self.all_layers: + x = layer(x) + return self.shared(x + self.a) + + class SubclassedModelWithConfig(SubclassedModelNoConfig): def get_config(self): @@ -171,6 +204,9 @@ class ReviveTestBase(keras_parameterized.TestCase): self.evaluate(revived.weights)) input_arr = constant_op.constant( np.random.random((2, 2, 3)).astype(np.float32)) + if isinstance(revived._saved_model_inputs_spec, + sparse_tensor.SparseTensorSpec): + input_arr = sparse_ops.from_dense(input_arr) self.assertAllClose(model(input_arr), revived(input_arr)) self.assertAllClose(sum(model.losses), sum(revived.losses)) @@ -276,6 +312,15 @@ class TestModelRevive(ReviveTestBase): revived = keras_load.load(self.path) self._assert_revived_correctness(model, revived) + def test_revive_subclassed_with_sparse_model(self): + model = SubclassedSparseModelNoConfig(1., 2.) + # Run data through the Model to create save spec and weights. + x = sparse_ops.from_dense(np.ones((10, 2, 3), dtype=np.float32)) + model.predict(x, batch_size=10) + model.save(self.path, save_format='tf') + revived = keras_load.load(self.path) + self._assert_revived_correctness(model, revived) + def test_revive_sequential_inputs(self): model = keras.models.Sequential([ keras.Input((None,), dtype=dtypes.string), diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 9615fef54b9..3cb960d0971 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -965,33 +965,34 @@ class MetricTest(test.TestCase, parameterized.TestCase): num_tensor_args, shape=(1, 5), test_sample_weight=True): - tf_save.save(metric, save_dir) - loaded = keras_load.load(save_dir) - self.evaluate([v.initializer for v in loaded.variables]) - self.assertEqual(metric.name, loaded.name) - self.assertEqual(metric.dtype, loaded.dtype) + with self.cached_session(): + tf_save.save(metric, save_dir) + loaded = keras_load.load(save_dir) + self.evaluate([v.initializer for v in loaded.variables]) + self.assertEqual(metric.name, loaded.name) + self.assertEqual(metric.dtype, loaded.dtype) - inputs = self.generate_inputs(num_tensor_args, shape) - actual = self.evaluate(metric(*inputs)) - self.assertAllClose(actual, loaded(*inputs)) - self.assertAllClose(metric.variables, loaded.variables) - - # Test with separate calls to update state and result. - inputs = self.generate_inputs(num_tensor_args, shape) - self.evaluate(metric.update_state(*inputs)) - self.evaluate(loaded.update_state(*inputs)) - actual = self.evaluate(metric.result()) - self.assertAllClose(actual, loaded.result()) - - if test_sample_weight: - # Test with sample weights input. inputs = self.generate_inputs(num_tensor_args, shape) - sample_weight = self.generate_inputs(1, [])[0] - inputs.append(sample_weight) - actual = self.evaluate(metric(*inputs)) self.assertAllClose(actual, loaded(*inputs)) - return loaded + self.assertAllClose(metric.variables, loaded.variables) + + # Test with separate calls to update state and result. + inputs = self.generate_inputs(num_tensor_args, shape) + self.evaluate(metric.update_state(*inputs)) + self.evaluate(loaded.update_state(*inputs)) + actual = self.evaluate(metric.result()) + self.assertAllClose(actual, loaded.result()) + + if test_sample_weight: + # Test with sample weights input. + inputs = self.generate_inputs(num_tensor_args, shape) + sample_weight = self.generate_inputs(1, [])[0] + inputs.append(sample_weight) + + actual = self.evaluate(metric(*inputs)) + self.assertAllClose(actual, loaded(*inputs)) + return loaded @parameterized.named_parameters([ ('mean', keras.metrics.Mean, 1, (1, 5)), @@ -1054,6 +1055,33 @@ class MetricTest(test.TestCase, parameterized.TestCase): num_tensor_args, test_sample_weight=False) + def test_registered_custom_metric(self): + + @generic_utils.register_keras_serializable('Testing') + class CustomMeanMetric(keras.metrics.Mean): + + def update_state(self, *args): # pylint: disable=useless-super-delegation + # Sometimes built-in metrics return an op in update_state. Custom + # metrics don't support returning ops, so wrap the update_state method + # while returning nothing. + super(CustomMeanMetric, self).update_state(*args) + + with self.cached_session(): + metric = CustomMeanMetric() + save_dir = self._save_model_dir('first_save') + self.evaluate([v.initializer for v in metric.variables]) + loaded = self._test_metric_save_and_load( + metric, + save_dir, + num_tensor_args=1, + test_sample_weight=False) + + self._test_metric_save_and_load( + loaded, + self._save_model_dir('second_save'), + num_tensor_args=1, + test_sample_weight=False) + def test_custom_metric_wrapped_call(self): class NegativeMean(keras.metrics.Mean): diff --git a/tensorflow/python/keras/saving/saving_utils_test.py b/tensorflow/python/keras/saving/saving_utils_test.py index 49b6fde9ec7..85f421a8507 100644 --- a/tensorflow/python/keras/saving/saving_utils_test.py +++ b/tensorflow/python/keras/saving/saving_utils_test.py @@ -40,6 +40,7 @@ from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import sequential +from tensorflow.python.keras.feature_column import dense_features from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.saving import saving_utils from tensorflow.python.ops import array_ops @@ -156,7 +157,7 @@ class TraceModelCallTest(keras_parameterized.TestCase): @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_trace_features_layer(self): columns = [feature_column_lib.numeric_column('x')] - model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)]) + model = sequential.Sequential([dense_features.DenseFeatures(columns)]) model_input = {'x': constant_op.constant([[1.]])} model.predict(model_input, steps=1) fn = saving_utils.trace_model_call(model) @@ -166,7 +167,7 @@ class TraceModelCallTest(keras_parameterized.TestCase): feature_column_lib.numeric_column('x'), feature_column_lib.numeric_column('y') ] - model = sequential.Sequential([feature_column_lib.DenseFeatures(columns)]) + model = sequential.Sequential([dense_features.DenseFeatures(columns)]) model_input = {'x': constant_op.constant([[1.]]), 'y': constant_op.constant([[2.]])} model.predict(model_input, steps=1) diff --git a/tensorflow/python/keras/tests/BUILD b/tensorflow/python/keras/tests/BUILD index fb6027a8c21..0ac8abcbe18 100644 --- a/tensorflow/python/keras/tests/BUILD +++ b/tensorflow/python/keras/tests/BUILD @@ -15,8 +15,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -94,7 +92,6 @@ tf_py_test( srcs = ["convert_to_constants_test.py"], python_version = "PY3", shard_count = 4, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -106,6 +103,8 @@ tf_py_test( "//tensorflow/python:tensor_spec", "//tensorflow/python:util", "//tensorflow/python/eager:def_function", + "//tensorflow/python/keras", + "//tensorflow/python/keras:testing_utils", "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:save", "//tensorflow/python/training/tracking", @@ -140,6 +139,7 @@ tf_py_test( "//tensorflow/python:saver", "//tensorflow/python:session", "//tensorflow/python:tf_optimizer", + "//tensorflow/python/keras", ], ) @@ -157,6 +157,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:nn_ops", "//tensorflow/python/keras", + "//tensorflow/python/keras/layers/legacy_rnn:rnn_cell_impl", + "//tensorflow/python/keras/legacy_tf_layers:layers_base", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], @@ -178,7 +180,6 @@ tf_py_test( srcs = ["model_architectures_test.py"], python_version = "PY3", shard_count = 16, - tfrt_enabled = True, deps = [ ":model_architectures", "//tensorflow/python/keras", @@ -295,6 +296,7 @@ cuda_py_test( "//tensorflow/python:util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", ], ) @@ -311,6 +313,7 @@ cuda_py_test( "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:summary_ops_v2", + "//tensorflow/python/keras:testing_utils", "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers:core", ], @@ -330,6 +333,8 @@ tf_py_test( "//tensorflow/python:tensor_spec", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:function", + "//tensorflow/python/keras", + "//tensorflow/python/keras:metrics", "//tensorflow/python/keras/layers:core", "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/saved_model:save", @@ -371,6 +376,7 @@ tf_py_test( "//tensorflow/python:constant_op", "//tensorflow/python:framework_test_lib", "//tensorflow/python:util", + "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers:core", @@ -408,6 +414,7 @@ tf_py_test( "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers:core", @@ -447,6 +454,7 @@ tf_py_test( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", + "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers:core", @@ -481,6 +489,7 @@ tf_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", + "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers:core", diff --git a/tensorflow/python/keras/tests/integration_test.py b/tensorflow/python/keras/tests/integration_test.py index 64a7b694355..9e6c9693c0a 100644 --- a/tensorflow/python/keras/tests/integration_test.py +++ b/tensorflow/python/keras/tests/integration_test.py @@ -28,10 +28,10 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.layers.legacy_rnn import rnn_cell_impl as rnn_cell +from tensorflow.python.keras.legacy_tf_layers import base as base_layer from tensorflow.python.keras.utils import np_utils -from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import nn_ops as nn -from tensorflow.python.ops import rnn_cell from tensorflow.python.platform import test diff --git a/tensorflow/python/keras/tests/tracking_util_test.py b/tensorflow/python/keras/tests/tracking_util_test.py index 5f672faa751..ed0bb17adbd 100644 --- a/tensorflow/python/keras/tests/tracking_util_test.py +++ b/tensorflow/python/keras/tests/tracking_util_test.py @@ -90,7 +90,7 @@ class InterfaceTests(test.TestCase): def testSaveWithOnlyKerasSession(self): - with ops.Graph().as_default(): + with ops.Graph().as_default(), self.cached_session(): inp = input_layer.Input([1]) dense = core.Dense(1)(inp) model = training.Model(inp, dense) @@ -356,6 +356,7 @@ class CheckpointingTests(keras_parameterized.TestCase): with self.test_session(): num_training_steps = 10 checkpoint_directory = self.get_temp_dir() + optimizer = adam.Adam(0.001) def _train_fn(model, input_value): with backprop.GradientTape() as tape: loss = model(input_value) @@ -365,7 +366,6 @@ class CheckpointingTests(keras_parameterized.TestCase): for training_continuation in range(3): with testing_utils.device(should_use_gpu=True): model = MyModel() - optimizer = adam.Adam(0.001) root = trackable_utils.Checkpoint( optimizer=optimizer, model=model) manager = checkpoint_management.CheckpointManager( diff --git a/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py b/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py index 5b5283af1eb..18c91f3b622 100644 --- a/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py +++ b/tensorflow/python/keras/tests/tracking_util_with_v1_optimizers_test.py @@ -283,7 +283,7 @@ class CheckpointingTests(keras_parameterized.TestCase): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - def _train_fn(optimizer, model): + def _train_fn(optimizer, model, root): input_value = constant_op.constant([[3.]]) optimizer.minimize( functools.partial(model, input_value), @@ -303,7 +303,7 @@ class CheckpointingTests(keras_parameterized.TestCase): for _ in range(num_training_steps): strategy.extended.call_for_each_replica( - functools.partial(_train_fn, optimizer, model)) + functools.partial(_train_fn, optimizer, model, root)) root.save(file_prefix=checkpoint_prefix) self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer_step.numpy()) @@ -314,7 +314,7 @@ class CheckpointingTests(keras_parameterized.TestCase): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - def _train_fn(optimizer, model): + def _train_fn(optimizer, model, root): input_value = constant_op.constant([[3.]]) return optimizer.minimize( functools.partial(model, input_value), @@ -332,7 +332,7 @@ class CheckpointingTests(keras_parameterized.TestCase): status = root.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)) train_op = strategy.extended.call_for_each_replica( - functools.partial(_train_fn, optimizer, model)) + functools.partial(_train_fn, optimizer, model, root)) with self.session() as session: if training_continuation > 0: status.assert_consumed() diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index e7a712545f3..30b93fa95d8 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -13,8 +13,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -58,24 +56,37 @@ py_library( name = "data_utils", srcs = ["data_utils.py"], srcs_version = "PY2AND3", - deps = [":generic_utils"], + deps = [ + ":generic_utils", + ":io_utils", + ":tf_inspect", + ], ) py_library( name = "engine_utils", srcs = [ "conv_utils.py", - "io_utils.py", "losses_utils.py", ], srcs_version = "PY2AND3", deps = [ ":data_utils", + ":io_utils", "//tensorflow/python/keras:backend", "//tensorflow/python/ops/losses:loss_reduction", ], ) +py_library( + name = "io_utils", + srcs = ["io_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "@six_archive//:six", + ], +) + py_library( name = "tf_utils", srcs = ["tf_utils.py"], @@ -101,6 +112,8 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":tf_contextlib", + ":tf_inspect", "//tensorflow/python:util", "//third_party/py/numpy", ], @@ -250,7 +263,6 @@ tf_py_test( size = "small", srcs = ["generic_utils_test.py"], python_version = "PY3", - tfrt_enabled = True, deps = [ ":generic_utils", "//tensorflow/python:client_testlib", @@ -293,7 +305,6 @@ tf_py_test( srcs = ["composite_tensor_support_test.py"], python_version = "PY3", shard_count = 8, - tags = ["no_windows"], # b/135752236 deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -303,6 +314,7 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/keras", "//tensorflow/python/keras:engine", "//tensorflow/python/keras/layers", "//tensorflow/python/ops/ragged:ragged_tensor", @@ -364,6 +376,7 @@ tf_py_test( python_version = "PY3", tfrt_enabled = True, deps = [ + ":layer_utils", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:layers", diff --git a/tensorflow/python/keras/utils/losses_utils.py b/tensorflow/python/keras/utils/losses_utils.py index 08ef613c3e2..442067d2ea5 100644 --- a/tensorflow/python/keras/utils/losses_utils.py +++ b/tensorflow/python/keras/utils/losses_utils.py @@ -31,8 +31,8 @@ from tensorflow.python.util.tf_export import keras_export # TODO(joshl/psv): Update references to ReductionV2 to point to its # new location. -ReductionV2 = keras_export( # pylint: disable=invalid-name - 'keras.losses.Reduction', v1=[])(loss_reduction.ReductionV2) +ReductionV2 = loss_reduction.ReductionV2 +keras_export('keras.losses.Reduction', v1=[])(loss_reduction.ReductionV2) def remove_squeezable_dimensions( diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index a7334bc6132..3515dcc87a1 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.keras import backend as K +from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -363,6 +364,9 @@ def register_symbolic_tensor_type(cls): cls: A `class` type which shall be regarded as a symbolic `Tensor`. """ global _user_convertible_tensor_types + if cls not in _user_convertible_tensor_types: + keras_tensor.register_keras_tensor_specialization( + cls, keras_tensor.UserRegisteredTypeKerasTensor) _user_convertible_tensor_types.add(cls) @@ -474,10 +478,11 @@ def get_tensor_spec(t, dynamic_batch=False, name=None): dynamic_batch_spec = copy.deepcopy(spec) # RaggedTensorSpec only has a private _shape. - shape = dynamic_batch_spec._shape.as_list() - if shape: - shape[0] = None - dynamic_batch_spec._shape = tensor_shape.TensorShape(shape) + shape = dynamic_batch_spec._shape + if shape.rank is not None and shape.rank > 0: + shape_list = shape.as_list() + shape_list[0] = None + dynamic_batch_spec._shape = tensor_shape.TensorShape(shape_list) return dynamic_batch_spec # pylint: enable=protected-access diff --git a/tensorflow/python/keras/wrappers/BUILD b/tensorflow/python/keras/wrappers/BUILD index 1122bcbdac9..9786cff1463 100644 --- a/tensorflow/python/keras/wrappers/BUILD +++ b/tensorflow/python/keras/wrappers/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - filegroup( name = "all_py_srcs", srcs = glob(["*.py"]), @@ -38,10 +36,11 @@ tf_py_test( srcs = ["scikit_learn_test.py"], python_version = "PY3", tags = ["notsan"], - tfrt_enabled = True, deps = [ ":wrappers", "//tensorflow/python:client_testlib", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/utils:np_utils", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 50530c8de50..b105d1f987a 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -121,6 +121,7 @@ cuda_py_test( "noasan", # TODO(b/155406705): flaky ], deps = [ + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -290,6 +291,7 @@ tf_py_test( deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:collective_ops_gen", + "//tensorflow/python:framework_combinations", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python/compat:v2_compat", ], @@ -777,7 +779,6 @@ tf_py_test( srcs = ["matrix_exponential_op_test.py"], shard_count = 16, tags = ["no_windows_gpu"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -816,6 +817,7 @@ cuda_py_test( name = "matrix_solve_ls_op_test", size = "medium", srcs = ["matrix_solve_ls_op_test.py"], + tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -898,7 +900,6 @@ tf_py_test( name = "parse_single_example_op_test", size = "small", srcs = ["parse_single_example_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1190,7 +1191,6 @@ tf_py_test( name = "string_split_op_test", size = "small", srcs = ["string_split_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1357,6 +1357,7 @@ tf_py_test( name = "template_test", size = "small", srcs = ["template_test.py"], + tfrt_enabled = True, deps = [ "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -1395,7 +1396,6 @@ cuda_py_test( tags = [ "no_oss", # TODO(b/142818120): Re-enable. ], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -1631,7 +1631,7 @@ tf_py_test( name = "reader_ops_test", size = "small", srcs = ["reader_ops_test.py"], - data = ["//tensorflow/core:lmdb_testdata"], + data = ["//tensorflow/core/lib/lmdb:lmdb_testdata"], tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", @@ -1736,7 +1736,6 @@ cuda_py_test( size = "medium", srcs = ["batch_matmul_op_test.py"], shard_count = 20, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -1825,6 +1824,7 @@ cuda_py_test( name = "check_ops_test", size = "small", srcs = ["check_ops_test.py"], + tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", @@ -1935,6 +1935,7 @@ tf_py_test( name = "control_flow_util_v2_test", size = "small", srcs = ["control_flow_util_v2_test.py"], + tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:cond_v2", @@ -2307,7 +2308,6 @@ cuda_py_test( size = "medium", srcs = ["matmul_op_test.py"], shard_count = 20, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2483,7 +2483,6 @@ cuda_py_test( tags = [ "no_windows_gpu", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2610,7 +2609,6 @@ cuda_py_test( name = "shape_ops_test", size = "medium", srcs = ["shape_ops_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -2628,7 +2626,6 @@ cuda_py_test( name = "softmax_op_test", size = "medium", srcs = ["softmax_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -2643,7 +2640,6 @@ cuda_py_test( name = "softplus_op_test", size = "small", srcs = ["softplus_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", @@ -3149,7 +3145,6 @@ cuda_py_test( size = "medium", srcs = ["pooling_ops_test.py"], shard_count = 4, - tfrt_enabled = True, xla_enable_strict_auto_jit = False, # Flaky in XLA b/149568654 deps = [ "//tensorflow/python:array_ops", @@ -3200,7 +3195,6 @@ cuda_py_test( srcs = ["rnn_cell_test.py"], shard_count = 15, tags = ["no_windows"], # b/139739217 - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3419,6 +3413,7 @@ cuda_py_test( size = "medium", srcs = ["conv_ops_3d_test.py"], shard_count = 30, + tags = ["no_cuda11"], tfrt_enabled = True, deps = [ "//tensorflow/python:client_testlib", @@ -3434,7 +3429,6 @@ cuda_py_test( srcs = ["cwise_ops_test.py"], shard_count = 50, tags = ["no_windows"], # b/163222163 - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3455,7 +3449,6 @@ cuda_py_test( size = "medium", srcs = ["cwise_ops_binary_test.py"], shard_count = 50, - tfrt_enabled = True, # b/140155647: Error just outside of tolerance xla_enable_strict_auto_jit = False, deps = [ @@ -3561,7 +3554,6 @@ tf_py_test( tags = [ "no_windows", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3582,7 +3574,6 @@ cuda_py_test( tags = [ "no_windows", ], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3619,6 +3610,7 @@ cuda_py_test( "no_oss", # b/117185141. "nomsan", # TODO(b/117236102): Re-enable in msan build. ], + tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -3876,7 +3868,6 @@ cuda_py_test( size = "medium", srcs = ["cond_v2_test.py"], grpc_enabled = True, - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index bd06627a707..df17e5a3a39 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -54,7 +54,6 @@ cuda_py_test( name = "gather_op_test", size = "medium", srcs = ["gather_op_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -76,6 +75,7 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:state_ops", "//tensorflow/python:variables", diff --git a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py index c89c70d2bbd..b3c566b882f 100644 --- a/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/scatter_nd_ops_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker_v2 +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -909,6 +910,32 @@ class ScatterNdTensorTest(test.TestCase): "hello", "there", "hello", "there", "there", "hello", "hello", "12" ])) + @test_util.run_in_graph_and_eager_modes + def testUpdateRepeatedIndices1D(self): + if test_util.is_gpu_available(): + self.skipTest("Duplicate indices scatter is non-deterministic on GPU") + a = array_ops.zeros([10, 1]) + b = array_ops.tensor_scatter_update(a, [[5], [5]], [[4], [8]]) + self.assertAllEqual( + b, + constant_op.constant([[0.], [0.], [0.], [0.], [0.], [8.], [0.], [0.], + [0.], [0.]])) + + @test_util.run_in_graph_and_eager_modes + def testUpdateRepeatedIndices2D(self): + if test_util.is_gpu_available(): + self.skipTest("Duplicate indices scatter is non-deterministic on GPU") + a = array_ops.zeros([10, 10]) + b = array_ops.tensor_scatter_update( + a, [[5], [6], [6]], + [math_ops.range(10), + math_ops.range(11, 21), + math_ops.range(10, 20)]) + self.assertAllEqual( + b[6], + constant_op.constant( + [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index cbf25df03b0..8eb5af399b4 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1610,7 +1610,7 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase): expected = self._scale_per_slice(shape, axis, quant_values) unused_minmax_value = 0 if axis is None else [0] * shape[axis] fake_quantized = self.evaluate( - array_ops.quantize_and_dequantize( + array_ops.quantize_and_dequantize_v2( inputs, unused_minmax_value, unused_minmax_value, @@ -1620,7 +1620,7 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase): self.assertAllEqual(fake_quantized, expected) if axis is not None: fake_quantized = self.evaluate( - array_ops.quantize_and_dequantize( + array_ops.quantize_and_dequantize_v2( inputs, unused_minmax_value, unused_minmax_value, @@ -1628,6 +1628,23 @@ class QuantizeAndDequantizeTest(test_util.TensorFlowTestCase): axis=(axis - 4))) self.assertAllClose(fake_quantized, expected) + def testQuantizeDequantizeGrad(self): + shape = (2, 2) + max_threshold = 0 + min_threshold = -10 + input_value = np.random.rand(2, 2) * 40.0 - 20.0 + input_tensor = constant_op.constant(input_value, shape=shape, + name="input_tensor") + with self.cached_session(): + def f(a): + return array_ops.quantize_and_dequantize_v2( + a, + input_min=min_threshold, + input_max=max_threshold, + range_given=True) + output_grad = gradient_checker_v2.compute_gradient(f, [input_tensor]) + self.assertAllClose(output_grad[0], np.zeros([1, 4, 4])) + @test_util.run_all_in_graph_and_eager_modes class SortedSearchTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/kernel_tests/bitcast_op_test.py b/tensorflow/python/kernel_tests/bitcast_op_test.py index ed6d7799c7e..5551469aa73 100644 --- a/tensorflow/python/kernel_tests/bitcast_op_test.py +++ b/tensorflow/python/kernel_tests/bitcast_op_test.py @@ -82,6 +82,7 @@ class BitcastTest(test.TestCase): datatype = dtypes.int8 array_ops.bitcast(x, datatype, None) + @test_util.disable_tfrt("b/169901260") def testQuantizedType(self): shape = [3, 4] x = np.zeros(shape, np.uint16) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index acc5af03097..9cb9fb490bc 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -1903,6 +1903,174 @@ class AssertShapesTest(test.TestCase): sess.run(out, feed_dict=feed_dict) +class AssertShapesSparseTensorTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_scalar_target_success(self): + sparse_float = sparse_tensor.SparseTensor( + constant_op.constant([[]], dtypes.int64), + constant_op.constant([42], dtypes.float32), + constant_op.constant([], dtypes.int64)) + assertion = check_ops.assert_shapes([(sparse_float, [])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_float) + self.evaluate(out) + + def test_assert_shapes_sparse_tensor_nonscalar_target_fail(self): + sparse_float = sparse_tensor.SparseTensor( + constant_op.constant([[]], dtypes.int64), + constant_op.constant([42], dtypes.float32), + constant_op.constant([], dtypes.int64)) + with self.assertRaisesRegexp(ValueError, + r"must have rank 2.*Received rank 0"): + assertion = check_ops.assert_shapes([(sparse_float, [None, None])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_float) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_fully_specified_target_success(self): + sparse_float = sparse_tensor.SparseTensor( + constant_op.constant([[111], [232]], dtypes.int64), + constant_op.constant([23.4, -43.2], dtypes.float32), + constant_op.constant([500], dtypes.int64)) + assertion = check_ops.assert_shapes([(sparse_float, [500])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_float) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_fully_specified_target_fail(self): + sparse_float = sparse_tensor.SparseTensor( + constant_op.constant([[111], [232]], dtypes.int64), + constant_op.constant([23.4, -43.2], dtypes.float32), + constant_op.constant([500], dtypes.int64)) + with self.assertRaisesRegexp(ValueError, r"dimension 0 must have size 499"): + assertion = check_ops.assert_shapes([(sparse_float, [499])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_float) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_partially_specified_target_success(self): + sparse_int = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 40], dtypes.int64)) + assertion = check_ops.assert_shapes([(sparse_int, [None, 40])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_int) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_symbolic_match_success(self): + sparse_int = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6, 7], [8, 9, 10]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 30, 40], dtypes.int64)) + assertion = check_ops.assert_shapes([(sparse_int, ["N", "N", "D"])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_int) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_partially_specified_target_fail(self): + sparse_int = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 40], dtypes.int64)) + with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 41"): + assertion = check_ops.assert_shapes([(sparse_int, [None, 41])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_int) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_wrong_rank_fail(self): + sparse_int = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 40], dtypes.int64)) + with self.assertRaisesRegexp(ValueError, + r"must have rank 3\..* Received rank 2"): + assertion = check_ops.assert_shapes([(sparse_int, [None, None, 40])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_int) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_wrong_symbolic_match_fail(self): + sparse_int = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 40], dtypes.int64)) + with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"): + assertion = check_ops.assert_shapes([(sparse_int, ["D", "D"])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_int) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_multiple_assertions_success(self): + sparse_scalar = sparse_tensor.SparseTensor( + constant_op.constant([[]], dtypes.int64), + constant_op.constant([42], dtypes.float32), + constant_op.constant([], dtypes.int64)) + sparse_2d = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 30], dtypes.int64)) + assertion = check_ops.assert_shapes([(sparse_scalar, []), + (sparse_2d, ["N", "N"])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_2d) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_multiple_assertions_fail(self): + sparse_scalar = sparse_tensor.SparseTensor( + constant_op.constant([[]], dtypes.int64), + constant_op.constant([42], dtypes.float32), + constant_op.constant([], dtypes.int64)) + sparse_2d = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 40], dtypes.int64)) + with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"): + assertion = check_ops.assert_shapes([(sparse_scalar, []), + (sparse_2d, ["N", "N"])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_2d) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_success(self): + dense_scalar = constant_op.constant([42], dtypes.float32) + sparse_2d = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 30], dtypes.int64)) + assertion = check_ops.assert_shapes([(dense_scalar, []), + (sparse_2d, ["N", "N"])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_2d) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes + def test_assert_shapes_sparse_tensor_mixed_dense_and_sparse_fail(self): + dense_scalar = constant_op.constant([42], dtypes.float32) + sparse_2d = sparse_tensor.SparseTensor( + constant_op.constant([[5, 6], [7, 8]], dtypes.int64), + constant_op.constant([23, -43], dtypes.int32), + constant_op.constant([30, 40], dtypes.int64)) + with self.assertRaisesRegexp(ValueError, r"dimension 1 must have size 30"): + assertion = check_ops.assert_shapes([(dense_scalar, []), + (sparse_2d, ["N", "N"])]) + with ops.control_dependencies([assertion]): + out = array_ops.identity(sparse_2d) + self.evaluate(out) + + class IsStrictlyIncreasingTest(test.TestCase): @test_util.run_in_graph_and_eager_modes diff --git a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py index 4385a20cd20..95143cc233a 100644 --- a/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_multi_worker_test.py @@ -64,7 +64,7 @@ class CollectiveOpTest(test.TestCase): except errors.UnavailableError: continue break - multi_process_runner.barrier().wait() + multi_process_runner.get_barrier().wait() cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2) mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec) diff --git a/tensorflow/python/kernel_tests/collective_ops_test.py b/tensorflow/python/kernel_tests/collective_ops_test.py index 583f1148bcf..0e3e16179a6 100644 --- a/tensorflow/python/kernel_tests/collective_ops_test.py +++ b/tensorflow/python/kernel_tests/collective_ops_test.py @@ -345,6 +345,139 @@ class AbortCollectiveOpsTest(test.TestCase, parameterized.TestCase): def_function.function(collective_fn)() +@combinations.generate( + combinations.combine( + collective_op=[ + combinations.NamedObject('all_reduce', _collective_ops.all_reduce), + combinations.NamedObject('all_reduce_v2', + _collective_ops.all_reduce_v2), + combinations.NamedObject('all_gather', _collective_ops.all_gather), + combinations.NamedObject('all_gather_v2', + _collective_ops.all_gather_v2), + ], + mode='eager', + communication=['ring'])) +class TimeoutTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + _setup_context() + super().setUp() + + def testTimeout(self, collective_op, communication): + timeout = 4.5 + + @def_function.function + def run(group_size, reported_group_size=None): + group_key = 20 + instance_key = 30 + tensor = [1, 2, 3, 4] + results = [] + if reported_group_size is None: + reported_group_size = group_size + for i in range(group_size): + with ops.device('/CPU:{}'.format(i)): + input_data = constant_op.constant(tensor) + result = collective_op( + input_data, + group_size=reported_group_size, + group_key=group_key, + instance_key=instance_key, + communication_hint=communication, + timeout=timeout) + results.append(result) + return results + + run(2, 2) + + start_time = time.time() + with self.assertRaisesRegex(errors.DeadlineExceededError, + 'Collective has timed out during execution'): + run(1, 2) + elapsed = time.time() - start_time + self.assertAllGreaterEqual(elapsed, timeout) + + def testParamResolutionAfterTimeoutV2(self, collective_op, communication): + timeout = 1.5 + + group_key = 20 + instance_key = 30 + input_data = constant_op.constant([1, 2, 3, 4]) + + # This timeout comes from param solution. + with self.assertRaisesRegex( + errors.DeadlineExceededError, + 'Collective has timed out waiting for other workers'): + with ops.device('CPU:0'): + collective_op( + input_data, + group_size=2, + group_key=group_key, + instance_key=instance_key, + communication_hint=communication, + timeout=timeout) + + # We launch the second device after the first device times out. This is to + # simulate the situation when other workers are slow and the timeout is + # short. Since the CPU:0 times out in the param resolution phase, CPU:1 + # should times out as well, but in the execute phase. + with self.assertRaisesRegex(errors.DeadlineExceededError, + 'Collective has timed out during execution'): + with ops.device('CPU:1'): + collective_op( + input_data, + group_size=2, + group_key=group_key, + instance_key=instance_key, + communication_hint=communication, + timeout=timeout) + + def testExecutionAfterTimeoutV2(self, collective_op, communication): + timeout = 1.5 + group_key = 20 + instance_key = 30 + input_data = constant_op.constant([1, 2, 3, 4]) + + @def_function.function + def run(): + for device in ['CPU:0', 'CPU:1']: + with ops.device(device): + collective_op( + input_data, + group_size=2, + group_key=group_key, + instance_key=instance_key, + communication_hint=communication, + timeout=timeout) + + # Run a normal all-reduce to complete param resolution. + run() + + with self.assertRaisesRegex(errors.DeadlineExceededError, + 'Collective has timed out during execution'): + with ops.device('CPU:0'): + collective_op( + input_data, + group_size=2, + group_key=group_key, + instance_key=instance_key, + communication_hint=communication, + timeout=timeout) + + # We launch the second device after the first device times out. This is to + # simulate the situation when other workers are slow and the timeout is + # short. It should error immediately. + with self.assertRaisesRegex(errors.DeadlineExceededError, + 'Collective has timed out during execution'): + with ops.device('CPU:1'): + # No timeout. + collective_op( + input_data, + group_size=2, + group_key=group_key, + instance_key=instance_key, + communication_hint=communication) + + def _setup_context(): context._reset_context() cpus = config.list_physical_devices('CPU') diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py index 30b20c67fda..70d7b2530a9 100644 --- a/tensorflow/python/kernel_tests/cond_v2_test.py +++ b/tensorflow/python/kernel_tests/cond_v2_test.py @@ -940,7 +940,6 @@ class CondV2Test(test.TestCase): self.assertEqual(fn_output[0].op.type, "StatefulPartitionedCall") self.assertAllEqual(self.evaluate(fn_output), [2.0, 4.0]) - @test_util.disable_tfrt("GPU to host copy not implemented yet.") def testGradientTapeOfCondWithResourceVariableInFunction(self): with context.eager_mode(): v = variables.Variable(2.) diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py index 06a0d02e48b..589540fd886 100644 --- a/tensorflow/python/kernel_tests/constant_op_eager_test.py +++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py @@ -581,13 +581,11 @@ class FillTest(test.TestCase): tf_ans = array_ops.fill([2, 3], np_ans[0][0], name="fill").numpy() self.assertAllEqual(np_ans, tf_ans) - @test_util.disable_tfrt("support error code in TFRT.") def testFillNegative(self): for shape in (-1,), (2, -1), (-1, 2), (-2), (-3): with self.assertRaises(errors_impl.InvalidArgumentError): array_ops.fill(shape, 7) - @test_util.disable_tfrt("support error code in TFRT.") def testShapeFunctionEdgeCases(self): # Non-vector dimensions. with self.assertRaises(errors_impl.InvalidArgumentError): diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index e35b62a4556..e965c52ee29 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -456,6 +456,7 @@ class ZerosTest(test.TestCase): self.assertFalse(np.any(z_value)) self.assertEqual((2, 3), z_value.shape) + @test_util.disable_tfrt("b/169901260") def testQint8Dtype(self): dtype = dtypes_lib.qint8 z = array_ops.zeros([2, 3], dtype=dtype) @@ -466,6 +467,7 @@ class ZerosTest(test.TestCase): z_value = self.evaluate(math_ops.cast(z, dtypes_lib.int32)) self.assertFalse(np.any(z_value)) + @test_util.disable_tfrt("b/169901260") def testQint16Dtype(self): dtype = dtypes_lib.qint16 z = array_ops.zeros([2, 3], dtype=dtype) @@ -650,6 +652,7 @@ class OnesTest(test.TestCase): self.assertEqual([2, 3], z.get_shape()) self.assertAllEqual(z, np.ones([2, 3])) + @test_util.disable_tfrt("b/169901260") def testQintDtype(self): @def_function.function(autograph=False) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 68ddbe12f16..54bbd2b2e9e 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -34,6 +34,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python import tf2 from tensorflow.python.client import device_lib from tensorflow.python.client import session +from tensorflow.python.data.experimental.ops import cardinality +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function as eager_function @@ -720,8 +722,9 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): # We expect that everything runs on CPU, even if GPU is available. self.assertEqual(len(run_metadata.partition_graphs), 1) - def _count_matching_switch_nodes_on_device(self, run_metadata, device_str): - # Returns the number of Switch nodes with type float32 placed on + def _count_matching_switch_nodes_on_device(self, run_metadata, device_str, + dtype): + # Returns the number of Switch nodes with type dtype placed on # `device_str`. device_graphs = [ g for g in run_metadata.partition_graphs @@ -729,14 +732,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): ] self.assertLen(device_graphs, 1) switch_nodes = [ - n for n in device_graphs[0].node if n.op == "Switch" and - n.attr["T"].type == dtypes.float32.as_datatype_enum + n for n in device_graphs[0].node + if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum ] return len(switch_nodes) @test_util.run_gpu_only @test_util.run_deprecated_v1 - def testCondSwitchColocatedWithInputWhenInputOnCPU(self): + def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self): x = array_ops.placeholder(dtypes.float32) # `arg` is used in the cond then branch so a Switch node is created for it. @@ -756,12 +759,46 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): options = config_pb2.RunOptions(output_partition_graphs=True) sess.run( r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) - self.assertEqual(len(run_metadata.partition_graphs), 2) + self.assertLen(run_metadata.partition_graphs, 2) # Check that the Switch for `arg` gets placed on CPU. self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 1) + self._count_matching_switch_nodes_on_device(run_metadata, "CPU", + dtypes.float32), 1) self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 0) + self._count_matching_switch_nodes_on_device(run_metadata, "GPU", + dtypes.float32), 0) + + @test_util.run_gpu_only + @test_util.run_deprecated_v1 + def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self): + x = array_ops.placeholder(dtypes.float32) + + # `arg` is used in the cond then branch so a Switch node is created for it. + # We test that the Switch node gets placed on the same device as `arg`. + # Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU + # by placer. + arg = dataset_ops.Dataset.range(8) + + def true_fn(): + return cardinality.cardinality(arg) + + r = control_flow_ops.cond( + constant_op.constant(True), true_fn, + lambda: constant_op.constant(0, dtypes.int64)) + + with session.Session() as sess: + run_metadata = config_pb2.RunMetadata() + options = config_pb2.RunOptions(output_partition_graphs=True) + sess.run( + r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) + self.assertLen(run_metadata.partition_graphs, 2) + # Check that the Switch for `arg` gets placed on CPU. + self.assertEqual( + self._count_matching_switch_nodes_on_device(run_metadata, "CPU", + dtypes.variant), 1) + self.assertEqual( + self._count_matching_switch_nodes_on_device(run_metadata, "GPU", + dtypes.variant), 0) @test_util.run_gpu_only @test_util.run_deprecated_v1 @@ -787,9 +824,11 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): self.assertEqual(len(run_metadata.partition_graphs), 2) # Check that the Switch for `arg` gets placed on GPU. self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 0) + self._count_matching_switch_nodes_on_device(run_metadata, "CPU", + dtypes.float32), 0) self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 1) + self._count_matching_switch_nodes_on_device(run_metadata, "GPU", + dtypes.float32), 1) def testCondAccessTrueBranchTensorInFalseBranchRaises(self): diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index f7234398cdc..dd033121329 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops @@ -3266,6 +3267,173 @@ def GetInceptionBackFilterTest(input_size, filter_size, output_size, strides, return Test +class FusedConv2DTest(test.TestCase): + + def _CreateNumpyTensor(self, shape): + total_size = np.prod(shape) + return np.arange(1, total_size + 1, dtype=np.float32).reshape(shape) + + def _CreateConv2D(self, + input_values, + filters, + strides=[1, 1], + padding="SAME"): + return nn_ops.convolution( + input_values, filters, strides=strides, padding=padding) + + # Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to + # Add has refcount 1. + @test_util.run_in_graph_and_eager_modes(use_gpu=False) + def testAddWithRefCountOne(self): + expected_output = [ + 113377, 125570, 77305, 86738, 19433, 22226, 60681, 70722, 36291, 43718, + 7143, 9206, 9785, 12098, 4783, 6366, 779, 1134 + ] + tensor_in_sizes = [1, 3, 3, 2] + filter_in_sizes = [2, 2, 2, 2] + bias_in_sizes = [2] + + x = self._CreateNumpyTensor(tensor_in_sizes) + filter_in = self._CreateNumpyTensor(filter_in_sizes) + bias_in = self._CreateNumpyTensor(bias_in_sizes) + # To get different weights for filter + offset = 1 + + conv1 = self._CreateConv2D(x, filter_in) + conv2 = self._CreateConv2D(conv1, filter_in + offset) + + conv = self._CreateConv2D(conv1, filter_in - offset) + bias_add = nn_ops.bias_add(conv, bias_in) + add = math_ops.add_n([bias_add, conv2]) + + self.assertAllEqual( + np.rint(expected_output), + self.evaluate(add).reshape(-1)) + + # Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to + # Add has a total refcount of 2, and Add is its last consumer. + @test_util.run_in_graph_and_eager_modes(use_gpu=False) + def testAddWithRefCountTwoAndRunAddLast(self): + expected_output = [ + 1.907175e+06, 2.253505e+06, 7.809210e+05, 9.537180e+05, 1.184170e+05, + 1.523070e+05, 5.367010e+05, 6.803700e+05, 1.867090e+05, 2.529460e+05, + 2.362300e+04, 3.522600e+04, 5.121700e+04, 7.168300e+04, 1.494300e+04, + 2.347400e+04, 1.558000e+03, 2.903000e+03 + ] + tensor_in_sizes = [1, 3, 3, 2] + filter_in_sizes = [2, 2, 2, 2] + bias_in_sizes = [2] + + x = self._CreateNumpyTensor(tensor_in_sizes) + filter_in = self._CreateNumpyTensor(filter_in_sizes) + bias_in = self._CreateNumpyTensor(bias_in_sizes) + # To get different weights for filter + offset = 1 + + conv1 = self._CreateConv2D(x, filter_in) + conv2 = self._CreateConv2D(conv1, filter_in + offset) + + conv = self._CreateConv2D(conv2, filter_in - offset) + bias_add = nn_ops.bias_add(conv, bias_in) + add = math_ops.add_n([bias_add, conv1]) + + self.assertAllEqual( + np.rint(expected_output), + self.evaluate(add).reshape(-1)) + + # Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to + # Add has refcount 2 and Add (in the fused Conv2D op) is its first consumer. + @test_util.run_in_graph_and_eager_modes(use_gpu=False) + def testAddWithRefCountTwoAndRunAddFirst(self): + expected_output = [ + 176161, 194450, 120673, 134822, 30545, 34734, 96041, 111102, 58149, + 69289, 11745, 14839, 15833, 19302, 7965, 10339, 1345, 1877 + ] + tensor_in_sizes = [1, 3, 3, 2] + filter_in_sizes = [2, 2, 2, 2] + bias_in_sizes = [2] + + x = self._CreateNumpyTensor(tensor_in_sizes) + filter_in = self._CreateNumpyTensor(filter_in_sizes) + bias_in = self._CreateNumpyTensor(bias_in_sizes) + # To get different weights for filter + offset = 1 + + conv1 = self._CreateConv2D(x, filter_in) + conv2 = self._CreateConv2D(conv1, filter_in + offset) + + conv = self._CreateConv2D(conv1, filter_in - offset) + bias_add = nn_ops.bias_add(conv, bias_in) + add = math_ops.add_n([bias_add, conv2]) + + relu = nn_ops.relu(add) + output = math_ops.add_n([relu, conv2]) + + self.assertAllEqual( + np.rint(expected_output), + self.evaluate(output).reshape(-1)) + + # Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to + # Add has refcount 2, and there is no dependency between its two consumers. + @test_util.run_in_graph_and_eager_modes(use_gpu=False) + def testAddWithRefCountTwoAndNoDependence(self): + expected_output = [ + 176161, 194450, 120673, 134822, 30545, 34734, 96041, 111102, 58149, + 69289, 11745, 14839, 15833, 19302, 7965, 10339, 1345, 1877 + ] + tensor_in_sizes = [1, 3, 3, 2] + filter_in_sizes = [2, 2, 2, 2] + bias_in_sizes = [2] + + x = self._CreateNumpyTensor(tensor_in_sizes) + filter_in = self._CreateNumpyTensor(filter_in_sizes) + bias_in = self._CreateNumpyTensor(bias_in_sizes) + # To get different weights for filter + offset = 1 + + conv1 = self._CreateConv2D(x, filter_in) + conv2 = self._CreateConv2D(conv1, filter_in + offset) + + conv = self._CreateConv2D(conv1, filter_in - offset) + bias_add = nn_ops.bias_add(conv, bias_in) + add = math_ops.add_n([bias_add, conv2]) + + relu1 = nn_ops.relu(add) + relu2 = nn_ops.relu(conv2) + output = math_ops.add_n([relu1, relu2]) + + self.assertAllEqual( + np.rint(expected_output), + self.evaluate(output).reshape(-1)) + + # Tests tensor forwarding of a fused Conv2D+BiasAdd+Add op when the input to + # Add is the same as the input to the fused Conv2D op and needs a tensor + # buffer. + @test_util.run_in_graph_and_eager_modes(use_gpu=False) + def testAddWithSameSrcAndAddTensorBuffer(self): + expected_output = [ + 57157, 63298, 39249, 44026, 9971, 11402, 31193, 36306, 19126, 22948, + 3970, 5060, 5135, 6350, 2666, 3524, 461, 674 + ] + tensor_in_sizes = [1, 3, 3, 2] + filter_in_sizes = [2, 2, 2, 2] + bias_in_sizes = [2] + + x = self._CreateNumpyTensor(tensor_in_sizes) + filter_in = self._CreateNumpyTensor(filter_in_sizes) + bias_in = self._CreateNumpyTensor(bias_in_sizes) + + conv1 = self._CreateConv2D(x, filter_in) + + conv = self._CreateConv2D(conv1, filter_in) + bias_add = nn_ops.bias_add(conv, bias_in) + add = math_ops.add_n([bias_add, conv1]) + + self.assertAllEqual( + np.rint(expected_output), + self.evaluate(add).reshape(-1)) + + if __name__ == "__main__": for index, (input_size_, filter_size_, output_size_, stride_, padding_) in enumerate(GetShrunkInceptionShapes()): diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py index 98832dd9885..1f8f6ac6153 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py @@ -991,6 +991,7 @@ class ComparisonOpTest(test.TestCase): [[True, True, True, True, True], [False, False, False, False, False]], values) + @test_util.disable_tfrt("b/169901260") def testEqualQuantizeDType(self): dtypes = [ dtypes_lib.qint8, diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD index ff530da88a6..4b0915b37e7 100644 --- a/tensorflow/python/kernel_tests/distributions/BUILD +++ b/tensorflow/python/kernel_tests/distributions/BUILD @@ -95,7 +95,6 @@ cuda_py_test( name = "categorical_test", size = "small", srcs = ["categorical_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -193,7 +192,6 @@ cuda_py_test( cuda_py_test( name = "multinomial_test", srcs = ["multinomial_test.py"], - tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 093fdb69dc3..b6ab771fe08 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -300,6 +300,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): param) checked_param.eval(feed_dict={param: np.ones([int(2**11+1)])}) + @test_util.disable_tfrt("b/169901260") @test_util.run_in_graph_and_eager_modes def testUnsupportedDtype(self): param = ops.convert_to_tensor( diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index 50d11a62793..5c4df4c6ac7 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -62,6 +62,7 @@ class DynamicStitchTestBase(object): # length. self.assertEqual([None], stitched_t.get_shape().as_list()) + @test_util.disable_tfrt("b/169901260") def testSimpleOneDimensional(self): # Test various datatypes in the simple case to ensure that the op was # registered under those types. diff --git a/tensorflow/python/kernel_tests/fingerprint_op_test.py b/tensorflow/python/kernel_tests/fingerprint_op_test.py index 0af3f5182fa..0a5643f9d69 100644 --- a/tensorflow/python/kernel_tests/fingerprint_op_test.py +++ b/tensorflow/python/kernel_tests/fingerprint_op_test.py @@ -37,6 +37,11 @@ class FingerprintTest(test.TestCase): self.assertTupleEqual(fingerprint0.shape, fingerprint1.shape) self.assertTrue(np.any(fingerprint0 != fingerprint1)) + def test_empty(self): + f0 = self.evaluate(array_ops.fingerprint([])) + self.assertEqual(f0.ndim, 2) + self.assertEqual(f0.shape, (0, 8)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index f792cda6ea1..f2f6dc33b84 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -22,6 +22,7 @@ from __future__ import print_function from absl.testing import parameterized import numpy as np # pylint: disable=unused-import +from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -40,6 +41,7 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope as vs @@ -1600,9 +1602,18 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): def func(): t = constant_op.constant([1., 2., 3.]) l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + handle_data = resource_variable_ops.get_eager_safe_handle_data(l) + self.assertTrue(handle_data.is_set) + self.assertEqual(types_pb2.ST_TENSOR_LIST, + handle_data.shape_and_type[0].specialized_type) return l tensor_list = func() + handle_data = resource_variable_ops.get_eager_safe_handle_data(tensor_list) + self.assertTrue(handle_data.is_set) + self.assertEqual(dtypes.float32, handle_data.shape_and_type[0].dtype) + self.assertEqual(types_pb2.ST_TENSOR_LIST, + handle_data.shape_and_type[0].specialized_type) element = list_ops.tensor_list_get_item( tensor_list, 0, element_dtype=dtypes.float32) self.assertAllEqual(element.shape.as_list(), []) diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index e4810beec31..d564da12c27 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -195,6 +195,20 @@ class StaticHashTableTest(BaseLookupTableTest): result = self.evaluate(output) self.assertAllEqual([0, 1, -1], result) + def testStaticHashTableGetItem(self): + default_val = constant_op.constant(-1, dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer( + keys, values), default_val) + self.initialize_table(table) + + input_string = constant_op.constant(["brain", "salad", "tank"]) + output = table[input_string] + + result = self.evaluate(output) + self.assertAllEqual([0, 1, -1], result) + def testStaticHashTableWithSparseTensorInput(self): default_val = constant_op.constant(-1, dtypes.int64) keys = constant_op.constant(["brain", "salad", "surgery"]) @@ -473,6 +487,24 @@ class StaticHashTableTest(BaseLookupTableTest): self.evaluate(lookup_ops.tables_initializer()) self.assertAllEqual(grad, -10.) + def testExportShapeInference(self): + table = self.getHashTable()(lookup_ops.KeyValueTensorInitializer( + constant_op.constant([2, 5], dtype=dtypes.int64), + constant_op.constant([-10.0, 1], dtype=dtypes.float32)), -1) + actual_shapes = [t.shape for t in table.export()] + inferred_shapes = [] + + @def_function.function + def f(): + for t in table.export(): + inferred_shapes.append(t.shape) + + f() + self.assertLen(actual_shapes, 2) + self.assertLen(inferred_shapes, 2) + self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0])) + self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1])) + class KeyValueTensorInitializerTest(BaseLookupTableTest): @@ -954,6 +986,21 @@ class StaticVocabularyTableTest(BaseLookupTableTest): self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) + def testStaticVocabularyTableGetItem(self): + vocab_file = self._createVocabFile("feat_to_id_1.txt") + vocab_size = 3 + oov_buckets = 1 + table = self.getVocabularyTable()(lookup_ops.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size), oov_buckets) + + self.initialize_table(table) + + input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) + + out = table[input_string] + self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) + self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) + def testInt32StaticVocabularyTable(self): vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) vocab_size = 3 @@ -1226,71 +1273,85 @@ class StaticVocabularyTableTest(BaseLookupTableTest): class DenseHashTableOpTest(test.TestCase): def testBasic(self): - with self.cached_session(): + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) + self.assertAllEqual(0, self.evaluate(table.size())) - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) + + remove_string = constant_op.constant([12, 15], dtypes.int64) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) + + input_string = constant_op.constant([11, 12, 15], dtypes.int64) + output = table.lookup(input_string) + self.assertAllEqual([3], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual([0, -1, -1], result) + + def testGetItem(self): + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) + + self.evaluate(table.insert(keys, values)) + + input_string = constant_op.constant([11, 12, 15], dtypes.int64) + output = table[input_string] + self.assertAllEqual([3], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual([0, 1, -1], result) + + def testBasicBool(self): + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([True, True, True, True], dtypes.bool) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.bool, + default_value=False, + empty_key=0, + deleted_key=-1) + self.assertAllEqual(0, self.evaluate(table.size())) + + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) + + remove_string = constant_op.constant([11, 15], dtypes.int64) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) + + input_string = constant_op.constant([11, 12, 15], dtypes.int64) + output = table.lookup(input_string) + self.assertAllEqual([3], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual([False, True, False], result) + + def testSameEmptyAndDeletedKey(self): + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Empty and deleted keys"): table = lookup_ops.DenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, - empty_key=0, - deleted_key=-1) + empty_key=42, + deleted_key=42) self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = self.evaluate(output) - self.assertAllEqual([0, -1, -1], result) - - def testBasicBool(self): - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([True, True, True, True], dtypes.bool) - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.bool, - default_value=False, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, self.evaluate(table.size())) - - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) - - remove_string = constant_op.constant([11, 15], dtypes.int64) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) - - input_string = constant_op.constant([11, 12, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) - - result = self.evaluate(output) - self.assertAllEqual([False, True, False], result) - - def testSameEmptyAndDeletedKey(self): - with self.cached_session(): - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Empty and deleted keys"): - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, self.evaluate(table.size())) - @test_util.run_v1_only("uses placeholders") def testLookupUnknownShape(self): with self.cached_session(): @@ -1313,212 +1374,203 @@ class DenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1], result) def testMapStringToFloat(self): - with self.cached_session(): + keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) + values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) + default_value = constant_op.constant(-1.5, dtypes.float32) + table = lookup_ops.DenseHashTable( + dtypes.string, + dtypes.float32, + default_value=default_value, + empty_key="", + deleted_key="$") + self.assertAllEqual(0, self.evaluate(table.size())) - keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) - default_value = constant_op.constant(-1.5, dtypes.float32) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) + + remove_string = constant_op.constant(["b", "e"]) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) + + input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) + output = table.lookup(input_string) + self.assertAllEqual([4], output.get_shape()) + + result = self.evaluate(output) + self.assertAllClose([0, -1.5, 3.3, -1.5], result) + + def testMapInt64ToFloat(self): + for float_dtype in [dtypes.float32, dtypes.float64]: + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) + default_value = constant_op.constant(-1.5, float_dtype) table = lookup_ops.DenseHashTable( - dtypes.string, - dtypes.float32, + dtypes.int64, + float_dtype, default_value=default_value, - empty_key="", - deleted_key="$") + empty_key=0, + deleted_key=-1) self.assertAllEqual(0, self.evaluate(table.size())) self.evaluate(table.insert(keys, values)) self.assertAllEqual(4, self.evaluate(table.size())) - remove_string = constant_op.constant(["b", "e"]) + remove_string = constant_op.constant([12, 15], dtypes.int64) self.evaluate(table.remove(remove_string)) self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) + input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) output = table.lookup(input_string) self.assertAllEqual([4], output.get_shape()) result = self.evaluate(output) self.assertAllClose([0, -1.5, 3.3, -1.5], result) - def testMapInt64ToFloat(self): - for float_dtype in [dtypes.float32, dtypes.float64]: - with self.cached_session(): - - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) - default_value = constant_op.constant(-1.5, float_dtype) - table = lookup_ops.DenseHashTable( - dtypes.int64, - float_dtype, - default_value=default_value, - empty_key=0, - deleted_key=-1) - self.assertAllEqual(0, self.evaluate(table.size())) - - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) - - remove_string = constant_op.constant([12, 15], dtypes.int64) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) - - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) - - result = self.evaluate(output) - self.assertAllClose([0, -1.5, 3.3, -1.5], result) - def testVectorValues(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], - dtypes.int64) - default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, self.evaluate(table.size())) + keys = constant_op.constant([11, 12, 13], dtypes.int64) + values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], + dtypes.int64) + default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=default_value, + empty_key=0, + deleted_key=-1, + initial_num_buckets=4) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) - self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) - self.evaluate( - table.insert( - constant_op.constant([14], dtypes.int64), - constant_op.constant([[2, 3, 4, 5]], dtypes.int64))) - self.assertAllEqual(4, self.evaluate(table.size())) - self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) + self.evaluate( + table.insert( + constant_op.constant([14], dtypes.int64), + constant_op.constant([[2, 3, 4, 5]], dtypes.int64))) + self.assertAllEqual(4, self.evaluate(table.size())) + self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) - remove_string = constant_op.constant([12, 16], dtypes.int64) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) - self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) + remove_string = constant_op.constant([12, 16], dtypes.int64) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) + self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) - input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4, 4], - output.shape, - msg="Saw shape: %s" % output.shape) + input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) + output = table.lookup(input_string) + self.assertAllEqual([4, 4], + output.shape, + msg="Saw shape: %s" % output.shape) - result = self.evaluate(output) - self.assertAllEqual( - [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], - result) + result = self.evaluate(output) + self.assertAllEqual( + [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], + result) def testVectorKeys(self): - with self.cached_session(): - keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) - values = constant_op.constant([10, 11, 12], dtypes.int64) - empty_key = constant_op.constant([0, 3], dtypes.int64) - deleted_key = constant_op.constant([-1, -1], dtypes.int64) - default_value = constant_op.constant(-1, dtypes.int64) - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=default_value, - empty_key=empty_key, - deleted_key=deleted_key, - initial_num_buckets=8) - self.assertAllEqual(0, self.evaluate(table.size())) + keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) + values = constant_op.constant([10, 11, 12], dtypes.int64) + empty_key = constant_op.constant([0, 3], dtypes.int64) + deleted_key = constant_op.constant([-1, -1], dtypes.int64) + default_value = constant_op.constant(-1, dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=default_value, + empty_key=empty_key, + deleted_key=deleted_key, + initial_num_buckets=8) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - self.evaluate( - table.insert( - constant_op.constant([[0, 0]], dtypes.int64), - constant_op.constant([13], dtypes.int64))) - self.assertAllEqual(4, self.evaluate(table.size())) - self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) + self.evaluate( + table.insert( + constant_op.constant([[0, 0]], dtypes.int64), + constant_op.constant([13], dtypes.int64))) + self.assertAllEqual(4, self.evaluate(table.size())) + self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) - remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) - self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) + remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) + self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) - input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([4], output.get_shape()) + input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], + dtypes.int64) + output = table.lookup(input_string) + self.assertAllEqual([4], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual([10, -1, 12, -1], result) + result = self.evaluate(output) + self.assertAllEqual([10, -1, 12, -1], result) def testResize(self): - with self.cached_session(): - keys = constant_op.constant([11, 12, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=4) - self.assertAllEqual(0, self.evaluate(table.size())) + keys = constant_op.constant([11, 12, 13], dtypes.int64) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1, + initial_num_buckets=4) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) - self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) - keys2 = constant_op.constant([12, 99], dtypes.int64) - self.evaluate(table.remove(keys2)) - self.assertAllEqual(2, self.evaluate(table.size())) - self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) + keys2 = constant_op.constant([12, 99], dtypes.int64) + self.evaluate(table.remove(keys2)) + self.assertAllEqual(2, self.evaluate(table.size())) + self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) - keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) - values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) + keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) + values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) - self.evaluate(table.insert(keys3, values3)) - self.assertAllEqual(6, self.evaluate(table.size())) - self.assertAllEqual(16, len(self.evaluate(table.export()[0]))) + self.evaluate(table.insert(keys3, values3)) + self.assertAllEqual(6, self.evaluate(table.size())) + self.assertAllEqual(16, len(self.evaluate(table.export()[0]))) - keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], - dtypes.int64) - output = table.lookup(keys4) - self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], self.evaluate(output)) + keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], + dtypes.int64) + output = table.lookup(keys4) + self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], self.evaluate(output)) def testExport(self): - with self.cached_session(): + keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) + values = constant_op.constant([1, 2, 3, 4], dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=100, + deleted_key=200, + initial_num_buckets=8) + self.assertAllEqual(0, self.evaluate(table.size())) - keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) - values = constant_op.constant([1, 2, 3, 4], dtypes.int64) - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=100, - deleted_key=200, - initial_num_buckets=8) - self.assertAllEqual(0, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) + keys2 = constant_op.constant([12, 15], dtypes.int64) + self.evaluate(table.remove(keys2)) + self.assertAllEqual(3, self.evaluate(table.size())) - keys2 = constant_op.constant([12, 15], dtypes.int64) - self.evaluate(table.remove(keys2)) - self.assertAllEqual(3, self.evaluate(table.size())) + exported_keys, exported_values = table.export() - exported_keys, exported_values = table.export() + np_keys = self.evaluate(exported_keys) + np_values = self.evaluate(exported_values) - np_keys = self.evaluate(exported_keys) - np_values = self.evaluate(exported_values) + self.assertAllEqual(8, len(np_keys)) + self.assertAllEqual(8, len(np_values)) - self.assertAllEqual(8, len(np_keys)) - self.assertAllEqual(8, len(np_values)) - - # pair up keys and values, drop extra added dimension - pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] - # sort by key - pairs = pairs[pairs[:, 0].argsort()] - self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], - [100, 0], [100, 0], [200, 2]], pairs) + # pair up keys and values, drop extra added dimension + pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] + # sort by key + pairs = pairs[pairs[:, 0].argsort()] + self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], + [100, 0], [100, 0], [200, 2]], pairs) @test_util.run_v1_only("Saver V1 only") def testSaveRestore(self): @@ -1892,137 +1944,134 @@ class DenseHashTableOpTest(test.TestCase): self.assertAllEqual([0, 1, -1, 3, -1], output) def testReprobe(self): - with self.cached_session(): - # Insert 6 keys into a table with 8 buckets. - # The values are chosen to make sure collisions occur when using GCC STL - keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) - values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=0, - deleted_key=-1, - initial_num_buckets=8) - self.assertAllEqual(0, self.evaluate(table.size())) + # Insert 6 keys into a table with 8 buckets. + # The values are chosen to make sure collisions occur when using GCC STL + keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) + values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1, + initial_num_buckets=8) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(6, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(6, self.evaluate(table.size())) - input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22], - dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([9], output.get_shape()) + input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22], + dtypes.int64) + output = table.lookup(input_string) + self.assertAllEqual([9], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) + result = self.evaluate(output) + self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) def testCustomEmptyKey(self): - with self.cached_session(): - keys = constant_op.constant([11, 0, 13], dtypes.int64) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=12, - deleted_key=-1) - self.assertAllEqual(0, self.evaluate(table.size())) + keys = constant_op.constant([11, 0, 13], dtypes.int64) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=12, + deleted_key=-1) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant([11, 0, 15], dtypes.int64) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + input_string = constant_op.constant([11, 0, 15], dtypes.int64) + output = table.lookup(input_string) + self.assertAllEqual([3], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual([0, 1, -1], result) + result = self.evaluate(output) + self.assertAllEqual([0, 1, -1], result) def testErrors(self): - with self.cached_session(): - table = lookup_ops.DenseHashTable( + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=0, + deleted_key=-1) + + # Inserting the empty key returns an error + keys1 = constant_op.constant([11, 0], dtypes.int64) + values1 = constant_op.constant([0, 1], dtypes.int64) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "empty_key"): + self.evaluate(table.insert(keys1, values1)) + + # Looking up the empty key returns an error + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "empty_key"): + self.evaluate(table.lookup(keys1)) + + # Inserting the deleted key returns an error + keys2 = constant_op.constant([11, -1], dtypes.int64) + values2 = constant_op.constant([0, 1], dtypes.int64) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "deleted_key"): + self.evaluate(table.insert(keys2, values2)) + + # Looking up the empty key returns an error + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "deleted_key"): + self.evaluate(table.lookup(keys2)) + + # Arbitrary tensors of keys are not supported + keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) + values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Expected key shape"): + self.evaluate(table.lookup(keys)) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Expected key shape"): + self.evaluate(table.insert(keys, values)) + + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Number of buckets must be"): + table2 = lookup_ops.DenseHashTable( dtypes.int64, dtypes.int64, default_value=-1, - empty_key=0, - deleted_key=-1) + empty_key=17, + deleted_key=-1, + initial_num_buckets=12) + self.assertAllEqual(0, self.evaluate(table2.size())) - # Inserting the empty key returns an error - keys1 = constant_op.constant([11, 0], dtypes.int64) - values1 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "empty_key"): - self.evaluate(table.insert(keys1, values1)) + with self.assertRaisesRegex( + errors_impl.InvalidArgumentError, + "Empty and deleted keys must have same shape"): + table3 = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=[1, 2]) + self.assertAllEqual(0, self.evaluate(table3.size())) - # Looking up the empty key returns an error - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "empty_key"): - self.evaluate(table.lookup(keys1)) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Empty and deleted keys cannot be equal"): + table4 = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=42, + deleted_key=42) + self.assertAllEqual(0, self.evaluate(table4.size())) - # Inserting the deleted key returns an error - keys2 = constant_op.constant([11, -1], dtypes.int64) - values2 = constant_op.constant([0, 1], dtypes.int64) - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "deleted_key"): - self.evaluate(table.insert(keys2, values2)) - - # Looking up the empty key returns an error - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "deleted_key"): - self.evaluate(table.lookup(keys2)) - - # Arbitrary tensors of keys are not supported - keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Expected key shape"): - self.evaluate(table.lookup(keys)) - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Expected key shape"): - self.evaluate(table.insert(keys, values)) - - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Number of buckets must be"): - table2 = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=17, - deleted_key=-1, - initial_num_buckets=12) - self.assertAllEqual(0, self.evaluate(table2.size())) - - with self.assertRaisesRegex( - errors_impl.InvalidArgumentError, - "Empty and deleted keys must have same shape"): - table3 = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=[1, 2]) - self.assertAllEqual(0, self.evaluate(table3.size())) - - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table4 = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=42, - deleted_key=42) - self.assertAllEqual(0, self.evaluate(table4.size())) - - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Empty and deleted keys cannot be equal"): - table5 = lookup_ops.DenseHashTable( - dtypes.int64, - dtypes.int64, - default_value=-1, - empty_key=[1, 2, 3], - deleted_key=[1, 2, 3]) - self.assertAllEqual(0, self.evaluate(table5.size())) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Empty and deleted keys cannot be equal"): + table5 = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=-1, + empty_key=[1, 2, 3], + deleted_key=[1, 2, 3]) + self.assertAllEqual(0, self.evaluate(table5.size())) @test_util.run_in_graph_and_eager_modes def testStringToResource(self): @@ -2038,6 +2087,30 @@ class DenseHashTableOpTest(test.TestCase): table.insert("v1", v1.handle) self.assertEqual([], table.lookup("v1").shape) + def testExportShapeInference(self): + default_value = -1 + empty_key = 0 + deleted_key = -1 + table = lookup_ops.DenseHashTable( + dtypes.int64, + dtypes.int64, + default_value=default_value, + empty_key=empty_key, + deleted_key=deleted_key) + actual_shapes = [t.shape for t in table.export()] + inferred_shapes = [] + + @def_function.function + def f(): + for t in table.export(): + inferred_shapes.append(t.shape) + + f() + self.assertLen(actual_shapes, 2) + self.assertLen(inferred_shapes, 2) + self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0])) + self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1])) + class IndexTableFromFile(test.TestCase): @@ -2049,68 +2122,65 @@ class IndexTableFromFile(test.TestCase): def test_string_index_table_from_file(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, num_oov_buckets=1) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_string_index_table_from_multicolumn_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, - num_oov_buckets=1, - key_column_index=0, - value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_column_index=0, + value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_string_index_table_from_multicolumn_file_custom_delimiter(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, - num_oov_buckets=1, - key_column_index=0, - value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, - delimiter=" ") - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_column_index=0, + value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + delimiter=" ") + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_string_index_table_from_file_tensor_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") - with self.cached_session(): - vocabulary_file = constant_op.constant(vocabulary_file) - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, num_oov_buckets=1) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + vocabulary_file = constant_op.constant(vocabulary_file) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) - if not context.executing_eagerly(): - self.assertEqual(1, - len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + if not context.executing_eagerly(): + self.assertEqual(1, + len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) @test_util.run_v1_only("placeholder usage") def test_string_index_table_from_file_placeholder_filename(self): @@ -2133,70 +2203,64 @@ class IndexTableFromFile(test.TestCase): def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, - num_oov_buckets=1, - key_dtype=dtypes.int32) - ids = table.lookup( - constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_dtype=dtypes.int32) + ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int64_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab3.txt", values=("42", "1", "-1000")) - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, - num_oov_buckets=1, - key_dtype=dtypes.int64) - ids = table.lookup( - constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_dtype=dtypes.int64) + ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_index_table_from_file_with_default_value(self): default_value = -42 vocabulary_file = self._createVocabFile("f2i_vocab4.txt") - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, default_value=default_value) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, default_value=default_value) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, default_value), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, default_value), self.evaluate(ids)) def test_index_table_from_file_with_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab5.txt") - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, num_oov_buckets=1000) - ids = table.lookup( - constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1000) + ids = table.lookup( + constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual( - ( - 1, # From vocabulary file. - 2, # From vocabulary file. - 867, # 3 + fingerprint("tarkus") mod 300. - 860), # 3 + fingerprint("toccata") mod 300. - self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual( + ( + 1, # From vocabulary file. + 2, # From vocabulary file. + 867, # 3 + fingerprint("tarkus") mod 300. + 860), # 3 + fingerprint("toccata") mod 300. + self.evaluate(ids)) def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): self.assertRaises( @@ -2227,26 +2291,24 @@ class IndexTableFromFile(test.TestCase): def test_index_table_from_file_with_vocab_size_too_small(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, vocab_size=2) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, vocab_size=2) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, -1, -1), self.evaluate(ids)) - self.assertEqual(2, self.evaluate(table.size())) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, -1, -1), self.evaluate(ids)) + self.assertEqual(2, self.evaluate(table.size())) def test_index_table_from_file_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.cached_session(): - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Invalid vocab_size"): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, vocab_size=4) - self.evaluate(table.initializer) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Invalid vocab_size"): + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, vocab_size=4) + self.evaluate(table.initializer) def test_index_table_from_file_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab8.txt") @@ -2257,50 +2319,46 @@ class IndexTableFromFile(test.TestCase): vocabulary_file=vocabulary_file, vocab_size=0) - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, vocab_size=3) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, vocab_size=3) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, -1), self.evaluate(ids)) - self.assertEqual(3, self.evaluate(table.size())) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, -1), self.evaluate(ids)) + self.assertEqual(3, self.evaluate(table.size())) def test_index_table_from_file_with_invalid_hashers(self): vocabulary_file = self._createVocabFile("invalid_hasher.txt") - with self.cached_session(): - with self.assertRaises(TypeError): - lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, - vocab_size=3, - num_oov_buckets=1, - hasher_spec=1) - - table = lookup_ops.index_table_from_file( + with self.assertRaises(TypeError): + lookup_ops.index_table_from_file( vocabulary_file=vocabulary_file, vocab_size=3, num_oov_buckets=1, - hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + hasher_spec=1) - self.assertRaises(ValueError, table.lookup, - constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + vocab_size=3, + num_oov_buckets=1, + hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + + self.assertRaises(ValueError, table.lookup, + constant_op.constant(["salad", "surgery", "tarkus"])) def test_index_table_from_file_table_ref_with_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab9.txt") - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, num_oov_buckets=1) - self.assertIsNotNone(table.resource_handle) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + self.assertIsNotNone(table.resource_handle) def test_index_table_from_file_table_ref_without_oov_buckets(self): vocabulary_file = self._createVocabFile("f2i_vocab10.txt") - with self.cached_session(): - table = lookup_ops.index_table_from_file( - vocabulary_file=vocabulary_file, num_oov_buckets=0) - self.assertIsNotNone(table.resource_handle) + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=0) + self.assertIsNotNone(table.resource_handle) class IndexTableFromTensor(test.TestCase): @@ -2323,75 +2381,67 @@ class IndexTableFromTensor(test.TestCase): self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int32_index_table_from_tensor_with_tensor_init(self): - with self.cached_session(): - table = lookup_ops.index_table_from_tensor( - vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) - ids = table.lookup( - constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) + table = lookup_ops.index_table_from_tensor( + vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) + ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.FailedPreconditionError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.FailedPreconditionError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int64_index_table_from_tensor_with_tensor_init(self): - with self.cached_session(): - table = lookup_ops.index_table_from_tensor( - vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) - ids = table.lookup( - constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) + table = lookup_ops.index_table_from_tensor( + vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) + ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.FailedPreconditionError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, 3), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.FailedPreconditionError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_index_table_from_tensor_with_default_value(self): default_value = -42 - with self.cached_session(): - table = lookup_ops.index_table_from_tensor( - vocabulary_list=["brain", "salad", "surgery"], - default_value=default_value) - ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_tensor( + vocabulary_list=["brain", "salad", "surgery"], + default_value=default_value) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.FailedPreconditionError): - self.evaluate(ids) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((1, 2, default_value), self.evaluate(ids)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.FailedPreconditionError): + self.evaluate(ids) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((1, 2, default_value), self.evaluate(ids)) def test_index_table_from_tensor_missing_vocabulary_list(self): - with self.cached_session(): - with self.assertRaisesRegex(ValueError, - "vocabulary_list must be specified"): - lookup_ops.index_table_from_tensor( - vocabulary_list=None, num_oov_buckets=1) + with self.assertRaisesRegex(ValueError, + "vocabulary_list must be specified"): + lookup_ops.index_table_from_tensor( + vocabulary_list=None, num_oov_buckets=1) def test_index_table_from_tensor_empty_vocabulary_list(self): - with self.cached_session(): - with self.assertRaisesRegex(errors_impl.OpError, - "keys and values cannot be empty"): - _ = lookup_ops.index_table_from_tensor( - vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1) - self.evaluate(lookup_ops.tables_initializer()) + with self.assertRaisesRegex(errors_impl.OpError, + "keys and values cannot be empty"): + _ = lookup_ops.index_table_from_tensor( + vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1) + self.evaluate(lookup_ops.tables_initializer()) def test_index_table_from_tensor_with_invalid_hashers(self): - with self.cached_session(): - with self.assertRaises(TypeError): - lookup_ops.index_table_from_tensor( - vocabulary_list=["brain", "salad", "surgery"], - num_oov_buckets=1, - hasher_spec=1) - - table = lookup_ops.index_table_from_tensor( + with self.assertRaises(TypeError): + lookup_ops.index_table_from_tensor( vocabulary_list=["brain", "salad", "surgery"], num_oov_buckets=1, - hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + hasher_spec=1) - self.assertRaises(ValueError, table.lookup, - constant_op.constant(["salad", "surgery", "tarkus"])) + table = lookup_ops.index_table_from_tensor( + vocabulary_list=["brain", "salad", "surgery"], + num_oov_buckets=1, + hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + + self.assertRaises(ValueError, table.lookup, + constant_op.constant(["salad", "surgery", "tarkus"])) class IndexToStringTableFromFileTest(test.TestCase): @@ -2408,147 +2458,135 @@ class IndexToStringTableFromFileTest(test.TestCase): type_funcs = [str, constant_op.constant] for type_func in type_funcs: vocabulary_file = type_func(vocabulary_path) - with self.cached_session(): - table = lookup_ops.index_to_string_table_from_file( - vocabulary_file=vocabulary_file) - features = table.lookup( - constant_op.constant([0, 1, 2, 3], dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), - self.evaluate(features)) + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file) + features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + self.evaluate(features)) def test_index_to_string_table_from_multicolumn_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) - with self.cached_session(): - table = lookup_ops.index_to_string_table_from_file( - vocabulary_file=vocabulary_file, - key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, - value_column_index=0) - features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), - self.evaluate(features)) + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, + key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_column_index=0) + features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + self.evaluate(features)) def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self): vocabulary_file = self._createVocabFile( "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) - with self.cached_session(): - table = lookup_ops.index_to_string_table_from_file( - vocabulary_file=vocabulary_file, - key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, - value_column_index=0, - delimiter=" ") - features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), - self.evaluate(features)) + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, + key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_column_index=0, + delimiter=" ") + features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + self.evaluate(features)) def test_index_to_string_table_with_default_value(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.cached_session(): - table = lookup_ops.index_to_string_table_from_file( - vocabulary_file=vocabulary_file, default_value=default_value) - features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"salad", b"surgery", default_value), - self.evaluate(features)) + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, default_value=default_value) + features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"salad", b"surgery", default_value), + self.evaluate(features)) def test_index_to_string_table_with_vocab_size_too_small(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") - with self.cached_session(): - table = lookup_ops.index_to_string_table_from_file( - vocabulary_file=vocabulary_file, - vocab_size=2, - default_value=default_value) - features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"salad", default_value, default_value), - self.evaluate(features)) + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, + vocab_size=2, + default_value=default_value) + features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"salad", default_value, default_value), + self.evaluate(features)) def test_index_to_string_table_with_vocab_size_too_large(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") - with self.cached_session(): - with self.assertRaisesRegex(errors_impl.InvalidArgumentError, - "Invalid vocab_size"): - _ = lookup_ops.index_to_string_table_from_file( - vocabulary_file=vocabulary_file, vocab_size=4) - self.evaluate(lookup_ops.tables_initializer()) + with self.assertRaisesRegex(errors_impl.InvalidArgumentError, + "Invalid vocab_size"): + _ = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, vocab_size=4) + self.evaluate(lookup_ops.tables_initializer()) def test_index_to_string_table_with_vocab_size(self): vocabulary_file = self._createVocabFile("f2i_vocab7.txt") - with self.cached_session(): - table = lookup_ops.index_to_string_table_from_file( - vocabulary_file=vocabulary_file, vocab_size=3) - features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, vocab_size=3) + features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"salad", b"surgery", b"UNK"), - self.evaluate(features)) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"salad", b"surgery", b"UNK"), self.evaluate(features)) class IndexToStringTableFromTensorTest(test.TestCase): def test_index_to_string_table_from_tensor(self): - with self.cached_session(): - vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.index_to_string_table_from_tensor( - vocabulary_list=vocabulary_list) + vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) + table = lookup_ops.index_to_string_table_from_tensor( + vocabulary_list=vocabulary_list) - indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) - features = table.lookup(indices) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) + indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) + features = table.lookup(indices) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), - self.evaluate(features)) + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + self.evaluate(features)) def test_duplicate_entries(self): - with self.cached_session(): - vocabulary_list = constant_op.constant(["hello", "hello"]) - table = lookup_ops.index_to_string_table_from_tensor( - vocabulary_list=vocabulary_list) - indices = constant_op.constant([0, 1, 4], dtypes.int64) - features = table.lookup(indices) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"hello", b"hello", b"UNK"), self.evaluate(features)) + vocabulary_list = constant_op.constant(["hello", "hello"]) + table = lookup_ops.index_to_string_table_from_tensor( + vocabulary_list=vocabulary_list) + indices = constant_op.constant([0, 1, 4], dtypes.int64) + features = table.lookup(indices) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"hello", b"hello", b"UNK"), self.evaluate(features)) def test_index_to_string_with_default_value(self): default_value = b"NONE" - with self.cached_session(): - vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.index_to_string_table_from_tensor( - vocabulary_list=vocabulary_list, default_value=default_value) - indices = constant_op.constant([1, 2, 4], dtypes.int64) - features = table.lookup(indices) - if not context.executing_eagerly(): - with self.assertRaises(errors_impl.OpError): - self.evaluate(features) - self.evaluate(lookup_ops.tables_initializer()) - self.assertAllEqual((b"salad", b"surgery", default_value), - self.evaluate(features)) + vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) + table = lookup_ops.index_to_string_table_from_tensor( + vocabulary_list=vocabulary_list, default_value=default_value) + indices = constant_op.constant([1, 2, 4], dtypes.int64) + features = table.lookup(indices) + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(features) + self.evaluate(lookup_ops.tables_initializer()) + self.assertAllEqual((b"salad", b"surgery", default_value), + self.evaluate(features)) class IdTableWithHashBucketsTest(test.TestCase): @@ -2711,45 +2749,40 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsInitializationAcrossSessions(self): vocab_file = self._createVocabFile("feat_to_id_5.txt") - with self.cached_session(): - default_value = -1 - vocab_size = 3 - oov_buckets = 1 - table1 = lookup_ops.IdTableWithHashBuckets( - lookup_ops.StaticHashTable( - lookup_ops.TextFileIdTableInitializer( - vocab_file, vocab_size=vocab_size), default_value), - oov_buckets) + default_value = -1 + vocab_size = 3 + oov_buckets = 1 + table1 = lookup_ops.IdTableWithHashBuckets( + lookup_ops.StaticHashTable( + lookup_ops.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size), default_value), oov_buckets) - self.evaluate(table1.initializer) + self.evaluate(table1.initializer) - input_string_1 = constant_op.constant( - ["brain", "salad", "surgery", "UNK"]) + input_string_1 = constant_op.constant(["brain", "salad", "surgery", "UNK"]) - out1 = table1.lookup(input_string_1) + out1 = table1.lookup(input_string_1) - self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1)) - self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size())) + self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1)) + self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size())) - with self.cached_session(): - default_value = -1 - vocab_size = 3 - oov_buckets = 1 + default_value = -1 + vocab_size = 3 + oov_buckets = 1 - # Underlying lookup table already initialized in previous session. - # No need to call self.evaluate(table2.initializer) - table2 = lookup_ops.IdTableWithHashBuckets( - lookup_ops.StaticHashTable( - lookup_ops.TextFileIdTableInitializer( - vocab_file, vocab_size=vocab_size), default_value), - oov_buckets) + # Underlying lookup table already initialized in previous session. + # No need to call self.evaluate(table2.initializer) + table2 = lookup_ops.IdTableWithHashBuckets( + lookup_ops.StaticHashTable( + lookup_ops.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size), default_value), oov_buckets) - input_string_2 = constant_op.constant(["fruit", "salad", "UNK"]) + input_string_2 = constant_op.constant(["fruit", "salad", "UNK"]) - out2 = table2.lookup(input_string_2) + out2 = table2.lookup(input_string_2) - self.assertAllEqual([3, 1, 3], self.evaluate(out2)) - self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) + self.assertAllEqual([3, 1, 3], self.evaluate(out2)) + self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self): vocab_file = self._createVocabFile("feat_to_id_6.txt") @@ -2938,84 +2971,79 @@ class IdTableWithHashBucketsTest(test.TestCase): def testIdTableWithHashBucketsWithInvalidHashers(self): vocab_file = self._createVocabFile("feat_to_id_4.txt") - with self.cached_session(): - default_value = -1 - vocab_size = 3 - oov_buckets = 1 - lookup_table = lookup_ops.StaticHashTable( - lookup_ops.TextFileIdTableInitializer( - vocab_file, vocab_size=vocab_size), default_value) + default_value = -1 + vocab_size = 3 + oov_buckets = 1 + lookup_table = lookup_ops.StaticHashTable( + lookup_ops.TextFileIdTableInitializer( + vocab_file, vocab_size=vocab_size), default_value) - with self.assertRaises(TypeError): - lookup_ops.IdTableWithHashBuckets( - lookup_table, oov_buckets, hasher_spec=1) + with self.assertRaises(TypeError): + lookup_ops.IdTableWithHashBuckets( + lookup_table, oov_buckets, hasher_spec=1) + table = lookup_ops.IdTableWithHashBuckets( + lookup_table, + oov_buckets, + hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + + input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) + + with self.assertRaises(ValueError): + table.lookup(input_string) + + with self.assertRaises(ValueError): + table = lookup_ops.IdTableWithHashBuckets( + lookup_table, oov_buckets, hasher_spec=lookup_ops.StrongHashSpec([])) + + with self.assertRaises(ValueError): table = lookup_ops.IdTableWithHashBuckets( lookup_table, oov_buckets, - hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) + hasher_spec=lookup_ops.StrongHashSpec([1, 2, 3])) - input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) - - with self.assertRaises(ValueError): - table.lookup(input_string) - - with self.assertRaises(ValueError): - table = lookup_ops.IdTableWithHashBuckets( - lookup_table, - oov_buckets, - hasher_spec=lookup_ops.StrongHashSpec([])) - - with self.assertRaises(ValueError): - table = lookup_ops.IdTableWithHashBuckets( - lookup_table, - oov_buckets, - hasher_spec=lookup_ops.StrongHashSpec([1, 2, 3])) - - with self.assertRaises(TypeError): - table = lookup_ops.IdTableWithHashBuckets( - lookup_table, - oov_buckets, - hasher_spec=lookup_ops.StrongHashSpec([None, 2])) + with self.assertRaises(TypeError): + table = lookup_ops.IdTableWithHashBuckets( + lookup_table, + oov_buckets, + hasher_spec=lookup_ops.StrongHashSpec([None, 2])) def testIdTableWithHashBucketsNoInnerTable(self): - with self.cached_session(): - table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1) - self.assertIsNone(table.resource_handle) + table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1) + self.assertIsNone(table.resource_handle) class MutableHashTableOpTest(test.TestCase): def testMutableHashTable(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, self.evaluate(table.size())) + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) - remove_string = constant_op.constant(["tarkus", "tank"]) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) + remove_string = constant_op.constant(["tarkus", "tank"]) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3], output.get_shape()) + input_string = constant_op.constant(["brain", "salad", "tank"]) + output = table.lookup(input_string) + self.assertAllEqual([3], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual([0, 1, -1], result) + result = self.evaluate(output) + self.assertAllEqual([0, 1, -1], result) - exported_keys, exported_values = table.export() + exported_keys, exported_values = table.export() - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(self.evaluate(exported_keys)) - sorted_values = np.sort(self.evaluate(exported_values)) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - self.assertAllEqual([0, 1, 2], sorted_values) + # exported data is in the order of the internal map, i.e. undefined + sorted_keys = np.sort(self.evaluate(exported_keys)) + sorted_values = np.sort(self.evaluate(exported_values)) + self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) + self.assertAllEqual([0, 1, 2], sorted_values) @test_util.run_v1_only("SaverV1") def testSaveRestore(self): @@ -3214,370 +3242,374 @@ class MutableHashTableOpTest(test.TestCase): self.assertAllEqual([b"-", b"a", b"b"], output) def testMutableHashTableOfTensors(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], - dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, self.evaluate(table.size())) + default_val = constant_op.constant([-1, -1], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) + values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], + dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) - remove_string = constant_op.constant(["tarkus", "tank"]) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) + remove_string = constant_op.constant(["tarkus", "tank"]) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) - self.assertAllEqual([3, 2], output.get_shape()) + input_string = constant_op.constant(["brain", "salad", "tank"]) + output = table.lookup(input_string) + self.assertAllEqual([3, 2], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) + result = self.evaluate(output) + self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) - exported_keys, exported_values = table.export() - # exported data is in the order of the internal map, i.e. undefined - sorted_keys = np.sort(self.evaluate(exported_keys)) - sorted_values = np.sort(self.evaluate(exported_values), axis=0) - self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) - sorted_expected_values = np.sort([[4, 5], [2, 3], [0, 1]], axis=0) - self.assertAllEqual(sorted_expected_values, sorted_values) + exported_keys, exported_values = table.export() + # exported data is in the order of the internal map, i.e. undefined + sorted_keys = np.sort(self.evaluate(exported_keys)) + sorted_values = np.sort(self.evaluate(exported_values), axis=0) + self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) + sorted_expected_values = np.sort([[4, 5], [2, 3], [0, 1]], axis=0) + self.assertAllEqual(sorted_expected_values, sorted_values) def testMutableHashTableExportInsert(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) - table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, self.evaluate(table1.size())) - self.evaluate(table1.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table1.size())) + default_val = constant_op.constant([-1, -1], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + self.assertAllEqual(0, self.evaluate(table1.size())) + self.evaluate(table1.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table1.size())) - input_string = constant_op.constant(["brain", "salad", "tank"]) - expected_output = [[0, 1], [2, 3], [-1, -1]] - output1 = table1.lookup(input_string) - self.assertAllEqual(expected_output, self.evaluate(output1)) + input_string = constant_op.constant(["brain", "salad", "tank"]) + expected_output = [[0, 1], [2, 3], [-1, -1]] + output1 = table1.lookup(input_string) + self.assertAllEqual(expected_output, self.evaluate(output1)) - exported_keys, exported_values = table1.export() - self.assertAllEqual(3, self.evaluate(exported_keys).size) - self.assertAllEqual(6, self.evaluate(exported_values).size) + exported_keys, exported_values = table1.export() + self.assertAllEqual(3, self.evaluate(exported_keys).size) + self.assertAllEqual(6, self.evaluate(exported_values).size) - # Populate a second table from the exported data - table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, self.evaluate(table2.size())) - self.evaluate(table2.insert(exported_keys, exported_values)) - self.assertAllEqual(3, self.evaluate(table2.size())) + # Populate a second table from the exported data + table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + self.assertAllEqual(0, self.evaluate(table2.size())) + self.evaluate(table2.insert(exported_keys, exported_values)) + self.assertAllEqual(3, self.evaluate(table2.size())) - # Verify lookup result is still the same - output2 = table2.lookup(input_string) - self.assertAllEqual(expected_output, self.evaluate(output2)) + # Verify lookup result is still the same + output2 = table2.lookup(input_string) + self.assertAllEqual(expected_output, self.evaluate(output2)) def testMutableHashTableOfTensorsInvalidShape(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = constant_op.constant([-1, -1], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - # Shape [6] instead of [3, 2] - values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - self.evaluate(table.insert(keys, values)) - - # Shape [2,3] instead of [3, 2] - values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - self.evaluate(table.insert(keys, values)) - - # Shape [2, 2] instead of [3, 2] - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - self.evaluate(table.insert(keys, values)) - - # Shape [3, 1] instead of [3, 2] - values = constant_op.constant([[0], [2], [4]], dtypes.int64) - with self.assertRaisesOpError("Expected shape"): - self.evaluate(table.insert(keys, values)) - - # Valid Insert - values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + # Shape [6] instead of [3, 2] + values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) + with self.assertRaisesOpError("Expected shape"): self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + + # Shape [2,3] instead of [3, 2] + values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64) + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.insert(keys, values)) + + # Shape [2, 2] instead of [3, 2] + values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.insert(keys, values)) + + # Shape [3, 1] instead of [3, 2] + values = constant_op.constant([[0], [2], [4]], dtypes.int64) + with self.assertRaisesOpError("Expected shape"): + self.evaluate(table.insert(keys, values)) + + # Valid Insert + values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) def testMutableHashTableInvalidDefaultValue(self): - with self.cached_session(): - default_val = constant_op.constant([[-1, -1]], dtypes.int64) - with self.assertRaisesOpError("Default value must be a vector"): - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, self.evaluate(table.size())) + default_val = constant_op.constant([[-1, -1]], dtypes.int64) + with self.assertRaisesOpError("Default value must be a vector"): + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + self.assertAllEqual(0, self.evaluate(table.size())) def testMutableHashTableDuplicateInsert(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) - values = constant_op.constant([0, 1, 2, 3], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.assertAllEqual(0, self.evaluate(table.size())) + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) + values = constant_op.constant([0, 1, 2, 3], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) + input_string = constant_op.constant(["brain", "salad", "tank"]) + output = table.lookup(input_string) - result = self.evaluate(output) - self.assertAllEqual([3, 1, -1], result) + result = self.evaluate(output) + self.assertAllEqual([3, 1, -1], result) def testMutableHashTableFindHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant([["brain", "salad"], - ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2], output.get_shape()) + input_string = constant_op.constant([["brain", "salad"], + ["tank", "tarkus"]]) + output = table.lookup(input_string) + self.assertAllEqual([2, 2], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual([[0, 1], [-1, -1]], result) + result = self.evaluate(output) + self.assertAllEqual([[0, 1], [-1, -1]], result) def testMutableHashTableInsertHighRank(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = -1 + keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) + values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) + input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) + output = table.lookup(input_string) - result = self.evaluate(output) - self.assertAllEqual([0, 1, 3, -1], result) + result = self.evaluate(output) + self.assertAllEqual([0, 1, 3, -1], result) def testMutableHashTableRemoveHighRank(self): - with self.test_session(): - default_val = -1 - keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) - values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = -1 + keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) + values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(4, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(4, self.evaluate(table.size())) - remove_string = constant_op.constant(["salad", "tarkus"]) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(3, self.evaluate(table.size())) + remove_string = constant_op.constant(["salad", "tarkus"]) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) - output = table.lookup(input_string) + input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) + output = table.lookup(input_string) - result = self.evaluate(output) - self.assertAllEqual([0, -1, 3, -1], result) + result = self.evaluate(output) + self.assertAllEqual([0, -1, 3, -1], result) def testMutableHashTableOfTensorsFindHighRank(self): - with self.cached_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = constant_op.constant([-1, -1, -1], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], + dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant([["brain", "salad"], - ["tank", "tarkus"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) + input_string = constant_op.constant([["brain", "salad"], + ["tank", "tarkus"]]) + output = table.lookup(input_string) + self.assertAllEqual([2, 2, 3], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual( - [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) + result = self.evaluate(output) + self.assertAllEqual( + [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) def testMutableHashTableOfTensorsRemoveHighRank(self): - with self.test_session(): - default_val = constant_op.constant([-1, -1, -1], dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], - dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = constant_op.constant([-1, -1, -1], dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], + dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - remove_string = constant_op.constant([["brain", "tank"]]) - self.evaluate(table.remove(remove_string)) - self.assertAllEqual(2, self.evaluate(table.size())) + remove_string = constant_op.constant([["brain", "tank"]]) + self.evaluate(table.remove(remove_string)) + self.assertAllEqual(2, self.evaluate(table.size())) - input_string = constant_op.constant([["brain", "salad"], - ["surgery", "tank"]]) - output = table.lookup(input_string) - self.assertAllEqual([2, 2, 3], output.get_shape()) + input_string = constant_op.constant([["brain", "salad"], + ["surgery", "tank"]]) + output = table.lookup(input_string) + self.assertAllEqual([2, 2, 3], output.get_shape()) - result = self.evaluate(output) - self.assertAllEqual( - [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) + result = self.evaluate(output) + self.assertAllEqual( + [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) def testMultipleMutableHashTables(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) - table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - table3 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) - self.evaluate(table1.insert(keys, values)) - self.evaluate(table2.insert(keys, values)) - self.evaluate(table3.insert(keys, values)) + table1 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + table3 = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + self.evaluate(table1.insert(keys, values)) + self.evaluate(table2.insert(keys, values)) + self.evaluate(table3.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table1.size())) - self.assertAllEqual(3, self.evaluate(table2.size())) - self.assertAllEqual(3, self.evaluate(table3.size())) + self.assertAllEqual(3, self.evaluate(table1.size())) + self.assertAllEqual(3, self.evaluate(table2.size())) + self.assertAllEqual(3, self.evaluate(table3.size())) - input_string = constant_op.constant(["brain", "salad", "tank"]) - output1 = table1.lookup(input_string) - output2 = table2.lookup(input_string) - output3 = table3.lookup(input_string) + input_string = constant_op.constant(["brain", "salad", "tank"]) + output1 = table1.lookup(input_string) + output2 = table2.lookup(input_string) + output3 = table3.lookup(input_string) - out1, out2, out3 = self.evaluate([output1, output2, output3]) - self.assertAllEqual([0, 1, -1], out1) - self.assertAllEqual([0, 1, -1], out2) - self.assertAllEqual([0, 1, -1], out3) + out1, out2, out3 = self.evaluate([output1, output2, output3]) + self.assertAllEqual([0, 1, -1], out1) + self.assertAllEqual([0, 1, -1], out2) + self.assertAllEqual([0, 1, -1], out3) def testMutableHashTableWithTensorDefault(self): - with self.cached_session(): - default_val = constant_op.constant(-1, dtypes.int64) - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = constant_op.constant(-1, dtypes.int64) + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) + input_string = constant_op.constant(["brain", "salad", "tank"]) + output = table.lookup(input_string) - result = self.evaluate(output) - self.assertAllEqual([0, 1, -1], result) + result = self.evaluate(output) + self.assertAllEqual([0, 1, -1], result) def testSignatureMismatch(self): - with self.cached_session(): - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) - # insert with keys of the wrong type - with self.assertRaises(ValueError): - self.evaluate(table.insert(constant_op.constant([4, 5, 6]), values)) + # insert with keys of the wrong type + with self.assertRaises(ValueError): + self.evaluate(table.insert(constant_op.constant([4, 5, 6]), values)) - # insert with values of the wrong type - with self.assertRaises(ValueError): - self.evaluate(table.insert(keys, constant_op.constant(["a", "b", "c"]))) + # insert with values of the wrong type + with self.assertRaises(ValueError): + self.evaluate(table.insert(keys, constant_op.constant(["a", "b", "c"]))) - self.assertAllEqual(0, self.evaluate(table.size())) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string_ref = variables.Variable("brain") - input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) - self.evaluate(variables.global_variables_initializer()) + input_string_ref = variables.Variable("brain") + input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) + self.evaluate(variables.global_variables_initializer()) - # Ref types do not produce an insert signature mismatch. - self.evaluate(table.insert(input_string_ref, input_int64_ref)) - self.assertAllEqual(3, self.evaluate(table.size())) + # Ref types do not produce an insert signature mismatch. + self.evaluate(table.insert(input_string_ref, input_int64_ref)) + self.assertAllEqual(3, self.evaluate(table.size())) - # Ref types do not produce a lookup signature mismatch. - self.assertEqual(-1, self.evaluate(table.lookup(input_string_ref))) + # Ref types do not produce a lookup signature mismatch. + self.assertEqual(-1, self.evaluate(table.lookup(input_string_ref))) - # lookup with keys of the wrong type - input_string = constant_op.constant([1, 2, 3], dtypes.int64) - with self.assertRaises(ValueError): - self.evaluate(table.lookup(input_string)) + # lookup with keys of the wrong type + input_string = constant_op.constant([1, 2, 3], dtypes.int64) + with self.assertRaises(ValueError): + self.evaluate(table.lookup(input_string)) - # default value of the wrong type - with self.assertRaises(TypeError): - lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, "UNK") + # default value of the wrong type + with self.assertRaises(TypeError): + lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, "UNK") def testMutableHashTableStringFloat(self): - with self.cached_session(): - default_val = -1.5 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) - table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, - default_val) - self.assertAllEqual(0, self.evaluate(table.size())) + default_val = -1.5 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, + default_val) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant(["brain", "salad", "tank"]) - output = table.lookup(input_string) + input_string = constant_op.constant(["brain", "salad", "tank"]) + output = table.lookup(input_string) - result = self.evaluate(output) - self.assertAllClose([0, 1.1, default_val], result) + result = self.evaluate(output) + self.assertAllClose([0, 1.1, default_val], result) def testMutableHashTableIntFloat(self): - with self.cached_session(): - default_val = -1.0 - keys = constant_op.constant([3, 7, 0], dtypes.int64) - values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) - table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.float32, - default_val) - self.assertAllEqual(0, self.evaluate(table.size())) + default_val = -1.0 + keys = constant_op.constant([3, 7, 0], dtypes.int64) + values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) + table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.float32, + default_val) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant([7, 0, 11], dtypes.int64) - output = table.lookup(input_string) + input_string = constant_op.constant([7, 0, 11], dtypes.int64) + output = table.lookup(input_string) - result = self.evaluate(output) - self.assertAllClose([-1.2, 9.9, default_val], result) + result = self.evaluate(output) + self.assertAllClose([-1.2, 9.9, default_val], result) def testMutableHashTableInt64String(self): - with self.cached_session(): - default_val = "n/a" - keys = constant_op.constant([0, 1, 2], dtypes.int64) - values = constant_op.constant(["brain", "salad", "surgery"]) - table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.string, - default_val) - self.assertAllEqual(0, self.evaluate(table.size())) + default_val = "n/a" + keys = constant_op.constant([0, 1, 2], dtypes.int64) + values = constant_op.constant(["brain", "salad", "surgery"]) + table = lookup_ops.MutableHashTable(dtypes.int64, dtypes.string, + default_val) + self.assertAllEqual(0, self.evaluate(table.size())) - self.evaluate(table.insert(keys, values)) - self.assertAllEqual(3, self.evaluate(table.size())) + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) - input_string = constant_op.constant([0, 1, 3], dtypes.int64) - output = table.lookup(input_string) + input_string = constant_op.constant([0, 1, 3], dtypes.int64) + output = table.lookup(input_string) - result = self.evaluate(output) - self.assertAllEqual((b"brain", b"salad", b"n/a"), result) + result = self.evaluate(output) + self.assertAllEqual((b"brain", b"salad", b"n/a"), result) + + def testExportShapeInference(self): + default_value = -1 + table = lookup_ops.MutableHashTable( + dtypes.int64, + dtypes.int64, + default_value=default_value) + actual_shapes = [t.shape for t in table.export()] + inferred_shapes = [] + + @def_function.function + def f(): + for t in table.export(): + inferred_shapes.append(t.shape) + + f() + self.assertLen(actual_shapes, 2) + self.assertLen(inferred_shapes, 2) + self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0])) + self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1])) class MutableHashTableBenchmark(test.Benchmark): diff --git a/tensorflow/python/kernel_tests/proto/BUILD b/tensorflow/python/kernel_tests/proto/BUILD index cc6147a8fd2..e5c46c76e2e 100644 --- a/tensorflow/python/kernel_tests/proto/BUILD +++ b/tensorflow/python/kernel_tests/proto/BUILD @@ -28,7 +28,6 @@ tf_py_test( "no_pip", # TODO(b/78026780) "no_windows", # TODO(b/78028010) ], - tfrt_enabled = True, deps = [ ":decode_proto_op_test_base", ":py_test_deps", @@ -49,7 +48,6 @@ tf_py_test( "no_pip", # TODO(b/78026780) "no_windows", # TODO(b/78028010) ], - tfrt_enabled = True, deps = [ ":encode_proto_op_test_base", ":py_test_deps", diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index 6add4128371..35129a59b84 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -143,7 +143,6 @@ cuda_py_test( srcs = ["random_gamma_test.py"], shard_count = 4, tags = ["nozapfhahn"], - tfrt_enabled = True, deps = [ ":util", "//tensorflow/python:array_ops", diff --git a/tensorflow/python/kernel_tests/random/random_poisson_test.py b/tensorflow/python/kernel_tests/random/random_poisson_test.py index 51dd4cb47ca..eafa1d9382c 100644 --- a/tensorflow/python/kernel_tests/random/random_poisson_test.py +++ b/tensorflow/python/kernel_tests/random/random_poisson_test.py @@ -171,6 +171,11 @@ class RandomPoissonTest(test.TestCase): constant_op.constant([1], dtype=lam_dt), [10], dtype=out_dt).eval() + @test_util.run_deprecated_v1 + def testInfRate(self): + sample = random_ops.random_poisson(shape=[2], lam=np.inf) + self.assertAllEqual([np.inf, np.inf], self.evaluate(sample)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 839235233ff..44279a98a39 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -57,8 +57,6 @@ from tensorflow.python.training import training_util from tensorflow.python.util import compat -@test_util.disable_tfrt( - "Trying to assign variable with wrong dtype. b/156200342") @test_util.with_control_flow_v2 class ResourceVariableOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): @@ -105,6 +103,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, v0 = resource_variable_ops.ResourceVariable(1.0) self.assertAllEqual(v0.numpy(), 1.0) + @test_util.disable_tfrt("b/169375363: error code support") def testReadVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( @@ -200,6 +199,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, value, _ = sess.run([v, v.assign_add(1.0)]) self.assertAllEqual(value, 0.0) + @test_util.disable_tfrt("b/169375363: error code support") def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( @@ -336,7 +336,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0] self.assertAllEqual(g.shape.as_list(), [1, 2]) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_deprecated_v1 def testGradientCondInWhileLoop(self): v = resource_variable_ops.ResourceVariable(initial_value=1.0) @@ -751,6 +750,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertEqual(v.handle.op.colocation_groups(), v.initializer.inputs[1].op.colocation_groups()) + @test_util.disable_tfrt("b/169375363: error code support") def testCountUpTo(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(0, name="upto") @@ -758,6 +758,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, with self.assertRaises(errors.OutOfRangeError): v.count_up_to(1) + @test_util.disable_tfrt("b/169375363: error code support") def testCountUpToFunction(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(0, name="upto") @@ -856,6 +857,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, variable_def=other_v_def) self.assertIsNotNone(other_v_prime._cached_value) + @test_util.disable_tfrt("b/169375363: error code support") def testVariableDefInitializedInstances(self): with ops.Graph().as_default(), self.cached_session(): v_def = resource_variable_ops.ResourceVariable( @@ -977,6 +979,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.evaluate(assign_without_read) self.assertEqual(0.0, self.evaluate(v.value())) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes @test_util.run_v1_only("b/120545219") def testDestroyResource(self): @@ -1003,6 +1006,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, [assign], feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)}) + @test_util.disable_tfrt("b/169375363: error code support") def testAssignDifferentShapesEagerNotAllowed(self): with context.eager_mode(): with variable_scope.variable_scope("foo"): @@ -1013,7 +1017,6 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, assign = var.assign(np.zeros(shape=[2, 2])) self.evaluate(assign) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.disable_xla("XLA doesn't allow changing shape at assignment, as " "dictated by tf2xla/xla_resource.cc:SetTypeAndShape") @test_util.run_in_graph_and_eager_modes @@ -1066,6 +1069,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, .batch_scatter_update(batch_slices2), [[1, 3], [2, 3]]) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testInitValueWrongShape(self): with self.assertRaisesWithPredicateMatch( @@ -1084,6 +1088,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertEqual(v.dtype, w.dtype) # TODO(alive): get caching to work in eager mode. + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_deprecated_v1 def testCachingDevice(self): with ops.device("/job:server/task:1"): @@ -1100,6 +1105,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, with self.assertRaises(ValueError): _ = w.value().op.get_attr("_class") + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_deprecated_v1 def testSharedName(self): with self.cached_session(): @@ -1158,6 +1164,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, v.initializer.run(feed_dict={v.initial_value: 3.0}) self.assertEqual(3.0, v.value().eval()) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_v1_only("b/120545219") def testControlFlowInitialization(self): """Expects an error if an initializer is in a control-flow scope.""" @@ -1245,6 +1252,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertEqual(1, v1.read_value().numpy()) self.assertEqual(2, v2.read_value().numpy()) + @test_util.disable_tfrt("b/169375363: error code support") def testDestruction(self): with context.eager_mode(): var = resource_variable_ops.ResourceVariable(initial_value=1.0, @@ -1332,6 +1340,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, state_ops.scatter_update(v, [1], [3]) self.assertAllEqual([1.0, 3.0], v.numpy()) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testScatterUpdateInvalidArgs(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update") @@ -1341,6 +1350,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, with self.assertRaisesRegex(Exception, r"shape.*2.*3"): state_ops.scatter_update(v, [0, 1], [0, 1, 2]) + @test_util.disable_tfrt("b/169375363: error code support") @test_util.run_in_graph_and_eager_modes def testAssignIncompatibleShape(self): v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) @@ -1387,7 +1397,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, # TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create # EagerTensor constants with TensorProto inputs. - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") + @test_util.disable_tfrt("Does not support tf.Const in lowering.") @test_util.run_in_graph_and_eager_modes() def testVariantInitializer(self): variant_shape_and_type_data = self.create_variant_shape_and_type_data() diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py index 762a644b065..6b229ea80f7 100644 --- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py +++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py @@ -309,6 +309,7 @@ class SpaceToDepthTest(test.TestCase): actual_vals, expected_vals = self.evaluate([actual, expected]) self.assertTrue(np.array_equal(actual_vals, expected_vals)) + @test_util.disable_tfrt("b/169901260") def testAgainstTranspose(self): self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", dtypes.float32, False) self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", dtypes.float32, False) diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index 79f1c488f35..8ec1756c154 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -97,7 +97,6 @@ class SparseTensorDenseMatMulTest(test.TestCase): self._testMatmul(x, y, indices_dtype=indices_dtype) - @test_util.run_deprecated_v1 def testBasic(self): np.random.seed(127) # Repeatable results self._testBasic(np.int32) @@ -108,7 +107,6 @@ class SparseTensorDenseMatMulTest(test.TestCase): self._testBasic(np.int32, indices_dtype=np.int32) self._testBasic(np.float32, indices_dtype=np.int32) - @test_util.run_deprecated_v1 def testShapeInference(self): x = np.random.rand(10, 10) x[np.abs(x) < 0.5] = 0 # Make it sparse @@ -116,97 +114,91 @@ class SparseTensorDenseMatMulTest(test.TestCase): x_indices = np.vstack(np.where(x)).astype(np.int64).T x_values = x[np.where(x)] x_shape = x.shape - x_st = sparse_tensor.SparseTensor(x_indices, x_values, x_shape) - result = sparse_ops.sparse_tensor_dense_matmul(x_st, y) - self.assertEqual(result.get_shape(), (10, 20)) - x_shape_unknown = array_ops.placeholder(dtype=dtypes.int64, shape=None) - x_st_shape_unknown = sparse_tensor.SparseTensor(x_indices, x_values, - x_shape_unknown) - result_left_shape_unknown = sparse_ops.sparse_tensor_dense_matmul( - x_st_shape_unknown, y) - self.assertEqual(result_left_shape_unknown.get_shape().as_list(), - [None, 20]) + with ops.Graph().as_default(): + x_st = sparse_tensor.SparseTensor(x_indices, x_values, x_shape) + result = sparse_ops.sparse_tensor_dense_matmul(x_st, y) + self.assertEqual(result.get_shape(), (10, 20)) - x_shape_inconsistent = [10, 15] - x_st_shape_inconsistent = sparse_tensor.SparseTensor(x_indices, x_values, - x_shape_inconsistent) - with self.assertRaisesRegex(ValueError, "Dimensions must be equal"): - sparse_ops.sparse_tensor_dense_matmul(x_st_shape_inconsistent, y) + x_shape_unknown = array_ops.placeholder(dtype=dtypes.int64, shape=None) + x_st_shape_unknown = sparse_tensor.SparseTensor(x_indices, x_values, + x_shape_unknown) + result_left_shape_unknown = sparse_ops.sparse_tensor_dense_matmul( + x_st_shape_unknown, y) + self.assertEqual(result_left_shape_unknown.get_shape().as_list(), + [None, 20]) - @test_util.deprecated_graph_mode_only + x_shape_inconsistent = [10, 15] + x_st_shape_inconsistent = sparse_tensor.SparseTensor( + x_indices, x_values, x_shape_inconsistent) + with self.assertRaisesRegex(ValueError, "Dimensions must be equal"): + sparse_ops.sparse_tensor_dense_matmul(x_st_shape_inconsistent, y) + + @test_util.run_in_graph_and_eager_modes(use_gpu=False) def testInvalidIndicesForSparseTensorDenseMatmul(self): - # Note: use_gpu=False because nice errors are only returned from CPU kernel. - with self.session(use_gpu=False): - indices = np.matrix([[1, 10]]).astype(np.int64) - values = np.array([10]).astype(np.float32) - shape = [3, 2] - sparse_t = sparse_tensor.SparseTensor(indices, values, shape) + # TODO(b/169813429): Make GPU kernel return nice errors too. + indices = np.matrix([[1, 10]]).astype(np.int64) + values = np.array([10]).astype(np.float32) + shape = [3, 2] + sparse_t = sparse_tensor.SparseTensor(indices, values, shape) - # Test multiplying by both a small and large dense matrix, to hit - # both cases in the kernel. - dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) - with self.assertRaisesOpError( - "k .10. from index.0,1. out of bounds .>=2."): - self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t)) - dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) - with self.assertRaisesOpError( - "k .10. from index.0,1. out of bounds .>=2."): - self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t)) + # Test multiplying by both a small and large dense matrix, to hit + # both cases in the kernel. + dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) + with self.assertRaisesOpError("k .10. from index.0,1. out of bounds .>=2."): + self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t)) + dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) + with self.assertRaisesOpError("k .10. from index.0,1. out of bounds .>=2."): + self.evaluate(sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t)) - # Repeat with adjoint_a, to get a different error. - dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) - with self.assertRaisesOpError( - "m .10. from index.0,1. out of bounds .>=2."): - self.evaluate( - sparse_ops.sparse_tensor_dense_matmul( - sparse_t, dense_t, adjoint_a=True)) - dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) - with self.assertRaisesOpError( - "m .10. from index.0,1. out of bounds .>=2."): - self.evaluate( - sparse_ops.sparse_tensor_dense_matmul( - sparse_t, dense_t, adjoint_a=True)) + # Repeat with adjoint_a, to get a different error. + dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) + with self.assertRaisesOpError("m .10. from index.0,1. out of bounds .>=2."): + self.evaluate( + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True)) + dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) + with self.assertRaisesOpError("m .10. from index.0,1. out of bounds .>=2."): + self.evaluate( + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True)) + @test_util.run_gpu_only def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self): - # Note: use_gpu=False because nice errors are only returned from CPU kerne - if not test.is_gpu_available(): - return - with self.session(use_gpu=True): - indices = np.array([[1, 10]]).astype(np.int64) - values = np.array([10]).astype(np.float32) - shape = [3, 2] - sparse_t = sparse_tensor.SparseTensor(indices, values, shape) + indices = np.array([[1, 10]]).astype(np.int64) + values = np.array([10]).astype(np.float32) + shape = [3, 2] + sparse_t = sparse_tensor.SparseTensor(indices, values, shape) - # Test multiplying by both a small and large dense matrix, to hit - # both cases in the kernel. - dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) - expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32) - self.assertAllClose(expected_t, - sparse_ops.sparse_tensor_dense_matmul( - sparse_t, dense_t)) - dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) - expected_t = np.array( - [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32) - self.assertAllClose(expected_t, - sparse_ops.sparse_tensor_dense_matmul( - sparse_t, dense_t)) + # Test multiplying by both a small and large dense matrix, to hit + # both cases in the kernel. + dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose( + expected_t, sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t)) + dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) + expected_t = np.array([[0] * 500, [np.nan] * 500, [0] * 500], + dtype=np.float32) + self.assertAllClose( + expected_t, sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t)) - # Repeat with adjoint_a, now the error is that the sparse index - # is OOO w.r.t. the output. The GPU kernel can't do much here, - # so it just doesn't accumulate. + # Repeat with adjoint_a, now the error is that the sparse index + # is OOO w.r.t. the output. The GPU kernel can't do much here, + # so it just doesn't accumulate. - dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) - expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32) - self.assertAllClose(expected_t, - sparse_ops.sparse_tensor_dense_matmul( - sparse_t, dense_t, adjoint_a=True)) + dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose( + expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True)) - dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) - expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32) - self.assertAllClose(expected_t, - sparse_ops.sparse_tensor_dense_matmul( - sparse_t, dense_t, adjoint_a=True)) + dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) + expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32) + self.assertAllClose( + expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True)) # Tests setting one dimension to be a high value. def _testLarge(self, np_dtype): @@ -235,7 +227,6 @@ class SparseTensorDenseMatMulTest(test.TestCase): self._testLarge(np.complex128) # Tests random sized matrices. - @test_util.run_deprecated_v1 def testFloatRandom(self): np.random.seed(127) # Repeatable results for _ in range(8): diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py index cf1337b493d..c53f196ecb9 100644 --- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py @@ -25,6 +25,8 @@ import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.eager import backprop as backprop_lib +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -32,7 +34,7 @@ from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops -from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -79,33 +81,34 @@ class SparseXentTest(test.TestCase): self.assertAllClose([0.0, 0.0, 0.0], tf_loss) self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop) - @test_util.run_deprecated_v1 - @test_util.disable_xla("XLA cannot assert inside of a kernel.") - def testInvalidLabel(self): + @test_util.run_gpu_only() + def testInvalidLabelGPU(self): features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.], [1., 2., 3., 4.]] labels = [4, 3, 0, -1] + loss, backprop = self.evaluate( + gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels)) + self.assertAllClose([[np.nan] * 4, [0.25, 0.25, 0.25, -0.75], + [-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4], + backprop, + rtol=1e-3, + atol=1e-3) + self.assertAllClose([np.nan, 1.3862, 3.4420, np.nan], + loss, + rtol=1e-3, + atol=1e-3) - if test.is_built_with_gpu_support() and test.is_gpu_available(): - with self.session(use_gpu=True) as sess: - loss, backprop = ( - gen_nn_ops.sparse_softmax_cross_entropy_with_logits( - features, labels)) - tf_loss, tf_backprop = self.evaluate([loss, backprop]) - self.assertAllClose( - [[np.nan] * 4, [0.25, 0.25, 0.25, -0.75], - [-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4], - tf_backprop, - rtol=1e-3, - atol=1e-3) - self.assertAllClose( - [np.nan, 1.3862, 3.4420, np.nan], tf_loss, rtol=1e-3, atol=1e-3) - - with self.session(use_gpu=False) as sess: - loss, backprop = ( + @test_util.run_in_graph_and_eager_modes(use_gpu=False) + @test_util.disable_xla("XLA cannot assert inside of a kernel.") + def testInvalidLabelCPU(self): + features = [[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 2., 3., 4.], + [1., 2., 3., 4.]] + labels = [4, 3, 0, -1] + with self.assertRaisesRegex( + (errors_impl.InvalidArgumentError, errors_impl.UnknownError), + "Received a label value of"): + self.evaluate( gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels)) - with self.assertRaisesOpError("Received a label value of"): - self.evaluate([loss, backprop]) def testNpXent(self): # We create 2 batches of logits for testing. @@ -153,9 +156,8 @@ class SparseXentTest(test.TestCase): nn_ops.sparse_softmax_cross_entropy_with_logits( labels=constant_op.constant(0), logits=constant_op.constant(1.0)) - @test_util.run_deprecated_v1 def testLabelsPlaceholderScalar(self): - with self.session(use_gpu=True): + with ops_lib.Graph().as_default(), self.session(use_gpu=True): labels = array_ops.placeholder(np.int32) y = nn_ops.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=[[7.]]) @@ -189,7 +191,7 @@ class SparseXentTest(test.TestCase): def testEmpty(self): self._testXent(np.zeros((0, 3)), np.zeros((0,), dtype=np.int32)) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testGradient(self): with self.session(use_gpu=True) as sess: l = constant_op.constant([3, 0, 1], name="l") @@ -198,22 +200,28 @@ class SparseXentTest(test.TestCase): shape=[3, 4], dtype=dtypes.float64, name="f") - x = nn_ops.sparse_softmax_cross_entropy_with_logits( - labels=l, logits=f, name="xent") - err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3]) - # Check that no extra computation performed. When only first derivative is - # requested, second derivative must not be computed. So when there is no - # second derivative, there is no `BatchMatMul` op in the graph. - op_names = [ - op.op_def.name for op in sess.graph.get_operations() if op.op_def - ] - self.assertNotIn("BatchMatMul", op_names) - self.assertNotIn("BatchMatMulV2", op_names) + def xent(f): + # gradient_checker_v2.computee_gradient doesn't take int32/int64. + # labels must be of type int32/int64, so passing them separately here. + return nn_ops.sparse_softmax_cross_entropy_with_logits( + labels=l, logits=f, name="xent") - self.assertLess(err, 5e-8) + theoretical, numerical = gradient_checker_v2.compute_gradient(xent, [f]) + + if not context.executing_eagerly(): + # Check that no extra computation performed. When only first derivative + # is requested, second derivative must not be computed. So when there is + # no second derivative, there is no `BatchMatMul` op in the graph. + op_names = [ + op.op_def.name for op in sess.graph.get_operations() if op.op_def + ] + self.assertNotIn("BatchMatMul", op_names) + self.assertNotIn("BatchMatMulV2", op_names) + + tol = 5e-8 + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) - @test_util.run_deprecated_v1 def testSecondGradient(self): with self.session() as sess: l = constant_op.constant([3, 0, 1], name="l") @@ -222,51 +230,67 @@ class SparseXentTest(test.TestCase): shape=[3, 4], dtype=dtypes.float64, name="f") - x = nn_ops.sparse_softmax_cross_entropy_with_logits( - labels=l, logits=f, name="xent") - gradients = gradients_impl.gradients(x, [f])[0] - err = gradient_checker.compute_gradient_error(f, [3, 4], gradients, - [3, 4]) + def xent_grad(f): + if not context.executing_eagerly(): + return gradients_impl.gradients( + nn_ops.sparse_softmax_cross_entropy_with_logits( + labels=l, logits=f, name="xent"), [f])[0] + with backprop_lib.GradientTape() as tape: + tape.watch(f) + return tape.gradient( + nn_ops.sparse_softmax_cross_entropy_with_logits( + labels=l, logits=f, name="xent"), [f])[0] - # Check that second derivative is calculated. - # (it is equivalent to being `BatchMatMul` op in the graph because of - # implementation of xentropy grad) - op_names = [ - op.op_def.name for op in sess.graph.get_operations() if op.op_def - ] - self.assertIn("BatchMatMulV2", op_names) + theoretical, numerical = gradient_checker_v2.compute_gradient( + xent_grad, [f]) - self.assertLess(err, 5e-8) + if not context.executing_eagerly(): + # Check that second derivative is calculated. + # (it is equivalent to being `BatchMatMul` op in the graph because of + # implementation of xentropy grad) + op_names = [ + op.op_def.name for op in sess.graph.get_operations() if op.op_def + ] + self.assertIn("BatchMatMulV2", op_names) + tol = 5e-8 + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def _testHighDim(self, features, labels): np_loss, np_backprop = self._npXent(np.array(features), np.array(labels)) # manually reshape loss np_loss = np.reshape(np_loss, np.array(labels).shape) - with self.cached_session(use_gpu=True) as sess: - loss = nn_ops.sparse_softmax_cross_entropy_with_logits( - labels=labels, logits=features) - backprop = loss.op.inputs[0].op.outputs[1] - tf_loss, tf_backprop = self.evaluate([loss, backprop]) + tf_loss = nn_ops.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=features) + if not context.executing_eagerly(): + tf_backprop = tf_loss.op.inputs[0].op.outputs[1] + else: + with backprop_lib.GradientTape() as tape: + features = constant_op.constant(features) + tape.watch(features) + tf_backprop = tape.gradient( + nn_ops.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=features), [features])[0] + tf_backprop = array_ops.reshape(tf_backprop, np_backprop.shape) + self.assertAllCloseAccordingToType(np_loss, tf_loss) self.assertAllCloseAccordingToType(np_backprop, tf_backprop) - @test_util.run_deprecated_v1 def testHighDim(self): features = [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]] labels = [[3], [0]] self._testHighDim(features, labels) - @test_util.run_deprecated_v1 def testHighDim2(self): features = [[[1., 1., 1., 1.], [2., 2., 2., 2.]], [[1., 2., 3., 4.], [5., 6., 7., 8.]]] labels = [[3, 2], [0, 3]] self._testHighDim(features, labels) - @test_util.run_deprecated_v1 def testScalarHandling(self): - with self.session(use_gpu=False) as sess: + with ops_lib.Graph().as_default(), self.session(use_gpu=False) as sess: with self.assertRaisesRegex(errors_impl.InvalidArgumentError, ".*labels must be 1-D.*"): labels = array_ops.placeholder(dtypes.int32, shape=[None, 1]) diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 0d3bbb5144d..e04bf9798eb 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -173,7 +173,6 @@ class VariablesTestCase(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, "shape.*and.*are incompatible"): var.assign(np.zeros(shape=[2, 2])) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_in_graph_and_eager_modes def testAssignDifferentShapesAllowed(self): var = variables.Variable(np.zeros(shape=[1, 1]), @@ -183,8 +182,6 @@ class VariablesTestCase(test.TestCase, parameterized.TestCase): self.evaluate(var.assign(np.zeros(shape=[2, 2]))) self.assertAllEqual(np.zeros(shape=[2, 2]), var.read_value()) - @test_util.disable_tfrt("GetHostSize() is not expected to be called with " - "string type. b/156761465") def testZeroSizeStringAssign(self): with self.cached_session() as sess: array = variables.VariableV1( diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py index fb4e19da902..a4c9b613068 100644 --- a/tensorflow/python/lib/io/file_io.py +++ b/tensorflow/python/lib/io/file_io.py @@ -347,6 +347,7 @@ def get_matching_files(filename): Raises: * errors.OpError: If there are filesystem / directory listing errors. + * errors.NotFoundError: If pattern to be matched is an invalid directory. """ return get_matching_files_v2(filename) @@ -401,6 +402,7 @@ def get_matching_files_v2(pattern): Raises: errors.OpError: If there are filesystem / directory listing errors. + errors.NotFoundError: If pattern to be matched is an invalid directory. """ if isinstance(pattern, six.string_types): return [ diff --git a/tensorflow/python/lite/toco_python_api_wrapper.cc b/tensorflow/python/lite/toco_python_api_wrapper.cc index b77200a3bee..c5a5f63b2ac 100644 --- a/tensorflow/python/lite/toco_python_api_wrapper.cc +++ b/tensorflow/python/lite/toco_python_api_wrapper.cc @@ -77,4 +77,14 @@ PYBIND11_MODULE(_pywrap_toco_api, m) { R"pbdoc( Returns a sparsified model. )pbdoc"); + m.def( + "RegisterCustomOpdefs", + [](py::object custom_opdefs_txt_raw) { + return tensorflow::PyoOrThrow( + toco::RegisterCustomOpdefs(custom_opdefs_txt_raw.ptr())); + }, + py::arg("custom_opdefs_txt_raw"), + R"pbdoc( + Registers the given custom opdefs to the TensorFlow global op registry. + )pbdoc"); } diff --git a/tensorflow/python/modules_with_exports.py b/tensorflow/python/modules_with_exports.py index 75e67dbb1c6..8d03133d266 100644 --- a/tensorflow/python/modules_with_exports.py +++ b/tensorflow/python/modules_with_exports.py @@ -54,6 +54,10 @@ from tensorflow.python.ops import initializers_ns as initializers from tensorflow.python.util.tf_export import tf_export # _internal APIs +from tensorflow.python.distribute.combinations import generate +from tensorflow.python.distribute.multi_process_runner import * +from tensorflow.python.distribute.multi_worker_test_base import * +from tensorflow.python.distribute.strategy_combinations import * from tensorflow.python.framework.combinations import * from tensorflow.python.framework.composite_tensor import * from tensorflow.python.framework.test_combinations import * diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 7cd227d5b92..67a7c7a0e94 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -3758,6 +3758,23 @@ def _FakeQuantWithMinMaxVarsPerChannelGradient(op, grad): narrow_range=op.get_attr("narrow_range")) +@ops.RegisterGradient("QuantizeAndDequantizeV4") +def _QuantizeAndDequantizeV4Grad(op, grad): + """Gradient for QuantizeAndDequantizeV4 op.""" + return quantize_and_dequantize_v4_grad( + grad, + op.inputs[0], + op.inputs[1], + op.inputs[2], + axis=op.get_attr("axis")) + + +@ops.RegisterGradient("QuantizeAndDequantizeV4Grad") +def _QuantizeAndDequantizeV4GradGrad(op, grad): + """Gradient for QuantizeAndDequantizeV4Grad op.""" + return _QuantizeAndDequantizeV4Grad(op, grad) + + @tf_export("required_space_to_batch_paddings") def required_space_to_batch_paddings(input_shape, block_shape, @@ -5208,6 +5225,293 @@ def batch_gather_nd(params, indices, batch_dims, name=None): return out +@deprecation.deprecated_endpoints("tensor_scatter_update") +@tf_export( + "tensor_scatter_nd_update", + v1=["tensor_scatter_nd_update", "tensor_scatter_update"]) +@dispatch.add_dispatch_support +def tensor_scatter_nd_update(tensor, indices, updates, name=None): + """"Scatter `updates` into an existing tensor according to `indices`. + + This operation creates a new tensor by applying sparse `updates` to the + input `tensor`. This is similar to an index assignment. + + ``` + # Not implemented: tensors cannot be updated inplace. + tensor[indices] = updates + ``` + + If an out of bound index is found on CPU, an error is returned. + + > **WARNING**: There are some GPU specific semantics for this operation. + > + > - If an out of bound index is found, the index is ignored. + > - The order in which updates are applied is nondeterministic, so the output + > will be nondeterministic if `indices` contains duplicates. + + This operation is very similar to `tf.scatter_nd`, except that the updates are + scattered onto an existing tensor (as opposed to a zero-tensor). If the memory + for the existing tensor cannot be re-used, a copy is made and updated. + + In general: + + * `indices` is an integer tensor - the indices to update in `tensor`. + * `indices` has **at least two** axes, the last axis is the depth of the + index vectors. + * For each index vector in `indices` there is a corresponding entry in + `updates`. + * If the length of the index vectors matches the rank of the `tensor`, then + the index vectors each point to scalars in `tensor` and each update is a + scalar. + * If the length of the index vectors is less than the rank of `tensor`, then + the index vectors each point to slices of `tensor` and shape of the updates + must match that slice. + + Overall this leads to the following shape constraints: + + ``` + assert tf.rank(indices) >= 2 + index_depth = indices.shape[-1] + batch_shape = indices.shape[:-1] + assert index_depth <= tf.rank(tensor) + outer_shape = tensor.shape[:index_depth] + inner_shape = tensor.shape[index_depth:] + assert updates.shape == batch_shape + inner_shape + ``` + + Typical usage is often much simpler than this general form, and it + can be better understood starting with simple examples: + + ### Scalar updates + + The simplest usage inserts scalar elements into a tensor by index. + In this case, the `index_depth` must equal the rank of the + input `tensor`, slice each column of `indices` is an index into an axis of the + input `tensor`. + + In this simplest case the shape constraints are: + + ``` + num_updates, index_depth = indices.shape.as_list() + assert updates.shape == [num_updates] + assert index_depth == tf.rank(tensor)` + ``` + + For example, to insert 4 scattered elements in a rank-1 tensor with + 8 elements. + +
+ +
+ + This scatter operation would look like this: + + >>> tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1 + >>> indices = [[1], [3], [4], [7]] # num_updates == 4, index_depth == 1 + >>> updates = [9, 10, 11, 12] # num_updates == 4 + >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates)) + tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32) + + The length (first axis) of `updates` must equal the length of the `indices`: + `num_updates`. This is the the number of updates being inserted. Each + scalar update is inserted into `tensor` at the indexed location. + + For a higher rank input `tensor` scalar updates can be inserted by using an + `index_depth` that matches `tf.rank(tensor)`: + + >>> tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2 + >>> indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2 + >>> updates = [5, 10] # num_updates == 2 + >>> print(tf.tensor_scatter_nd_update(tensor, indices, updates)) + tf.Tensor( + [[ 1 5] + [ 1 1] + [10 1]], shape=(3, 2), dtype=int32) + + ### Slice updates + + When the input `tensor` has more than one axis scatter can be used to update + entire slices. + + In this case it's helpful to think of the input `tensor` as being a two level + array-of-arrays. The shape of this two level array is split into the + `outer_shape` and the `inner_shape`. + + `indices` indexes into the outer level of the input tensor (`outer_shape`). + and replaces the sub-array at that location with the coresponding item from + the `updates` list. The shape of each update is `inner_shape`. + + When updating a list of slices the shape constraints are: + + ``` + num_updates, index_depth = indices.shape.as_list() + inner_shape = tensor.shape[:index_depth] + outer_shape = tensor.shape[index_depth:] + assert updates.shape == [num_updates, inner_shape] + ``` + + For example, to update rows of a `(6, 3)` `tensor`: + + >>> tensor = tf.zeros([6, 3], dtype=tf.int32) + + Use an index depth of one. + + >>> indices = tf.constant([[2], [4]]) # num_updates == 2, index_depth == 1 + >>> num_updates, index_depth = indices.shape.as_list() + + The `outer_shape` is `6`, the inner shape is `3`: + + >>> outer_shape = tensor.shape[:index_depth] + >>> inner_shape = tensor.shape[index_depth:] + + 2 rows are being indexed so 2 `updates` must be supplied. + Each update must be shaped to match the `inner_shape`. + + >>> # num_updates == 2, inner_shape==3 + >>> updates = tf.constant([[1, 2, 3], + ... [4, 5, 6]]) + + Alltogether this gives: + + >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy() + array([[0, 0, 0], + [0, 0, 0], + [1, 2, 3], + [0, 0, 0], + [4, 5, 6], + [0, 0, 0]], dtype=int32) + + #### More slice update examples + + A tensor representing a batch of uniformly sized video clips naturally has 5 + axes: `[batch_size, time, width, height, channels]`. + + For example: + + >>> batch_size, time, width, height, channels = 13,11,7,5,3 + >>> video_batch = tf.zeros([batch_size, time, width, height, channels]) + + To replace a selection of video clips: + * Use an `index_depth` of 1 (indexing the `outer_shape`: `[batch_size]`) + * Provide updates each with a shape matching the `inner_shape`: + `[time, width, height, channels]`. + + To relace the first two clips with ones: + + >>> indices = [[0],[1]] + >>> new_clips = tf.ones([2, time, width, height, channels]) + >>> tf.tensor_scatter_nd_update(video_batch, indices, new_clips) + + To replace a selection of frames in the videos: + + * `indices` must have an `index_depth` of 2 for the `outer_shape`: + `[batch_size, time]`. + * `updates` must be shaped like a list of images. Each update must have a + shape, matching the `inner_shape`: `[width, height, channels]`. + + To replace the first frame of the first three video clips: + + >>> indices = [[0, 0], [1, 0], [2, 0]] # num_updates=3, index_depth=2 + >>> new_images = tf.ones([ + ... # num_updates=3, inner_shape=(width, height, channels) + ... 3, width, height, channels]) + >>> tf.tensor_scatter_nd_update(video_batch, indices, new_images) + + ### Folded indices + + In simple cases it's convienient to think of `indices` and `updates` as + lists, but this is not a strict requirement. Instead of a flat `num_updates`, + the `indices` and `updates` can be folded into a `batch_shape`. This + `batch_shape` is all axes of the `indices`, except for the innermost + `index_depth` axis. + + ``` + index_depth = indices.shape[-1] + batch_shape = indices.shape[:-1] + ``` + + Note: The one exception is that the `batch_shape` cannot be `[]`. You can't + update a single index by passing indices with shape `[index_depth]`. + + `updates` must have a matching `batch_shape` (the axes before `inner_shape`). + + ``` + assert updates.shape == batch_shape + inner_shape + ``` + + Note: The result is equivalent to flattening the `batch_shape` axes of + `indices` and `updates`. This generalization just avoids the need + for reshapes when it is more natural to construct "folded" indices and + updates. + + With this generalization the full shape constraints are: + + ``` + assert tf.rank(indices) >= 2 + index_depth = indices.shape[-1] + batch_shape = indices.shape[:-1] + assert index_depth <= tf.rank(tensor) + outer_shape = tensor.shape[:index_depth] + inner_shape = tensor.shape[index_depth:] + assert updates.shape == batch_shape + inner_shape + ``` + + For example, to draw an `X` on a `(5,5)` matrix start with these indices: + + >>> tensor = tf.zeros([5,5]) + >>> indices = tf.constant([ + ... [[0,0], + ... [1,1], + ... [2,2], + ... [3,3], + ... [4,4]], + ... [[0,4], + ... [1,3], + ... [2,2], + ... [3,1], + ... [4,0]], + ... ]) + >>> indices.shape.as_list() # batch_shape == [2, 5], index_depth == 2 + [2, 5, 2] + + Here the `indices` do not have a shape of `[num_updates, index_depth]`, but a + shape of `batch_shape+[index_depth]`. + + Since the `index_depth` is equal to the rank of `tensor`: + + * `outer_shape` is `(5,5)` + * `inner_shape` is `()` - each update is scalar + * `updates.shape` is `batch_shape + inner_shape == (5,2) + ()` + + >>> updates = [ + ... [1,1,1,1,1], + ... [1,1,1,1,1], + ... ] + + Putting this together gives: + + >>> tf.tensor_scatter_nd_update(tensor, indices, updates).numpy() + array([[1., 0., 0., 0., 1.], + [0., 1., 0., 1., 0.], + [0., 0., 1., 0., 0.], + [0., 1., 0., 1., 0.], + [1., 0., 0., 0., 1.]], dtype=float32) + + Args: + tensor: Tensor to copy/update. + indices: Indices to update. + updates: Updates to apply at the indices. + name: Optional name for the operation. + + Returns: + A new tensor with the given shape and updates applied according to the + indices. + """ + return gen_array_ops.tensor_scatter_update( + tensor=tensor, indices=indices, updates=updates, name=name) + + # Define quantize_v2 here in order to make name the second-to-last attribute, # because round_mode was added later. # (And also now because of 'axis' processing). @@ -5343,6 +5647,13 @@ dequantize.__doc__ = gen_array_ops.dequantize.__doc__ @tf_export("quantization.quantize_and_dequantize") @dispatch.add_dispatch_support +@deprecation.deprecated(None, + "This Op has been deprecated, use" + + "`quantize_and_dequantize_v2` instead. To " + + "To simulate the V1 the behavior of " + + "tf.quantization.quantize_and_dequantize(...) use " + + "tf.grad_pass_through(" + + "tf.quantization.quantize_and_dequantize_v2)(...).") def quantize_and_dequantize( input, # pylint: disable=redefined-builtin input_min, @@ -5401,6 +5712,93 @@ def quantize_and_dequantize( name=name) +@tf_export("quantization.quantize_and_dequantize_v2") +@dispatch.add_dispatch_support +def quantize_and_dequantize_v2( + input, # pylint: disable=redefined-builtin + input_min, + input_max, + signed_input=True, + num_bits=8, + range_given=False, + round_mode="HALF_TO_EVEN", + name=None, + narrow_range=False, + axis=None): + """Quantizes then dequantizes a tensor. + + Updates the gradient definition for quantization that is outside the range to + be 0.To simulate the V1 the behavior of + tf.quantization.quantize_and_dequantize(...) use + tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). + + Example usage: + + ```python + def getQuantizeOp(input): + input_tensor = tf.placeholder(tf.float32, shape=[4, 4]) + net = tf.quantization.quantize_and_dequantize(input, + input_min=min_threshold, + input_max=max_threshold, + range_given=True) + + To simulate v1 behavior: + + def testDecomposeQuantizeDequantize(self): + def f(input_tensor): + return tf.quantization.quantize_and_dequantize_v2(input_tensor, + input_min = 5.0, + input_max= -10.0, + range_given=True) + input_tensor = tf.placeholder(tf.float32, shape=[4, 4]) + net = tf.grad_pass_through(f)(input_tensor) + ``` + + Args: + input: A `Tensor` to quantize and dequantize. + input_min: If range_given=True, the minimum input value, that needs to be + represented in the quantized representation. If axis is specified, this + should be a vector of minimum values for each slice along axis. + input_max: If range_given=True, the maximum input value that needs to be + represented in the quantized representation. If axis is specified, this + should be a vector of maximum values for each slice along axis. + signed_input: True if the quantization is signed or unsigned. + num_bits: The bitwidth of the quantization. + range_given: If true use `input_min` and `input_max` for the range of the + input, otherwise determine min and max from the input `Tensor`. + round_mode: Rounding mode when rounding from float values to quantized ones. + one of ['HALF_TO_EVEN', 'HALF_UP'] + name: Optional name for the operation. + narrow_range: If true, then the absolute value of the quantized minimum + value is the same as the quantized maximum value, instead of 1 greater. + i.e. for 8 bit quantization, the minimum value is -127 instead of -128. + axis: Integer. If specified, refers to a dimension of the input tensor, such + that quantization will be per slice along that dimension. + + Returns: + A `Tensor`. Each element is the result of quantizing and dequantizing the + corresponding element of `input`. + """ + if axis is None: + axis = -1 + elif axis < 0: + if input.shape.ndims is None: + raise ValueError("input should have known rank to use negative axis.") + axis %= input.shape.ndims + + return gen_array_ops.quantize_and_dequantize_v4( + input, + input_min=input_min, + input_max=input_max, + signed_input=signed_input, + num_bits=num_bits, + range_given=range_given, + round_mode=round_mode, + narrow_range=narrow_range, + axis=axis, + name=name) + + @tf_export("searchsorted") @dispatch.add_dispatch_support def searchsorted(sorted_sequence, @@ -5888,7 +6286,7 @@ def _with_nonzero_rank(data): @dispatch.add_dispatch_support def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin """Repeat elements of `input`. - + See also `tf.concat`, `tf.stack`, `tf.tile`. Args: diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 9bc638ac5a2..f920092fd7f 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1151,14 +1151,15 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): ValueError: If static checks determine `x` has wrong rank. """ with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): - x = ops.convert_to_tensor(x, name='x') + if not isinstance(x, sparse_tensor.SparseTensor): + x = ops.convert_to_tensor(x, name='x') rank = ops.convert_to_tensor(rank, name='rank') message = message or '' static_condition = lambda actual_rank, given_rank: actual_rank == given_rank dynamic_condition = math_ops.equal - if context.executing_eagerly(): + if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): name = '' else: name = x.name @@ -1418,11 +1419,12 @@ def assert_rank_in( """ with ops.name_scope( name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): - x = ops.convert_to_tensor(x, name='x') + if not isinstance(x, sparse_tensor.SparseTensor): + x = ops.convert_to_tensor(x, name='x') ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) message = message or '' - if context.executing_eagerly(): + if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): name = '' else: name = x.name @@ -1582,7 +1584,7 @@ def _dimension_sizes(x): rank = x.get_shape().rank rank_is_known = rank is not None if rank_is_known and rank == 0: - return tuple([1]) + return (1,) if rank_is_known and rank > 0: static_shape = x.get_shape().as_list() sizes = [ @@ -1787,14 +1789,14 @@ def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): message = message or '' with ops.name_scope(name, 'assert_shapes', [shapes, data]): # Shape specified as None implies no constraint - shape_constraints = [ - (ops.convert_to_tensor(x), s) for x, s in shapes if s is not None - ] + shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else + ops.convert_to_tensor(x), s) + for x, s in shapes if s is not None] executing_eagerly = context.executing_eagerly() def tensor_name(x): - if executing_eagerly: + if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor): return _shape_and_dtype_str(x) return x.name diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py index 57869157dd8..6afe923d795 100644 --- a/tensorflow/python/ops/collective_ops.py +++ b/tensorflow/python/ops/collective_ops.py @@ -77,7 +77,8 @@ def all_reduce_v2(t, instance_key, merge_op='Add', final_op='Id', - communication_hint='auto'): + communication_hint='auto', + timeout=0): """Reduces tensors collectively, across devices. Args: @@ -94,6 +95,9 @@ def all_reduce_v2(t, communication_hint: preferred collective communication. The implementation may fall back to another mechanism. Options include `auto`, `ring`, and `nccl`. + timeout: a float. If set to a non zero, set a completion timeout to detect + staleness. If the timer goes off, a DeadlineExceededError is raised. The + timeout value in seconds. This feature is experimental. Returns: An Op implementing the distributed reduction. @@ -105,7 +109,8 @@ def all_reduce_v2(t, instance_key=instance_key, merge_op=merge_op, final_op=final_op, - communication_hint=communication_hint.lower()) + communication_hint=communication_hint.lower(), + timeout_seconds=timeout) def all_gather(t, diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py index d3d110f1db2..d3988d0806d 100644 --- a/tensorflow/python/ops/collective_ops_test.py +++ b/tensorflow/python/ops/collective_ops_test.py @@ -166,151 +166,6 @@ class CollectiveOpTest(test.TestCase): elapsed = time.time() - start_time self.assertAllGreaterEqual(elapsed, timeout) - @test_util.run_v2_only - def testCollectiveTimeoutV2(self): - timeout = 4.5 - cpus = config.list_physical_devices('CPU') - self.assertEqual(len(cpus), 1) - config.set_logical_device_configuration(cpus[0], [ - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration() - ]) - context.ensure_initialized() - - @def_function.function - def run_all_reduce(group_size, reported_group_size=None): - group_key = 20 - instance_key = 30 - tensor = [1, 2, 3, 4] - results = [] - if reported_group_size is None: - reported_group_size = group_size - for i in range(group_size): - with ops.device('/CPU:{}'.format(i)): - input_data = constant_op.constant(tensor) - collective_op = collective_ops.all_reduce( - input_data, - group_size=reported_group_size, - group_key=group_key, - instance_key=instance_key, - merge_op='Add', - final_op='Id', - timeout=timeout) - results.append(collective_op) - return results - - run_all_reduce(2, 2) - - start_time = time.time() - with self.assertRaisesRegex(errors.DeadlineExceededError, - 'Collective has timed out during execution'): - run_all_reduce(1, 2) - elapsed = time.time() - start_time - self.assertAllGreaterEqual(elapsed, timeout) - - @test_util.run_v2_only - def testParamResolutionAfterTimeoutV2(self): - timeout = 1.5 - cpus = config.list_physical_devices('CPU') - self.assertEqual(len(cpus), 1) - config.set_logical_device_configuration(cpus[0], [ - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration() - ]) - context.ensure_initialized() - - group_key = 20 - instance_key = 30 - input_data = constant_op.constant([1, 2, 3, 4]) - - # This timeout comes from param solution. - with self.assertRaisesRegex( - errors.DeadlineExceededError, - 'Collective has timed out waiting for other workers'): - with ops.device('CPU:0'): - collective_ops.all_reduce( - input_data, - group_size=2, - group_key=group_key, - instance_key=instance_key, - merge_op='Add', - final_op='Id', - timeout=timeout) - - # We launch the second device after the first device times out. This is to - # simulate the situation when other workers are slow and the timeout is - # short. Since the CPU:0 times out in the param resolution phase, CPU:1 - # should times out as well, but in the execute phase. - with self.assertRaisesRegex(errors.DeadlineExceededError, - 'Collective has timed out during execution'): - with ops.device('CPU:1'): - collective_ops.all_reduce( - input_data, - group_size=2, - group_key=group_key, - instance_key=instance_key, - merge_op='Add', - final_op='Id', - timeout=timeout) - - @test_util.run_v2_only - def testExecutionAfterTimeoutV2(self): - timeout = 1.5 - cpus = config.list_physical_devices('CPU') - self.assertEqual(len(cpus), 1) - config.set_logical_device_configuration(cpus[0], [ - context.LogicalDeviceConfiguration(), - context.LogicalDeviceConfiguration() - ]) - context.ensure_initialized() - - group_key = 20 - instance_key = 30 - input_data = constant_op.constant([1, 2, 3, 4]) - - @def_function.function - def run_all_reduce(): - for device in ['CPU:0', 'CPU:1']: - with ops.device(device): - collective_ops.all_reduce( - input_data, - group_size=2, - group_key=group_key, - instance_key=instance_key, - merge_op='Add', - final_op='Id', - timeout=timeout) - - # Run a normal all-reduce to complete param resolution. - run_all_reduce() - - with self.assertRaisesRegex(errors.DeadlineExceededError, - 'Collective has timed out during execution'): - with ops.device('CPU:0'): - collective_ops.all_reduce( - input_data, - group_size=2, - group_key=group_key, - instance_key=instance_key, - merge_op='Add', - final_op='Id', - timeout=timeout) - - # We launch the second device after the first device times out. This is to - # simulate the situation when other workers are slow and the timeout is - # short. It should error immediately. - with self.assertRaisesRegex(errors.DeadlineExceededError, - 'Collective has timed out during execution'): - with ops.device('CPU:1'): - # No timeout. - collective_ops.all_reduce( - input_data, - group_size=2, - group_key=group_key, - merge_op='Add', - final_op='Id', - instance_key=instance_key) - def testNcclHintFallbackToRingReduce(self): """Tests that setting `communication_hint=nccl` works on non-GPU builds.""" if kernels.get_registered_kernels_for_op('NcclAllReduce'): diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 163f0fb7077..5bdd2494e91 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util @@ -42,6 +43,7 @@ from tensorflow.python.ops import default_gradient from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_functional_ops from tensorflow.python.ops import gradients_util +from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import math_ops from tensorflow.python.util import nest @@ -286,6 +288,7 @@ def _build_cond(pred, # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) + _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, @@ -813,6 +816,32 @@ def _get_output_shapes(*branch_graph_outputs): return output_shapes +def _copy_handle_data(external_tensors, *branch_graph_outputs): + """Combines shapes in handle data and sets metadata on `external_tensors`.""" + for tensors in zip(external_tensors, *branch_graph_outputs): + external = tensors[0] + internal = tensors[1:] + internal_handle_data = [] + for tensor in internal: + handle_data = handle_data_util.get_resource_handle_data(tensor) + # NOTE: Assumes handle data has only one ShapeAndType entry. It's + # unclear how to combine different lengths across branches. + if not handle_data.is_set or len(handle_data.shape_and_type) != 1: + break + internal_handle_data.append(handle_data) + else: # There is handle data, so we need to combine it. + combined_shape = tensor_shape.TensorShape(None) + for handle_data in internal_handle_data: + handle_shape = tensor_shape.TensorShape( + handle_data.shape_and_type[0].shape) + combined_shape = combined_shape.most_specific_compatible_shape( + handle_shape) + combined_handle_data = internal_handle_data[0] + combined_handle_data.shape_and_type[0].shape.CopyFrom( + combined_shape.as_proto()) + handle_data_util.set_handle_data(external, combined_handle_data) + + def verify_captures(op_type, branch_graphs): """Verify that a branch's tensor is not accessed in another branch fn.""" # Note: It is technically not possible for lower-branch_index branches to @@ -1143,6 +1172,7 @@ def _build_case(branch_index, # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) + _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index 4fc464c545c..0bbee785641 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -244,8 +244,9 @@ def _is_tpu_strategy(strategy): def _register_keras_layer_context_function(func): global _KERAS_LAYER_CONTEXT_FUNCTION - if _KERAS_LAYER_CONTEXT_FUNCTION is None: - _KERAS_LAYER_CONTEXT_FUNCTION = func + # TODO(scottzhu): Disable duplicated inject once keras is moved to + # third_party/py/keras. + _KERAS_LAYER_CONTEXT_FUNCTION = func def _is_building_keras_layer(): diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index f081f036b58..3e38f68a0f7 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.client import pywrap_tf_session from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape as tape_lib @@ -25,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import op_selector from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -42,47 +42,8 @@ VAR_OP_TYPES = [ ] -def copy_handle_data(source_t, target_t): - """Copies HandleData for variant and resource type tensors if available. - - The CppShapeInferenceResult::HandleData proto contains information about the - shapes and types of the element tensors of resource/variant type tensors. - We need to copy this across function boundaries, i.e., when capturing a - placeholder or when returning a function tensor as output. If we don't do this - the element tensors will have unknown shapes, e.g., if a TensorList variant - tensor is captured as a placeholder, elements popped from that list would have - unknown shape. - - Args: - source_t: The tensor to copy HandleData from. - target_t: The tensor to copy HandleData to. - """ - if (target_t.dtype == dtypes.resource or - target_t.dtype == dtypes.variant): - if isinstance(source_t, ops.EagerTensor): - handle_data = source_t._handle_data # pylint: disable=protected-access - else: - handle_data = resource_variable_ops.get_resource_handle_data(source_t) - if (handle_data is not None - and handle_data.is_set - and handle_data.shape_and_type): - # pylint: disable=protected-access - pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph, - target_t._as_tf_output(), - handle_data.SerializeToString()) - # pylint: enable=protected-access - # Ensure that shapes and dtypes are propagated. - shapes, types = zip(*[(pair.shape, pair.dtype) - for pair in handle_data.shape_and_type]) - ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension - if not s.unknown_rank else None for s in shapes] - pywrap_tf_session.TF_GraphSetOutputHandleShapesAndTypes_wrapper( - target_t._op._graph._c_graph, # pylint: disable=protected-access - target_t._as_tf_output(), # pylint: disable=protected-access - shapes, - ranks, - types) +# TODO(allenl): Remove this alias and migrate callers. +copy_handle_data = handle_data_util.copy_handle_data @tf_export("custom_gradient") diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 9784650d2d9..4d4df0ffa48 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -333,7 +333,7 @@ def _MaybeCompile(scope, op, func, grad_fn): "_XlaSeparateCompiledGradients") xla_scope = op.get_attr("_XlaScope").decode() except ValueError: - return grad_fn() # Exit early + xla_compile = False if not xla_compile: return grad_fn() # Exit early diff --git a/tensorflow/python/ops/handle_data_util.py b/tensorflow/python/ops/handle_data_util.py new file mode 100644 index 00000000000..d83bea3cb18 --- /dev/null +++ b/tensorflow/python/ops/handle_data_util.py @@ -0,0 +1,73 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Decorator to overrides the gradient for a function.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import pywrap_tf_session +from tensorflow.python.framework import cpp_shape_inference_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.util import compat + + +def get_resource_handle_data(graph_op): + assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck + + handle_data = pywrap_tf_session.GetHandleShapeAndType( + graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access + + return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( + compat.as_bytes(handle_data)) + + +def copy_handle_data(source_t, target_t): + """Copies HandleData for variant and resource type tensors if available. + + The CppShapeInferenceResult::HandleData proto contains information about the + shapes and types of the element tensors of resource/variant type tensors. + We need to copy this across function boundaries, i.e., when capturing a + placeholder or when returning a function tensor as output. If we don't do this + the element tensors will have unknown shapes, e.g., if a TensorList variant + tensor is captured as a placeholder, elements popped from that list would have + unknown shape. + + Args: + source_t: The tensor to copy HandleData from. + target_t: The tensor to copy HandleData to. + """ + if (target_t.dtype == dtypes.resource or + target_t.dtype == dtypes.variant): + if isinstance(source_t, ops.EagerTensor): + handle_data = source_t._handle_data # pylint: disable=protected-access + else: + handle_data = get_resource_handle_data(source_t) + if (handle_data is not None + and handle_data.is_set + and handle_data.shape_and_type): + set_handle_data(target_t, handle_data) + + +def set_handle_data(target_t, handle_data): + # pylint: disable=protected-access + if isinstance(target_t, ops.EagerTensor): + target_t._handle_data = handle_data + return + pywrap_tf_session.SetHandleShapeAndType(target_t.graph._c_graph, + target_t._as_tf_output(), + handle_data.SerializeToString()) + # pylint: enable=protected-access diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 53676b29fce..8a767597f00 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -5536,9 +5536,31 @@ def generate_bounding_box_proposals(scores, name=None): """Generate bounding box proposals from encoded bounding boxes. + Args: + scores: A 4-D float `Tensor` of shape + `[num_images, height, width, num_achors]` containing scores of + the boxes for given anchors, can be unsorted. + bbox_deltas: A 4-D float `Tensor` of shape + `[num_images, height, width, 4 x num_anchors]` encoding boxes + with respect to each anchor. Coordinates are given + in the form `[dy, dx, dh, dw]`. + image_info: A 2-D float `Tensor` of shape `[num_images, 5]` + containing image information Height, Width, Scale. + anchors: A 2-D float `Tensor` of shape `[num_anchors, 4]` + describing the anchor boxes. + Boxes are formatted in the form `[y1, x1, y2, x2]`. + nms_threshold: A scalar float `Tensor` for non-maximal-suppression + threshold. Defaults to 0.7. + pre_nms_topn: A scalar int `Tensor` for the number of + top scoring boxes to be used as input. Defaults to 6000. + min_size: A scalar float `Tensor`. Any box that has a smaller size + than min_size will be discarded. Defaults to 16. + post_nms_topn: An integer. Maximum number of rois in the output. + name: A name for this operation (optional). + Returns: rois: Region of interest boxes sorted by their scores. - roi_probabilities: scores of the ROI boxes in the ROIs' tensor. + roi_probabilities: scores of the ROI boxes in the ROIs' `Tensor`. """ return gen_image_ops.generate_bounding_box_proposals( scores=scores, diff --git a/tensorflow/python/ops/linalg/registrations_util.py b/tensorflow/python/ops/linalg/registrations_util.py index c707a67d43c..d19094132e5 100644 --- a/tensorflow/python/ops/linalg/registrations_util.py +++ b/tensorflow/python/ops/linalg/registrations_util.py @@ -56,7 +56,7 @@ def is_square(operator_a, operator_b): return m == l if (operator_a.is_square != operator_b.is_square) and ( - operator_a.is_square is not None and operator_a.is_square is not None): + operator_a.is_square is not None and operator_b.is_square is not None): return False return None diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 40f03491e6d..847d144bde3 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -524,6 +524,7 @@ def _QrGrad(op, dq, dr): return _QrGradSquareAndDeepMatrices(q, r, dq, dr) # Partition a = [x, y], r = [u, v] and reduce to the square case + # The methodology is explained in detail in https://arxiv.org/abs/2009.10071 a = op.inputs[0] y = a[..., :, num_rows:] u = r[..., :, :num_rows] diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 966706a854f..90b4abbdaf7 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -151,6 +151,9 @@ def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): def cholesky_solve(chol, rhs, name=None): """Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations. + Specifically, returns `X` from `A X = RHS`, where `A = L L^T`, `L` is the + `chol` arg and `RHS` is the `rhs` arg. + ```python # Solve 10 separate 2x2 linear systems: A = ... # shape 10 x 2 x 2 diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py index 3e7c116ec97..8379a26a260 100644 --- a/tensorflow/python/ops/list_ops.py +++ b/tensorflow/python/ops/list_ops.py @@ -19,11 +19,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_list_ops +from tensorflow.python.ops import handle_data_util # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_list_ops import * @@ -56,19 +60,45 @@ def empty_tensor_list(element_shape, name=name) +def _set_handle_data(list_handle, element_shape, element_dtype): + """Sets type information on `list_handle` for consistency with graphs.""" + # TODO(b/169968286): It would be better if we had a consistent story for + # creating handle data from eager operations (shared with VarHandleOp). + if isinstance(list_handle, ops.EagerTensor): + if tensor_util.is_tensor(element_shape): + element_shape = tensor_shape.TensorShape(None) + elif not isinstance(element_shape, tensor_shape.TensorShape): + element_shape = tensor_shape.TensorShape(element_shape) + handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() + handle_data.is_set = True + handle_data.shape_and_type.append( + cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( + shape=element_shape.as_proto(), + dtype=element_dtype.as_datatype_enum, + specialized_type=types_pb2.ST_TENSOR_LIST)) + list_handle._handle_data = handle_data # pylint: disable=protected-access + + def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): - return gen_list_ops.tensor_list_reserve( + result = gen_list_ops.tensor_list_reserve( element_shape=_build_element_shape(element_shape), num_elements=num_elements, element_dtype=element_dtype, name=name) + # TODO(b/169968286): gen_ops needs to ensure the metadata is properly + # populated for eager operations. + _set_handle_data(result, element_shape, element_dtype) + return result def tensor_list_from_tensor(tensor, element_shape, name=None): - return gen_list_ops.tensor_list_from_tensor( + tensor = ops.convert_to_tensor(tensor) + result = gen_list_ops.tensor_list_from_tensor( tensor=tensor, element_shape=_build_element_shape(element_shape), name=name) + _set_handle_data(result, tensor.shape, tensor.dtype) + return result def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None, @@ -107,16 +137,22 @@ def tensor_list_scatter(tensor, element_shape=None, input_handle=None, name=None): + """Returns a TensorList created or updated by scattering `tensor`.""" + tensor = ops.convert_to_tensor(tensor) if input_handle is not None: - return gen_list_ops.tensor_list_scatter_into_existing_list( + output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( input_handle=input_handle, tensor=tensor, indices=indices, name=name) + handle_data_util.copy_handle_data(input_handle, output_handle) + return output_handle else: - return gen_list_ops.tensor_list_scatter_v2( + output_handle = gen_list_ops.tensor_list_scatter_v2( tensor=tensor, indices=indices, element_shape=_build_element_shape(element_shape), num_elements=-1, name=name) + _set_handle_data(output_handle, element_shape, tensor.dtype) + return output_handle def tensor_list_stack(input_handle, @@ -167,8 +203,10 @@ def tensor_list_set_item(input_handle, lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda input_handle, index + 1), lambda: input_handle) - return gen_list_ops.tensor_list_set_item( + output_handle = gen_list_ops.tensor_list_set_item( input_handle=input_handle, index=index, item=item, name=name) + handle_data_util.copy_handle_data(input_handle, output_handle) + return output_handle @ops.RegisterGradient("TensorListPushBack") diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index 02fce277690..96e24d04362 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections as py_collections import os import pprint import random @@ -305,8 +306,21 @@ def print_v2(*inputs, **kwargs): # printed input. templates = [] tensors = [] + # If an input to the print function is of type `OrderedDict`, sort its + # elements by the keys for consistency with the ordering of `nest.flatten`. + # This is not needed for `dict` types because `pprint.pformat()` takes care + # of printing the template in a sorted fashion. + inputs_ordered_dicts_sorted = [] + for input_ in inputs: + if isinstance(input_, py_collections.OrderedDict): + inputs_ordered_dicts_sorted.append( + py_collections.OrderedDict(sorted(input_.items()))) + else: + inputs_ordered_dicts_sorted.append(input_) tensor_free_structure = nest.map_structure( - lambda x: "" if tensor_util.is_tensor(x) else x, inputs) + lambda x: "" if tensor_util.is_tensor(x) else x, + inputs_ordered_dicts_sorted) + tensor_free_template = " ".join( pprint.pformat(x) for x in tensor_free_structure) placeholder = _generate_placeholder_string(tensor_free_template) diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index e53629250c9..f99102fee52 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -146,6 +146,10 @@ class LookupInterface(trackable.TrackableResource): """Looks up `keys` in a table, outputs the corresponding values.""" raise NotImplementedError + def __getitem__(self, keys): + """Looks up `keys` in a table, outputs the corresponding values.""" + return self.lookup(keys) + class InitializableLookupTableBase(LookupInterface): """Initializable lookup table interface. @@ -255,14 +259,28 @@ class StaticHashTable(InitializableLookupTableBase): Example usage: - ```python - keys_tensor = tf.constant([1, 2]) - vals_tensor = tf.constant([3, 4]) - input_tensor = tf.constant([1, 5]) - table = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1) - print(table.lookup(input_tensor)) - ``` + >>> keys_tensor = tf.constant(['a', 'b', 'c']) + >>> vals_tensor = tf.constant([7, 8, 9]) + >>> input_tensor = tf.constant(['a', 'f']) + >>> table = tf.lookup.StaticHashTable( + ... tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), + ... default_value=-1) + >>> table.lookup(input_tensor).numpy() + array([ 7, -1], dtype=int32) + + Or for more pythonic code: + + >>> table[input_tensor].numpy() + array([ 7, -1], dtype=int32) + + The result of a lookup operation has the same shape as the argument: + + >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']]) + >>> table[input_tensor].numpy() + array([[ 7, 8], + [ 9, -1]], dtype=int32) + + """ def __init__(self, initializer, default_value, name=None): @@ -422,16 +440,15 @@ class DatasetInitializer(TableInitializerBase): """Creates a table initializer from a `tf.data.Dataset`. Sample usage: - ```python - keys = tf.data.Dataset.range(100) - values = tf.data.Dataset.range(100).map( - lambda x: string_ops.as_string(x * 2)) - ds = tf.data.Dataset.zip((keys, values)) - init = tf.lookup.experimental.DatasetInitializer(ds) - table = tf.lookup.StaticHashTable(init, "") - output = table.lookup([0, 1, 2]) - assertEquals(outputs, ["0", "2", "4"]) - ``` + + >>> keys = tf.data.Dataset.range(100) + >>> values = tf.data.Dataset.range(100).map( + ... lambda x: string_ops.as_string(x * 2)) + >>> ds = tf.data.Dataset.zip((keys, values)) + >>> init = tf.lookup.experimental.DatasetInitializer(ds) + >>> table = tf.lookup.StaticHashTable(init, "") + >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() + array([b'0', b'2', b'4'], dtype=object) Attributes: dataset: A `tf.data.Dataset` object that produces tuples of scalars. The @@ -479,7 +496,19 @@ class DatasetInitializer(TableInitializerBase): @tf_export("lookup.KeyValueTensorInitializer") class KeyValueTensorInitializer(TableInitializerBase): - """Table initializers given `keys` and `values` tensors.""" + """Table initializers given `keys` and `values` tensors. + + >>> keys_tensor = tf.constant(['a', 'b', 'c']) + >>> vals_tensor = tf.constant([7, 8, 9]) + >>> input_tensor = tf.constant(['a', 'f']) + >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor) + >>> table = tf.lookup.StaticHashTable( + ... init, + ... default_value=-1) + >>> table.lookup(input_tensor).numpy() + array([ 7, -1], dtype=int32) + + """ def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None): """Constructs a table initializer object based on keys and values tensors. @@ -537,7 +566,7 @@ class KeyValueTensorInitializer(TableInitializerBase): class TextFileIndex(object): """The key and value content to get from each line. - This class defines the key and value used for tf.lookup.TextFileInitializer. + This class defines the key and value used for `tf.lookup.TextFileInitializer`. The key and value content to get from each line is specified either by the following, or a value `>=0`. @@ -555,7 +584,7 @@ class TextFileIndex(object): @tf_export("lookup.TextFileInitializer") class TextFileInitializer(TableInitializerBase): - """Table initializers from a text file. + r"""Table initializers from a text file. This initializer assigns one entry in the table for each line in the file. @@ -574,11 +603,11 @@ class TextFileInitializer(TableInitializerBase): For example if we have a file with the following content: - ``` - emerson 10 - lake 20 - palmer 30 - ``` + >>> import tempfile + >>> f = tempfile.NamedTemporaryFile(delete=False) + >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",]) + >>> f.file.write(content.encode('utf-8')) + >>> f.file.close() The following snippet initializes a table with the first column as keys and second column as values: @@ -587,12 +616,13 @@ class TextFileInitializer(TableInitializerBase): * `lake -> 20` * `palmer -> 30` - ```python - table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer( - "test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1) - ... - table.init.run() - ``` + >>> init= tf.lookup.TextFileInitializer( + ... filename=f.name, + ... key_dtype=tf.string, key_index=0, + ... value_dtype=tf.int64, value_index=1, + ... delimiter=" ") + >>> table = tf.lookup.StaticHashTable(init, default_value=-1) + >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy() Similarly to initialize the whole line as keys and the line number as values. @@ -600,13 +630,14 @@ class TextFileInitializer(TableInitializerBase): * `lake 20 -> 1` * `palmer 30 -> 2` - ```python - table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer( - "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, - tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1) - ... - table.init.run() - ``` + >>> init = tf.lookup.TextFileInitializer( + ... filename=f.name, + ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ... delimiter=" ") + >>> table = tf.lookup.StaticHashTable(init, -1) + >>> table.lookup(tf.constant('palmer 30')).numpy() + 2 """ def __init__(self, @@ -1106,45 +1137,53 @@ class IdTableWithHashBuckets(LookupInterface): @tf_export("lookup.StaticVocabularyTable", v1=[]) class StaticVocabularyTable(LookupInterface): - r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets. + r"""String to Id table that assigns out-of-vocabulary keys to hash buckets. For example, if an instance of `StaticVocabularyTable` is initialized with a string-to-id initializer that maps: - * `emerson -> 0` - * `lake -> 1` - * `palmer -> 2` + >>> init = tf.lookup.KeyValueTensorInitializer( + ... keys=tf.constant(['emerson', 'lake', 'palmer']), + ... values=tf.constant([0, 1, 2], dtype=tf.int64)) + >>> table = tf.lookup.StaticVocabularyTable( + ... init, + ... num_oov_buckets=5) The `Vocabulary` object will performs the following mapping: * `emerson -> 0` * `lake -> 1` * `palmer -> 2` - * ` -> bucket_id`, where bucket_id will be between `3` and - `3 + num_oov_buckets - 1`, calculated by: + * ` -> bucket_id`, where `bucket_id` will be between `3` and + `3 + num_oov_buckets - 1 = 7`, calculated by: `hash() % num_oov_buckets + vocab_size` - If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`, - the lookup result is `[0, 1, 2, 4, 7]`. + If input_tensor is: + + >>> input_tensor = tf.constant(["emerson", "lake", "palmer", + ... "king", "crimson"]) + >>> table[input_tensor].numpy() + array([0, 1, 2, 6, 7]) If `initializer` is None, only out-of-vocabulary buckets are used. Example usage: - ```python - num_oov_buckets = 3 - input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"]) - table = tf.lookup.StaticVocabularyTable( - tf.lookup.TextFileInitializer( - filename, - key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER, - delimiter="\t"), - num_oov_buckets) - out = table.lookup(input_tensor). - table.init.run() - print(out.eval()) - ``` + >>> num_oov_buckets = 3 + >>> vocab = ["emerson", "lake", "palmer", "crimnson"] + >>> import tempfile + >>> f = tempfile.NamedTemporaryFile(delete=False) + >>> f.write('\n'.join(vocab).encode('utf-8')) + >>> f.close() + + >>> init = tf.lookup.TextFileInitializer( + ... f.name, + ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER) + >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets) + >>> table.lookup(tf.constant(["palmer", "crimnson" , "king", + ... "tarkus", "black", "moon"])).numpy() + array([2, 3, 5, 6, 6, 4]) The hash function used for generating out-of-vocabulary buckets ID is Fingerprint64. @@ -1158,8 +1197,8 @@ class StaticVocabularyTable(LookupInterface): """Construct a `StaticVocabularyTable` object. Args: - initializer: A TableInitializerBase object that contains the data used to - initialize the table. If None, then we only use out-of-vocab buckets. + initializer: A `TableInitializerBase` object that contains the data used + to initialize the table. If None, then we only use out-of-vocab buckets. num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must be greater than zero. lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to @@ -1926,17 +1965,18 @@ class DenseHashTable(LookupInterface): Example usage: - ```python - table = tf.lookup.DenseHashTable(key_dtype=tf.int64, - value_dtype=tf.int64, - default_value=-1, - empty_key=0, - deleted_key=-1) - - sess.run(table.insert(keys, values)) - out = table.lookup(query_keys) - print(out.eval()) - ``` + >>> table = tf.lookup.experimental.DenseHashTable( + ... key_dtype=tf.string, + ... value_dtype=tf.int64, + ... default_value=-1, + ... empty_key='', + ... deleted_key='$') + >>> keys = tf.constant(['a', 'b', 'c']) + >>> values = tf.constant([0, 1, 2], dtype=tf.int64) + >>> table.insert(keys, values) + >>> table.remove(tf.constant(['c'])) + >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy() + array([ 0, 1, -1, -1]) """ # TODO(andreasst): consider extracting common code with MutableHashTable into diff --git a/tensorflow/python/ops/losses/loss_reduction.py b/tensorflow/python/ops/losses/loss_reduction.py index 829bc2f811e..789a6561bfb 100644 --- a/tensorflow/python/ops/losses/loss_reduction.py +++ b/tensorflow/python/ops/losses/loss_reduction.py @@ -44,7 +44,7 @@ class ReductionV2(object): loss_obj = tf.keras.losses.CategoricalCrossentropy( reduction=tf.keras.losses.Reduction.NONE) .... - loss = tf.reduce_sum(loss_object(labels, predictions)) * + loss = tf.reduce_sum(loss_obj(labels, predictions)) * (1. / global_batch_size) ``` diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index a4b43f11bc5..119d4e95ccf 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -3718,6 +3718,29 @@ def log_sigmoid(x, name=None): Returns: A Tensor with the same type as `x`. + + Usage Example: + + If a positive number is large, then its log_sigmoid will approach to 0 since + the formula will be `y = log( / (1 + ) )` which + approximates to `log (1)` which is 0. + + >>> x = tf.constant([0.0, 1.0, 50.0, 100.0]) + >>> tf.math.log_sigmoid(x) + + + If a negative number is large, its log_sigmoid will approach to the number + itself since the formula will be `y = log( 1 / (1 + ) )` which is + `log (1) - log ( (1 + ) )` which approximates to `- ` + that is the number itself. + + >>> x = tf.constant([-100.0, -50.0, -1.0, 0.0]) + >>> tf.math.log_sigmoid(x) + """ with ops.name_scope(name, "LogSigmoid", [x]) as name: x = ops.convert_to_tensor(x, name="x") diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index 1742a919216..0421829bff3 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -43,14 +43,18 @@ class BatchNormalizationTest(test.TestCase): return math_ops.cast(y, x.dtype) def _inference_ref(self, x, scale, offset, mean, var, epsilon, data_format): - if data_format not in ['NHWC', 'NCHW']: - raise ValueError('data_format must be NCHW or NHWC, ' - 'got %s.' % data_format) + if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']: + raise ValueError('data_format must be NCHW or NHWC for 4D tensors or' + 'NCDHW or NDHWC for 5D tensors, got %s.' % data_format) if data_format == 'NCHW': x = array_ops.transpose(x, [0, 2, 3, 1]) + elif data_format == 'NCDHW': + x = array_ops.transpose(x, [0, 2, 3, 4, 1]) y = self._batch_norm(x, mean, var, offset, scale, epsilon) if data_format == 'NCHW': y = array_ops.transpose(y, [0, 3, 1, 2]) + elif data_format == 'NCDHW': + y = array_ops.transpose(y, [0, 4, 1, 2, 3]) return self.evaluate(y) def _test_inference(self, @@ -102,17 +106,24 @@ class BatchNormalizationTest(test.TestCase): def _training_ref(self, x, scale, offset, old_mean, old_var, exponential_avg_factor, epsilon, data_format): - if data_format not in ['NHWC', 'NCHW']: - raise ValueError('data_format must be NCHW or NHWC, ' - 'got %s.' % data_format) + if data_format not in ['NHWC', 'NCHW', 'NDHWC', 'NCDHW']: + raise ValueError('data_format must be NCHW or NHWC for 4D tensors or' + 'NCDHW or NDHWC for 5D tensors, got %s.' % data_format) + use_4d_tensor = (x.shape.ndims == 4) if data_format == 'NCHW': x = array_ops.transpose(x, [0, 2, 3, 1]) + elif data_format == 'NCDHW': + x = array_ops.transpose(x, [0, 2, 3, 4, 1]) + + mean_axis = [0, 1, 2] if use_4d_tensor else [0, 1, 2, 3] batch_mean, batch_var = nn_impl.moments( - math_ops.cast(x, scale.dtype), [0, 1, 2], keep_dims=False) + math_ops.cast(x, scale.dtype), mean_axis, keep_dims=False) y = self._batch_norm(x, batch_mean, batch_var, offset, scale, epsilon) if data_format == 'NCHW': y = array_ops.transpose(y, [0, 3, 1, 2]) + elif data_format == 'NCDHW': + y = array_ops.transpose(y, [0, 4, 1, 2, 3]) # This is for Bessel's correction. tf.nn.moments uses n, instead of n-1, as # the denominator in the formula to calculate variance, while @@ -377,14 +388,18 @@ class BatchNormalizationTest(test.TestCase): def _runtests(self, x_shape, is_training, gradient_test=False, cpu_only=False): + if len(x_shape) == 4: + data_format_list = ['NHWC', 'NCHW'] + else: + data_format_list = ['NCDHW', 'NDHWC'] use_gpu_vals = [False] if test.is_gpu_available(cuda_only=True) and not cpu_only: use_gpu_vals += [True] factors = [1.0, 0.6] for dtype in [np.float16, np.float32]: for use_gpu in use_gpu_vals: - for data_format in ['NHWC', 'NCHW']: - if data_format == 'NHWC': + for data_format in data_format_list: + if data_format == 'NHWC' or data_format == 'NDHWC': scale_shape = x_shape[-1:] else: scale_shape = x_shape[1:2] @@ -444,6 +459,10 @@ class BatchNormalizationTest(test.TestCase): # GPU kernel doesn't properly handle case where non-channel dimensions are 1 self._runtests(x_shape, False, cpu_only=True) + def testInferenceShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, False) + def testTrainingShape1(self): x_shape = [1, 1, 6, 1] self._runtests(x_shape, True) @@ -465,11 +484,16 @@ class BatchNormalizationTest(test.TestCase): x_shape = [0, 131, 127, 6] self._runtests(x_shape, True) + @test_util.run_deprecated_v1 def testTrainingShape6(self): x_shape = [1, 1, 1, 1] # GPU kernel doesn't properly handle case where non-channel dimensions are 1 self._runtests(x_shape, True, cpu_only=True) + def testTrainingShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, True) + @test_util.run_deprecated_v1 def testBatchNormGradInferenceShape1(self): x_shape = [1, 1, 6, 1] @@ -503,6 +527,11 @@ class BatchNormalizationTest(test.TestCase): self._runtests(x_shape, is_training=False, gradient_test=True, cpu_only=True) + @test_util.run_deprecated_v1 + def testBatchNormGradInferenceShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, is_training=False, gradient_test=True) + @test_util.run_deprecated_v1 def testBatchNormGradTrainingShape1(self): x_shape = [1, 1, 6, 1] @@ -535,42 +564,54 @@ class BatchNormalizationTest(test.TestCase): # GPU kernel doesn't properly handle case where non-channel dimensions are 1 self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True) + @test_util.run_deprecated_v1 + def testBatchNormGradTrainingShape7(self): + x_shape = [1, 2, 6, 1, 3] + self._runtests(x_shape, is_training=True, gradient_test=True) + def _testBatchNormGradGrad(self, config): shape = config['shape'] err_tolerance = config['err_tolerance'] dtype = config['dtype'] + rank = len(shape) + if rank == 4: + data_format_nhwc, features_nhwc = 'NHWC', shape[3] + data_format_nchw, features_nchw = 'NCHW', shape[1] + else: + data_format_nhwc, features_nhwc = 'NDHWC', shape[4] + data_format_nchw, features_nchw = 'NCDHW', shape[1] for is_training in [True, False]: if test.is_gpu_available(cuda_only=True): self._test_grad_grad( shape, - dtype, [shape[3]], + dtype, [features_nhwc], np.float32, use_gpu=True, - data_format='NHWC', + data_format=data_format_nhwc, is_training=is_training, err_tolerance=err_tolerance) self._test_grad_grad( shape, - dtype, [shape[1]], + dtype, [features_nchw], np.float32, use_gpu=True, - data_format='NCHW', + data_format=data_format_nchw, is_training=is_training, err_tolerance=err_tolerance) self._test_grad_grad( shape, - dtype, [shape[3]], + dtype, [features_nhwc], np.float32, use_gpu=False, - data_format='NHWC', + data_format=data_format_nhwc, is_training=is_training, err_tolerance=err_tolerance) self._test_grad_grad( shape, - dtype, [shape[1]], + dtype, [features_nchw], np.float32, use_gpu=False, - data_format='NCHW', + data_format=data_format_nchw, is_training=is_training, err_tolerance=err_tolerance) @@ -610,6 +651,24 @@ class BatchNormalizationTest(test.TestCase): } self._testBatchNormGradGrad(config) + @test_util.run_deprecated_v1 + def testBatchNormGradGradConfig5(self): + config = { + 'shape': [2, 3, 2, 2, 2], + 'err_tolerance': 2e-3, + 'dtype': np.float32, + } + self._testBatchNormGradGrad(config) + + @test_util.run_deprecated_v1 + def testBatchNormGradGradConfig6(self): + config = { + 'shape': [2, 3, 2, 2, 2], + 'err_tolerance': 3e-3, + 'dtype': np.float16, + } + self._testBatchNormGradGrad(config) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 58dd1852cc5..a02e31f80a5 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -897,6 +897,11 @@ def _BaseFusedBatchNormGrad(op, version, *grad): if data_format == b"NCHW": x = array_ops.transpose(x, [0, 2, 3, 1]) grad_y = array_ops.transpose(grad_y, [0, 2, 3, 1]) + elif data_format == b"NCDHW": + x = array_ops.transpose(x, [0, 2, 3, 4, 1]) + grad_y = array_ops.transpose(grad_y, [0, 2, 3, 4, 1]) + target_data_format = ("NHWC" if data_format in (b"NCHW", + b"NHWC") else "NDHWC") args = { "y_backprop": grad_y, "x": x, @@ -904,7 +909,7 @@ def _BaseFusedBatchNormGrad(op, version, *grad): "reserve_space_1": pop_mean, "reserve_space_2": pop_var, "epsilon": epsilon, - "data_format": "NHWC", + "data_format": target_data_format, "is_training": is_training } if version == 2: @@ -912,6 +917,8 @@ def _BaseFusedBatchNormGrad(op, version, *grad): dx, dscale, doffset, _, _ = grad_fun(**args) if data_format == b"NCHW": dx = array_ops.transpose(dx, [0, 3, 1, 2]) + elif data_format == b"NCDHW": + dx = array_ops.transpose(dx, [0, 4, 1, 2, 3]) return dx, dscale, doffset, None, None @@ -941,8 +948,8 @@ def _BatchNormGrad(grad_y, """Returns the gradients for the 3 inputs of BatchNorm. Args: - grad_y: A `Tensor` of 4 dimensions for gradient for y. - x: A `Tensor` of 4 dimensions for x. + grad_y: A `Tensor` of 4 or 5 dimensions for gradient for y. + x: A `Tensor` of 4 or 5 dimensions for x. scale: A `Tensor` of 1 dimension for scaling. pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when is_training=False. @@ -968,11 +975,19 @@ def _BatchNormGrad(grad_y, if data_format == b"NHWC": keepdims = False reduce_axis = [0, 1, 2] - else: + elif data_format == b"NDHWC": + keepdims = False + reduce_axis = [0, 1, 2, 3] + elif data_format == b"NCHW": keepdims = True reduce_axis = [0, 2, 3] shape = [1, array_ops.size(scale), 1, 1] scale = array_ops.reshape(scale, shape) + else: + keepdims = True + reduce_axis = [0, 2, 3, 4] + shape = [1, array_ops.size(scale), 1, 1, 1] + scale = array_ops.reshape(scale, shape) mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims) mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims) var_x = math_ops.reduce_mean( @@ -987,19 +1002,27 @@ def _BatchNormGrad(grad_y, grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset) grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum( grad_y * x_offset, axis=reduce_axis, keepdims=keepdims) - if data_format == b"NCHW": + if data_format == b"NCHW" or data_format == b"NCDHW": grad_scale = array_ops.squeeze(grad_scale) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) return math_ops.cast(grad_x, x_dtype), grad_scale, grad_offset else: if data_format == b"NHWC": reduce_axis = [0, 1, 2] - else: + elif data_format == b"NDHWC": + reduce_axis = [0, 1, 2, 3] + elif data_format == b"NCHW": reduce_axis = [0, 2, 3] shape = [1, array_ops.size(pop_mean), 1, 1] pop_mean = array_ops.reshape(pop_mean, shape) pop_var = array_ops.reshape(pop_var, shape) scale = array_ops.reshape(scale, shape) + else: + reduce_axis = [0, 2, 3, 4] + shape = [1, array_ops.size(pop_mean), 1, 1, 1] + pop_mean = array_ops.reshape(pop_mean, shape) + pop_var = array_ops.reshape(pop_var, shape) + scale = array_ops.reshape(scale, shape) grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis) var_rsqrt = math_ops.rsqrt(pop_var + epsilon) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 89174b29336..d22fbf3fa4e 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -1585,7 +1585,7 @@ def fused_batch_norm( (http://arxiv.org/abs/1502.03167). Args: - x: Input `Tensor` of 4 dimensions. + x: Input `Tensor` of 4 or 5 dimensions. scale: A `Tensor` of 1 dimension for scaling. offset: A `Tensor` of 1 dimension for bias. mean: A `Tensor` of 1 dimension for population mean. The shape and meaning @@ -1611,7 +1611,8 @@ def fused_batch_norm( Variance must be a `Tensor` of the same shape as scale containing the exponential running variance. epsilon: A small float number added to the variance of x. - data_format: The data format for x. Either "NHWC" (default) or "NCHW". + data_format: The data format for x. Support "NHWC" (default) or "NCHW" for + 4D tenors and "NDHWC" or "NCDHW" for 5D tensors. is_training: A bool value to specify if the operation is used for training or inference. name: A name for this operation (optional). @@ -1622,7 +1623,7 @@ def fused_batch_norm( returned. Returns: - y: A 4D Tensor for the normalized, scaled, offsetted x. + y: A 4D or 5D Tensor for the normalized, scaled, offsetted x. running_mean: A 1D Tensor for the exponential running mean of x. The output value is (1 - exponential_avg_factor) * mean + exponential_avg_factor * batch_mean), where batch_mean diff --git a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py index 4f96f9ba6a3..7f150b34bb3 100644 --- a/tensorflow/python/ops/nn_loss_scaling_utilities_test.py +++ b/tensorflow/python/ops/nn_loss_scaling_utilities_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -35,7 +36,7 @@ from tensorflow.python.platform import test as test_lib class LossUtilitiesTest(test_lib.TestCase, parameterized.TestCase): def setUp(self): - strategy_combinations.set_virtual_cpus_to_at_least(3) + test_util.set_logical_devices_to_at_least("CPU", 3) super(LossUtilitiesTest, self).setUp() def testComputeAverageLossGlobalBatchSize(self): diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 9cb23ba72cb..2477fa1e920 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -5402,6 +5402,19 @@ def fractional_max_pool_v2(value, [Graham, 2015](https://arxiv.org/abs/1412.6071) ([pdf](https://arxiv.org/pdf/1412.6071.pdf)) """ + if (isinstance(pooling_ratio, (list, tuple))): + if (pooling_ratio[0] != 1.0 or pooling_ratio[-1] != 1.0): + raise ValueError( + "The first and last elements of pooling ratio must be 1.0.") + for element in pooling_ratio: + if element < 1.0: + raise ValueError("pooling_ratio should be >= 1.0.") + elif (isinstance(pooling_ratio, (int, float))): + if pooling_ratio < 1.0: + raise ValueError("pooling_ratio should be >= 1.0.") + else: + raise ValueError("pooling_ratio should be an int or a list of ints.") + pooling_ratio = _get_sequence(pooling_ratio, 2, 3, "pooling_ratio") if seed == 0: @@ -5603,7 +5616,6 @@ def erosion2d(value, kernel, strides, rates, padding, name=None): Returns: A `Tensor`. Has the same type as `value`. 4-D with shape `[batch, out_height, out_width, depth]`. - Raises: ValueError: If the `value` depth does not match `kernel`' shape, or if padding is other than `'VALID'` or `'SAME'`. diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 956f42c5744..851bfcb66de 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -166,8 +166,6 @@ class LogPoissonLossTest(test_lib.TestCase): return lpl @test_util.run_in_graph_and_eager_modes - @test_util.disable_tfrt("b/163084901: does not support convert ScalarHost " - "tensor to RuntimeFallbackTensor.") def testLogPoissonLoss(self): x_shape = [5, 10] x_np = np.random.randn(*x_shape).astype(np.float32) diff --git a/tensorflow/python/ops/numpy_ops/g3doc/TensorFlow_Numpy_Distributed_Image_Classification.ipynb b/tensorflow/python/ops/numpy_ops/g3doc/TensorFlow_Numpy_Distributed_Image_Classification.ipynb index 35665b8e667..92e05a1f267 100644 --- a/tensorflow/python/ops/numpy_ops/g3doc/TensorFlow_Numpy_Distributed_Image_Classification.ipynb +++ b/tensorflow/python/ops/numpy_ops/g3doc/TensorFlow_Numpy_Distributed_Image_Classification.ipynb @@ -3,7 +3,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "KQALG9h23b0R" }, "source": [ @@ -15,8 +14,6 @@ "execution_count": null, "metadata": { "cellView": "both", - "colab": {}, - "colab_type": "code", "id": "U34SJW0W3dg_" }, "outputs": [], @@ -37,7 +34,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "VIX1XZHJ3gFo" }, "source": [ @@ -47,7 +43,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "f7NApJ7R3ndN" }, "source": [ @@ -61,7 +56,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "IYDdfih63rSG" }, "source": [ @@ -74,8 +68,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "3IlLM-YlTMv5" }, "outputs": [], @@ -88,8 +80,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "U13hRXHKTcsE" }, "outputs": [], @@ -98,6 +88,7 @@ "import functools\n", "import matplotlib.pyplot as plt\n", "import os\n", + "import tempfile\n", "import tensorflow as tf\n", "import tensorflow.experimental.numpy as tnp\n", "import tensorflow_datasets as tfds\n", @@ -121,7 +112,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "AxNuZSqZKcdM" }, "source": [ @@ -136,8 +126,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "yKf9Tm5OjwGK" }, "outputs": [], @@ -172,7 +160,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "ZDJQp4i00qaJ" }, "source": [ @@ -185,13 +172,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "44yzAmBFreyg" }, "outputs": [], "source": [ - "class Dense(object):\n", + "class Dense(tf.Module):\n", "\n", " def __init__(self, units, use_relu=True):\n", " self.wt = None\n", @@ -230,7 +215,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "wfKpg3adUCy9" }, "source": [ @@ -242,13 +226,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "NdrdxKB7SenC" }, "outputs": [], "source": [ - "class Model(object):\n", + "class Model(tf.Module):\n", " \"\"\"A three layer neural network.\"\"\"\n", "\n", " def __init__(self):\n", @@ -269,7 +251,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Hoxh5Z7E_9Pv" }, "source": [ @@ -282,8 +263,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "hOxqjE7rZPdr" }, "outputs": [], @@ -325,7 +304,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "8t70b5d6XCs7" }, "source": [ @@ -336,8 +314,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "HrhS_M6kALeP" }, "outputs": [], @@ -371,7 +347,48 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", + "id": "Dw7RwQmKcYK9" + }, + "source": [ + "#### Saving the models to disk" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rmk2xLQLcXkl" + }, + "outputs": [], + "source": [ + "# A temporary directory to save our models into.\n", + "dir = tempfile.TemporaryDirectory()\n", + "\n", + "# We take our model, and create a wrapper for it.\n", + "class SaveableModel(Model):\n", + " @tf.function\n", + " def __call__(self, inputs):\n", + " return super().__call__(inputs)\n", + "\n", + "saveable_model = SaveableModel()\n", + "\n", + "# This saves a concrete function that we care about.\n", + "outputs = saveable_model(x_test)\n", + "\n", + "# This saves the model to disk.\n", + "tf.saved_model.save(saveable_model, dir.name)\n", + "\n", + "loaded = tf.saved_model.load(dir.name)\n", + "outputs_loaded = loaded(x_test)\n", + "\n", + "# Ensure that the loaded model preserves the weights\n", + "# of the saved model.\n", + "assert tnp.allclose(outputs, outputs_loaded)" + ] + }, + { + "cell_type": "markdown", + "metadata": { "id": "ak_hCOkGXXfl" }, "source": [ @@ -385,7 +402,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "ujbeT5p6Xm7k" }, "source": [ @@ -398,8 +414,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "MZ6hivj-ZIRo" }, "outputs": [], @@ -458,7 +472,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "1ZiN1rpJYHLu" }, "source": [ @@ -469,8 +482,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "A6ZHYmLapunm" }, "outputs": [], diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 16417a34f62..369297bd85a 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -1435,7 +1435,7 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring indices = indices.data rank = array_ops.rank(arr) - axis = array_ops.where_v2(axis < 0, axis + rank, axis) + axis = axis + rank if axis < 0 else axis # Broadcast shapes to match, ensure that the axis of interest is not # broadcast. @@ -1461,7 +1461,7 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data - dont_move_axis_to_end = math_ops.equal(axis, rank - 1) + dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1)) arr = np_utils.cond(dont_move_axis_to_end, lambda: arr, lambda: swapaxes_(arr)) indices = np_utils.cond(dont_move_axis_to_end, lambda: indices, @@ -1730,6 +1730,10 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None): if updates is None: return array_ops.gather_nd(tensor, stacked_indices) else: + # We only need to move-axis `updates` in the contiguous case becausce + # only in this case the result dimensions of advanced indexing are in + # the middle of `updates`. In the non-contiguous case, those dimensions + # are always at the front. if dims_contiguous: # TODO(wangpeng): Support unknown rank (e.g. by partially flattening # `updates`) @@ -1773,10 +1777,11 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None): # do a gather instead. rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access dims = [(x + rank if x < 0 else x) for x in dims] - shape_tensor = array_ops.shape(tensor, out_type=stacked_indices.dtype) + shape_tensor = array_ops.shape(tensor) dim_sizes = array_ops.gather(shape_tensor, dims) if len(dims) == 1: stacked_indices = indices[0] + stacked_indices = math_ops.cast(stacked_indices, dtypes.int32) stacked_indices = array_ops.where_v2(stacked_indices < 0, stacked_indices + dim_sizes, stacked_indices) @@ -1784,8 +1789,12 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None): if len(dims) > 1: index_scaling = math_ops.cumprod( dim_sizes, reverse=True, exclusive=True) - stacked_indices = math_ops.tensordot( - stacked_indices, index_scaling, axes=1) + def _tensordot(a, b): + # TODO(b/168657656): This function should be replaced by + # tensordot(axis=1) once MatMul has int32 XLA kernel. + b = array_ops.broadcast_to(b, array_ops.shape(a)) + return math_ops.reduce_sum(a * b, axis=-1) + stacked_indices = _tensordot(stacked_indices, index_scaling) flat_shape = array_ops.concat( [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]], axis=0) diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index b189ac57bb9..119a944e867 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -22,6 +22,7 @@ py_library( ":gradients", ":test_util", "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index b60bc210e9b..504574385c6 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -450,6 +450,13 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True): results of applying fn to tensors unpacked from elems along the first dimension, from first to last. + Although they are less common as user-visible inputs and outputs, note that + tensors of type `tf.variant` which represent tensor lists (for example from + `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list + contents rather than the variant itself, and so the container tensor will + have a scalar shape when returned rather than the usual stacked shape. This + improves the performance of control flow gradient vectorization. + Raises: ValueError: If vectorization fails and fallback_to_while_loop is False. """ diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 96d41f3b359..3a0c6cf1a14 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -44,8 +44,10 @@ from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_list_ops from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients as gradient_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import list_ops @@ -1073,6 +1075,29 @@ class TensorListTest(PForTestCase): if not v2_enabled: control_flow_v2_toggles.disable_control_flow_v2() + def test_tensor_list_addn_already_stacked(self): + + def loop_fn(i): + l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + l1 = list_ops.tensor_list_set_item(l1, 0, i) + l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + l2 = list_ops.tensor_list_set_item(l2, 1, i) + return list_ops.tensor_list_stack(math_ops.add_n([l1, l2]), dtypes.int32) + + self._test_loop_fn(loop_fn, 2) + + def test_tensor_list_addn_stacking_required(self): + l1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + l1 = list_ops.tensor_list_set_item(l1, 1, 1) + + def loop_fn(i): + l2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + l2 = list_ops.tensor_list_set_item(l2, 1, i) + return list_ops.tensor_list_stack( + math_ops.add_n([l1, l2]), dtypes.int32) + + self._test_loop_fn(loop_fn, 2) + class StackTest(PForTestCase): @@ -1542,6 +1567,23 @@ class WhileV2Test(PForTestCase): y = constant_op.constant(np.random.uniform(size=(3, 3))) self.assertAllClose(_f(x, y, True), _f(x, y, False)) + def test_scan(self): + np.random.seed(seed=42) + data = np.random.randn(3).astype(np.float32) + + def log_prob(x): + return math_ops.reduce_sum(functional_ops.scan_v2( + lambda _, yi: (x - yi)**2, + elems=data, + initializer=constant_op.constant(0.))) + + x = variables.Variable(array_ops.ones([2])) + self.evaluate(x.initializer) + v_log_prob = lambda x: pfor_control_flow_ops.vectorized_map(log_prob, x) + theoretical, numerical = gradient_checker_v2.compute_gradient( + v_log_prob, (x,), delta=1e-3) + self.assertAllClose(theoretical, numerical, rtol=1e-2) + @test_util.run_all_in_graph_and_eager_modes class NestedControlFlowTest(PForTestCase): diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 3c6b9c0d756..7e460176c61 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -28,6 +28,7 @@ import numpy as np import six from tensorflow.compiler.tf2xla.python import xla +from tensorflow.core.framework import types_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import execute @@ -42,6 +43,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_image_ops @@ -59,6 +61,7 @@ from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import tensor_array_ops @@ -75,6 +78,29 @@ flags.DEFINE_bool( "DEPRECATED: Flag is ignored.") +def _variant_handle_data(t): + """Fetches handle data for a variant tensor `t`, or None if unavailable.""" + handle_data = resource_variable_ops.get_eager_safe_handle_data(t) + if not handle_data.is_set: + return None + if len(handle_data.shape_and_type) != 1: + raise ValueError("Expected handle data of length 1, got {!r} of length {}" + .format(handle_data, len(handle_data.shape_and_type))) + return handle_data.shape_and_type[0] + + +def _is_tensor_list(t): + """True if `t` is a TensorList, False if it isn't, None if unknown.""" + if t.dtype != dtypes.variant: + return False + shape_and_type = _variant_handle_data(t) + if shape_and_type is None: + # TODO(b/169968286): Identify all variant tensors (e.g. optionals) and we + # can make this an error instead of assuming TensorLists have handle data. + return None # Presumed not a TensorList + return shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST + + def _stack(t, length): """stacks `t` `length` times.""" # Note that this stacking may currently be triggered, for example, when a @@ -82,14 +108,24 @@ def _stack(t, length): # produces a loop dependent output. Simply stacking the variants may not be # suitable since operations on stacked handles may expect a vectorized version # of the variant. - # Given that variant types are generic, we are currently unable to figure out - # which particular variant type is being considered here and hence it may not - # be safe to allow stacking it. if t.dtype == dtypes.variant: - raise NotImplementedError( - "Vectorization tried to stack variant tensor %s. " - "This is likely because vectorization of that variant " - "is not fully supported yet." % t) + shape_and_type = _variant_handle_data(t) + if shape_and_type is None: + raise ValueError("Required handle data not set for {!r}".format(t)) + if shape_and_type.specialized_type == types_pb2.ST_TENSOR_LIST: + return wrap( + _stack_tensor_list(t, shape_and_type.dtype, length), + True) + else: + if shape_and_type.specialized_type != types_pb2.ST_INVALID: + raise ValueError( + ("Attempted to stack an unhandled variant-dtype tensor of " + "type {!r} ({!r})").format( + shape_and_type.specialized_type, t)) + else: + raise ValueError( + "Attempted to stack a variant-dtype tensor with no type set ({!r})" + .format(t)) ones = array_ops.ones_like(array_ops.shape(t)) ones = array_ops.reshape(ones, [-1]) length = array_ops.reshape(length, [-1]) @@ -735,20 +771,30 @@ class _PforInput(object): self._op = op self._inputs = inputs - def stack_inputs(self, stack_indices=None): + def stack_inputs(self, stack_indices=None, tile_variants=False): """Stacks unstacked inputs at `stack_indices`. Args: stack_indices: indices of inputs at which stacking is done. If None, stacking is done at all indices. + tile_variants: If True, affected indices which have a variant dtype will + be tiled after this operation to match the expected shape of a + vectorized tensor. Variants generally need to be un-tiled when they are + inputs to operations and tiled when returned. """ if stack_indices is None: stack_indices = range(len(self._inputs)) length = self.pfor.loop_len_vector for i in stack_indices: inp = self._inputs[i] + is_variant = inp.t.dtype == dtypes.variant if not inp.is_stacked: self._inputs[i] = _stack(inp.t, length) + if tile_variants and is_variant: + self._inputs[i] = wrap( + _tile_variant_with_length(self._inputs[i].t, length), True) + elif not tile_variants and is_variant: + self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True) def expanddim_inputs_for_broadcast(self): """Reshapes stacked inputs to prepare them for broadcast. @@ -994,6 +1040,12 @@ def wrap(tensor, is_stacked=True, is_sparse_stacked=False): return WrappedTensor(tensor, is_stacked, is_sparse_stacked) +def _wrap_and_tile_variants(tensor, length): + if tensor.dtype == dtypes.variant: + tensor = _tile_variant_with_length(tensor, length) + return wrap(tensor) + + def _fallback_converter(pfor_input, warn=True): if warn: logging.warn("Using a while_loop for converting %s", pfor_input.op_type) @@ -1496,7 +1548,10 @@ class PFor(object): if y is y_op: new_outputs = new_op else: - new_outputs = [wrap(x, False) for x in new_op.outputs] + new_outputs = [] + for old_output, new_output in zip(y_op.outputs, new_op.outputs): + custom_gradient.copy_handle_data(old_output, new_output) + new_outputs.append(wrap(new_output, False)) else: # Either some inputs are not loop invariant or op is stateful. if hasattr(y_op, "pfor_converter"): @@ -1570,7 +1625,10 @@ class PFor(object): else: batch_dim = tensor_shape.TensorShape(loop_len) output_shape = batch_dim.concatenate(output_shape) - new_output.t.set_shape(output_shape) + if _is_tensor_list(new_output.t): + new_output.t.set_shape([]) + else: + new_output.t.set_shape(output_shape) self._add_conversion(old_output, new_output) stack.pop(0) @@ -2888,8 +2946,10 @@ def _convert_rank(pfor_input): @RegisterPFor("AddN") def _convert_addn(pfor_input): # AddN does not support broadcasting. - pfor_input.stack_inputs() - return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True) + pfor_input.stack_inputs(tile_variants=False) + return _wrap_and_tile_variants( + math_ops.add_n([x.t for x in pfor_input.inputs]), + pfor_input.pfor.loop_len_vector) @RegisterPFor("Cross") @@ -3518,8 +3578,7 @@ def _convert_tensor_array_grad_v3(pfor_input): return [wrap(grad_handle, False), wrap(flow_out, True)] -def _stack_tensor_list_shape(shape, pfor_input): - first_dim = pfor_input.pfor.loop_len_vector +def _stack_tensor_list_shape(shape, first_dim): shape_value = tensor_util.constant_value(shape) # Note that negative values in the shape are used to signify unknown shapes # and are handled in a special way. @@ -3537,15 +3596,35 @@ def _stack_tensor_list_shape(shape, pfor_input): lambda: array_ops.concat([first_dim, shape], axis=0)) -def _tile_variant(t, pfor_input): +def _tile_variant_with_length(t, length): """stacks `t` `length` times.""" + if _is_tensor_list(t): + # The content of TensorLists is vectorized, not the variant itself. + return t + original_tensor = t t.set_shape([]) t = array_ops.reshape(t, [-1]) with ops.device("CPU:0"): - return array_ops.tile(t, pfor_input.pfor.loop_len_vector) + result = array_ops.tile(t, length) + # TODO(b/169968286): Should regular shape functions do handle data + # propagation here? + custom_gradient.copy_handle_data(original_tensor, result) + return result + + +def _tile_variant(t, pfor_input): + """stacks `t` according to its loop context.""" + return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector) def _untile_variant(t): + if _is_tensor_list(t): + # The content of TensorLists is vectorized, not the variant itself. + if not t.shape.is_compatible_with([]): + raise AssertionError( + "Unexpectedly saw a TensorList with non-scalar shape: {!r}" + .format(t)) + return t return array_ops.gather(t, 0) @@ -3556,7 +3635,8 @@ def _convert_tensor_list_reserve(pfor_input): element_dtype = pfor_input.get_attr("element_dtype") # Prepend a dimension to element_shape. - element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) handle = list_ops.tensor_list_reserve( element_shape, num_elements, element_dtype=element_dtype) @@ -3579,16 +3659,16 @@ def _convert_tensor_list_length(pfor_input): return wrap(list_ops.tensor_list_length(handle), False) -def _stack_tensor_list(handle, dtype, pfor_input, element_shape=None): +def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None): if element_shape is None: element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32) length = list_ops.tensor_list_length(handle) new_handle = list_ops.tensor_list_reserve( - _stack_tensor_list_shape(element_shape, pfor_input), length, dtype) + _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype) def _body_fn(i, h): elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape) - elem = _stack(elem, pfor_input.pfor.loop_len_vector).t + elem = _stack(elem, loop_len_vector).t return i + 1, list_ops.tensor_list_set_item(h, i, elem) return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn, @@ -3604,7 +3684,8 @@ def _convert_tensor_list_get_item(pfor_input): if handle_stacked: handle = _untile_variant(handle) - element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) if index_stacked: # We use a sequential loop since that may be more efficient than first # gathering and concatenating all the element corresponding to `index`, @@ -3650,7 +3731,8 @@ def _convert_tensor_array_set_item(pfor_input): return wrap( list_ops.tensor_list_scatter(item, index, input_handle=handle), False) else: - handle = _stack_tensor_list(handle, item.dtype, pfor_input) + handle = _stack_tensor_list(handle, item.dtype, + pfor_input.pfor.loop_len_vector) else: handle = _untile_variant(handle) @@ -3714,7 +3796,8 @@ def _convert_tensor_list_stack(pfor_input): num_elements = pfor_input.get_attr("num_elements") handle = _untile_variant(handle) - input_shape = _stack_tensor_list_shape(input_shape, pfor_input) + input_shape = _stack_tensor_list_shape(input_shape, + pfor_input.pfor.loop_len_vector) output = list_ops.tensor_list_stack( handle, element_dtype, @@ -3733,7 +3816,8 @@ def _convert_tensor_list_gather(pfor_input): if handle_stacked: handle = _untile_variant(handle) - element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) if index_stacked: # We use a sequential loop since that may be more efficient than first # gathering and concatenating all the element corresponding to `index`, @@ -3776,7 +3860,8 @@ def _convert_tensor_list_scatter(pfor_input): if handle_stacked: handle = _untile_variant(handle) else: - handle = _stack_tensor_list(handle, item.dtype, pfor_input) + handle = _stack_tensor_list(handle, item.dtype, + pfor_input.pfor.loop_len_vector) item = _transpose_first_two_dims(item) handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle) @@ -3788,7 +3873,8 @@ def _convert_tensor_list_from_tensor(pfor_input): tensor = pfor_input.stacked_input(0) element_shape = pfor_input.unstacked_input(1) tensor = _transpose_first_two_dims(tensor) - element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) handle = list_ops.tensor_list_from_tensor(tensor, element_shape) return wrap(_tile_variant(handle, pfor_input), True) @@ -4147,8 +4233,12 @@ class WhileV2(object): shapes = [tensor_shape.TensorShape(shape) for shape in shapes] for i, shape in enumerate(shapes): shape = shape.merge_with(output_shapes[i]) - if self._pfor_input.input(i).is_stacked: - shape = tensor_shape.TensorShape([None]).concatenate(shape) + pfor_input = self._pfor_input.input(i) + if pfor_input.is_stacked: + if _is_tensor_list(pfor_input.t): + shape = tensor_shape.TensorShape([]).concatenate(shape) + else: + shape = tensor_shape.TensorShape([None]).concatenate(shape) output_shapes[i] = shape assert len(output_shapes) == self._pfor_input.num_inputs return output_shapes @@ -4258,7 +4348,12 @@ class WhileV2(object): if out.is_stacked != inp.is_stacked: stacking_mismatch = True mismatching_stacked_indices.append(i) - wrapped_inputs[i] = _stack(inp.t, [array_ops.size(new_indices)]) + stacked = _stack(inp.t, [array_ops.size(new_indices)]) + if inp.t.dtype == dtypes.variant: + stacked = wrap( + _tile_variant_with_length(stacked.t, + [array_ops.size(new_indices)])) + wrapped_inputs[i] = stacked if not stacking_mismatch: if mismatching_stacked_indices: # We needed to stack some inputs. This code will be abandoned and @@ -4413,7 +4508,8 @@ class WhileV2(object): _ = while_fn.get_concrete_function() if indices_to_stack: # Need to abandon the current conversion, stack some inputs and restart. - self._pfor_input.stack_inputs(stack_indices=indices_to_stack) + self._pfor_input.stack_inputs( + stack_indices=indices_to_stack, tile_variants=True) # Note that this call will recurse at most one time. The first call will # do the required stacking, based on the iterative procedure in # _process_body, and the next invocation to __call__ should not need to do diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py index 5ec307cd0ed..312a5bada35 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch.py @@ -538,8 +538,9 @@ def register_dispatchers(): _, undecorated_op = tf_decorator.unwrap(op) if not hasattr(undecorated_op, tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names): - raise AssertionError('Expected %s to be an exported symbol ' - '(while adding a RaggedTensor dispatcher)') + raise AssertionError('Expected %r to be an exported symbol ' + '(while adding a RaggedTensor dispatcher)' + % (undecorated_op,)) for op in _UNARY_ELEMENTWISE_OPS: UnaryRaggedElementwiseDispatcher(op).register(op) diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py index 73a53583ada..60d608b381e 100644 --- a/tensorflow/python/ops/ragged/ragged_math_ops.py +++ b/tensorflow/python/ops/ragged/ragged_math_ops.py @@ -40,8 +40,12 @@ from tensorflow.python.util.tf_export import tf_export # pylint: disable=redefined-builtin @tf_export('ragged.range') @dispatch.add_dispatch_support -def range(starts, limits=None, deltas=1, dtype=None, - name=None, row_splits_dtype=dtypes.int64): +def range(starts, + limits=None, + deltas=1, + dtype=None, + name=None, + row_splits_dtype=dtypes.int64): """Returns a `RaggedTensor` containing the specified sequences of numbers. Each row of the returned `RaggedTensor` contains a single sequence: @@ -104,9 +108,8 @@ def range(starts, limits=None, deltas=1, dtype=None, result = gen_ragged_math_ops.ragged_range( starts, limits, deltas, Tsplits=row_splits_dtype, name=name) - return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values, - result.rt_nested_splits, - validate=False) + return ragged_tensor.RaggedTensor.from_row_splits( + result.rt_dense_values, result.rt_nested_splits, validate=False) def _infer_matching_dtype(tensors, dtype_hierarchy): @@ -118,7 +121,6 @@ def _infer_matching_dtype(tensors, dtype_hierarchy): ops.no_gradient('RaggedRange') - #=============================================================================== # ragged_segment_ #=============================================================================== @@ -181,8 +183,8 @@ def _ragged_segment_aggregate(unsorted_segment_op, `int32`. `segment_ids.shape` must be a prefix of `data.shape`. `segment_ids` is not required to be sorted. num_segments: An `int32` or `int64` scalar. - separator: An optional string. Defaults to None. The separator to - use when joining. Only used for string types. + separator: An optional string. Defaults to None. The separator to use when + joining. Only used for string types. name: A name prefix for the returned tensor (optional). Returns: @@ -261,38 +263,42 @@ def _ragged_segment_aggregate(unsorted_segment_op, def segment_sum(data, segment_ids, num_segments, name=None): # For docs, see: _RAGGED_SEGMENT_DOCSTRING - return _ragged_segment_aggregate(math_ops.unsorted_segment_sum, - data=data, - segment_ids=segment_ids, - num_segments=num_segments, - name=(name or'RaggedSegmentSum')) + return _ragged_segment_aggregate( + math_ops.unsorted_segment_sum, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + name=(name or 'RaggedSegmentSum')) def segment_prod(data, segment_ids, num_segments, name=None): # For docs, see: _RAGGED_SEGMENT_DOCSTRING - return _ragged_segment_aggregate(math_ops.unsorted_segment_prod, - data=data, - segment_ids=segment_ids, - num_segments=num_segments, - name=(name or 'RaggedSegmentProd')) + return _ragged_segment_aggregate( + math_ops.unsorted_segment_prod, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + name=(name or 'RaggedSegmentProd')) def segment_min(data, segment_ids, num_segments, name=None): # For docs, see: _RAGGED_SEGMENT_DOCSTRING - return _ragged_segment_aggregate(math_ops.unsorted_segment_min, - data=data, - segment_ids=segment_ids, - num_segments=num_segments, - name=(name or 'RaggedSegmentMin')) + return _ragged_segment_aggregate( + math_ops.unsorted_segment_min, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + name=(name or 'RaggedSegmentMin')) def segment_max(data, segment_ids, num_segments, name=None): # For docs, see: _RAGGED_SEGMENT_DOCSTRING - return _ragged_segment_aggregate(math_ops.unsorted_segment_max, - data=data, - segment_ids=segment_ids, - num_segments=num_segments, - name=(name or 'RaggedSegmentMax')) + return _ragged_segment_aggregate( + math_ops.unsorted_segment_max, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + name=(name or 'RaggedSegmentMax')) def segment_mean(data, segment_ids, num_segments, name=None): @@ -301,7 +307,8 @@ def segment_mean(data, segment_ids, num_segments, name=None): [data, segment_ids, num_segments]): total = segment_sum(data, segment_ids, num_segments) ones = ragged_tensor.RaggedTensor.from_nested_row_splits( - array_ops.ones_like(data.flat_values), data.nested_row_splits, + array_ops.ones_like(data.flat_values), + data.nested_row_splits, validate=False) count = segment_sum(ones, segment_ids, num_segments) if ragged_tensor.is_ragged(total): @@ -316,12 +323,13 @@ def segment_sqrt_n(data, segment_ids, num_segments, name=None): [data, segment_ids, num_segments]): total = segment_sum(data, segment_ids, num_segments) ones = ragged_tensor.RaggedTensor.from_nested_row_splits( - array_ops.ones_like(data.flat_values), data.nested_row_splits, + array_ops.ones_like(data.flat_values), + data.nested_row_splits, validate=False) count = segment_sum(ones, segment_ids, num_segments) if ragged_tensor.is_ragged(total): - return total.with_flat_values( - total.flat_values / math_ops.sqrt(count.flat_values)) + return total.with_flat_values(total.flat_values / + math_ops.sqrt(count.flat_values)) else: return total / math_ops.sqrt(count) @@ -455,8 +463,8 @@ def ragged_reduce_aggregate(reduce_op, range `[0, rt_input.rank)`. keepdims: If true, retains reduced dimensions with length 1. separator: An optional string. Defaults to None. The separator to use when - joining. The separator must not be set for non-string data types. (i.e. - if separator is not None then it uses string ops) + joining. The separator must not be set for non-string data types. (i.e. if + separator is not None then it uses string ops) name: A name prefix for the returned tensor (optional). Returns: @@ -470,14 +478,12 @@ def ragged_reduce_aggregate(reduce_op, """ if not ragged_tensor.is_ragged(rt_input): if separator is None: - return reduce_op(rt_input, axis, name=name) + return reduce_op(rt_input, axis, keepdims=keepdims, name=name) else: # When separator is not None, We infer that dtype is string and # reduce_join will be called. - return reduce_op(rt_input, axis, name=name, separator=separator) - - if keepdims: - raise ValueError('keepdims=True is not supported for RaggedTensors.') + return reduce_op( + rt_input, axis, keepdims=keepdims, name=name, separator=separator) if isinstance(axis, ops.Tensor): axis = tensor_util.constant_value(axis) @@ -488,7 +494,12 @@ def ragged_reduce_aggregate(reduce_op, # When reducing all axes, just ignore splits & reduce the inner values. if axis is None: - return reduce_op(rt_input.flat_values, None, name=name) + result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name) + if keepdims: + # Expand the result to the input number of dimensions. + for _ in rt_input.shape[1:]: + result = array_ops.expand_dims(result, axis=0) + return result with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]): if isinstance(axis, (tuple, list)): @@ -529,15 +540,21 @@ def ragged_reduce_aggregate(reduce_op, row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1] num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0) segment_ids = range(row_lengths).values - return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, - segment_ids, num_segments, separator) + result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, + segment_ids, num_segments, separator) + if keepdims: + result = array_ops.expand_dims(result, axis=0) + return result elif axis == 1: # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N] num_segments = array_ops.shape(rt_input.row_splits)[0] - 1 segment_ids = segment_id_ops.row_splits_to_segment_ids( rt_input.row_splits) - return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, - segment_ids, num_segments, separator) + result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, + segment_ids, num_segments, separator) + if keepdims: + result = array_ops.expand_dims(result, axis=1) + return result else: # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] = # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] @@ -554,7 +571,8 @@ def reduce_sum(input_tensor, axis=None, keepdims=None, name=None): reduce_op=math_ops.reduce_sum, unsorted_segment_op=math_ops.unsorted_segment_sum, rt_input=input_tensor, - axis=axis, keepdims=keepdims, + axis=axis, + keepdims=keepdims, name=(name or 'RaggedReduceSum')) @@ -598,13 +616,15 @@ def reduce_mean(input_tensor, axis=None, keepdims=None, name=None): if ragged_tensor.is_ragged(input_tensor): ones = ragged_tensor.RaggedTensor.from_nested_row_splits( array_ops.ones_like(input_tensor.flat_values), - input_tensor.nested_row_splits, validate=False) + input_tensor.nested_row_splits, + validate=False) else: ones = array_ops.ones_like(input_tensor) count = reduce_sum(ones, axis, keepdims) if ragged_tensor.is_ragged(total): return ragged_tensor.RaggedTensor.from_nested_row_splits( - total.flat_values / count.flat_values, total.nested_row_splits, + total.flat_values / count.flat_values, + total.nested_row_splits, validate=False) else: return total / count diff --git a/tensorflow/python/ops/ragged/ragged_reduce_op_test.py b/tensorflow/python/ops/ragged/ragged_reduce_op_test.py index a39090fa3a2..2afe0867ad7 100644 --- a/tensorflow/python/ops/ragged/ragged_reduce_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_reduce_op_test.py @@ -40,8 +40,7 @@ def mean(*values): @test_util.run_all_in_graph_and_eager_modes -class RaggedReduceOpsTest(test_util.TensorFlowTestCase, - parameterized.TestCase): +class RaggedReduceOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): @parameterized.parameters( #========================================================================= @@ -51,92 +50,212 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, # [9, ], # [2, 6 ]] #========================================================================= + # keepdims=True dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=0, + keepdims=False, expected=[15, 12, 4] # = [3+1+9+2, 1+5+6, 4] ), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=-2, + keepdims=False, expected=[15, 12, 4] # = [3+1+9+2, 1+5+6, 4] ), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=1, + keepdims=False, expected=[8, 6, 9, 8] # = [3+1+4, 1+5, 9, 2+6] ), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=-1, + keepdims=False, expected=[8, 6, 9, 8] # = [3+1+4, 1+5, 9, 2+6] ), dict( ragged_reduce_op=ragged_math_ops.reduce_prod, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=0, + keepdims=False, expected=[54, 30, 4] # = [3*1*9*2, 1*5*6, 4] ), dict( ragged_reduce_op=ragged_math_ops.reduce_prod, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=1, + keepdims=False, expected=[12, 5, 9, 12] # = [3*1*4, 1*5, 9, 2*6] ), dict( ragged_reduce_op=ragged_math_ops.reduce_min, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=0, + keepdims=False, expected=[1, 1, 4] # = [min(3, 1, 9, 2), min(1, 5, 6), 4] ), dict( ragged_reduce_op=ragged_math_ops.reduce_min, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=1, + keepdims=False, expected=[1, 1, 9, 2] # = [min(3, 1, 4), min(1, 5), 9, min(2, 6)] ), dict( ragged_reduce_op=ragged_math_ops.reduce_max, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=0, + keepdims=False, expected=[9, 6, 4] # = [max(3, 1, 9, 2), max(1, 5, 6), 4] ), dict( ragged_reduce_op=ragged_math_ops.reduce_max, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=1, + keepdims=False, expected=[4, 5, 9, 6] # = [max(3, 1, 4), max(1, 5), 9, max(2, 6)] ), dict( ragged_reduce_op=ragged_math_ops.reduce_mean, rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], axis=0, + keepdims=False, expected=[3.75, 4, 4] # = [mean(3, 1, 9, 2), mean(1, 5, 6), 4] ), dict( ragged_reduce_op=ragged_math_ops.reduce_any, rt_input=[[True, True], [True, True, False, True], [False, True]], axis=0, + keepdims=False, expected=[True, True, False, True]), dict( ragged_reduce_op=ragged_math_ops.reduce_any, rt_input=[[True, True], [True, True, False, True], [False, True]], axis=1, + keepdims=False, expected=[True, True, True]), dict( ragged_reduce_op=ragged_math_ops.reduce_all, rt_input=[[True, True], [True, True, False, True], [False, True]], axis=0, + keepdims=False, expected=[False, True, False, True]), dict( ragged_reduce_op=ragged_math_ops.reduce_all, rt_input=[[True, True], [True, True, False, True], [False, True]], axis=1, + keepdims=False, expected=[True, False, False]), + # keepdims=True + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=0, + keepdims=True, + expected=[[15, 12, 4]] # = [[3+1+9+2, 1+5+6, 4]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=-2, + keepdims=True, + expected=[[15, 12, 4]] # = [[3+1+9+2, 1+5+6, 4]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=1, + keepdims=True, + expected=[[8], [6], [9], [8]] # = [[3+1+4], [1+5], [9], [2+6]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=-1, + keepdims=True, + expected=[[8], [6], [9], [8]] # = [[3+1+4], [1+5], [9], [2+6]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_prod, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=0, + keepdims=True, + expected=[[54, 30, 4]] # = [[3*1*9*2, 1*5*6, 4]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_prod, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=1, + keepdims=True, + expected=[[12], [5], [9], [12]] # = [[3*1*4], [1*5], [9], [2*6]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_min, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=0, + keepdims=True, + expected=[[1, 1, 4]] # = [[min(3, 1, 9, 2), min(1, 5, 6), 4]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_min, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=1, + keepdims=True, + expected=[[1], [1], [9], + [2]] # = [[min(3, 1, 4)], [min(1, 5)], [9], [min(2, 6)]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_max, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=0, + keepdims=True, + expected=[[9, 6, 4]] # = [[max(3, 1, 9, 2), max(1, 5, 6), 4]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_max, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=1, + keepdims=True, + expected=[[4], [5], [9], + [6]] # = [[max(3, 1, 4)], [max(1, 5)], [9], [max(2, 6)]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_mean, + rt_input=[[3, 1, 4], [1, 5], [9], [2, 6]], + axis=0, + keepdims=True, + expected=[[3.75, 4, 4]] # = [[mean(3, 1, 9, 2), mean(1, 5, 6), 4]] + ), + dict( + ragged_reduce_op=ragged_math_ops.reduce_any, + rt_input=[[True, True], [True, True, False, True], [False, True]], + axis=0, + keepdims=True, + expected=[[True, True, False, True]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_any, + rt_input=[[True, True], [True, True, False, True], [False, True]], + axis=1, + keepdims=True, + expected=[[True], [True], [True]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_all, + rt_input=[[True, True], [True, True, False, True], [False, True]], + axis=0, + keepdims=True, + expected=[[False, True, False, True]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_all, + rt_input=[[True, True], [True, True, False, True], [False, True]], + axis=1, + keepdims=True, + expected=[[True], [False], [False]]), #========================================================================= # Examples with the following RaggedTensor (ragged_rank=1): @@ -153,52 +272,62 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=None, + keepdims=False, expected=0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9), dict( ragged_reduce_op=ragged_math_ops.reduce_prod, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=None, + keepdims=False, expected=0 * 1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9), dict( ragged_reduce_op=ragged_math_ops.reduce_min, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=None, + keepdims=False, expected=min(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)), dict( ragged_reduce_op=ragged_math_ops.reduce_max, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=None, + keepdims=False, expected=max(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)), dict( ragged_reduce_op=ragged_math_ops.reduce_mean, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=None, + keepdims=False, expected=mean(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)), # axis=0 dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=0, + keepdims=False, expected=[0 + 4 + 5 + 7 + 8, 1 + 6 + 9, 2, 3]), dict( ragged_reduce_op=ragged_math_ops.reduce_prod, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=0, + keepdims=False, expected=[0 * 4 * 5 * 7 * 8, 1 * 6 * 9, 2, 3]), dict( ragged_reduce_op=ragged_math_ops.reduce_min, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=0, + keepdims=False, expected=[min(0, 4, 5, 7, 8), min(1, 6, 9), 2, 3]), dict( ragged_reduce_op=ragged_math_ops.reduce_max, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=0, + keepdims=False, expected=[max(0, 4, 5, 7, 8), max(1, 6, 9), 2, 3]), dict( ragged_reduce_op=ragged_math_ops.reduce_mean, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=0, + keepdims=False, expected=[mean(0, 4, 5, 7, 8), mean(1, 6, 9), 2, 3]), # axis=1 @@ -208,16 +337,19 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=1, + keepdims=False, expected=[0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]), dict( ragged_reduce_op=ragged_math_ops.reduce_prod, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=1, + keepdims=False, expected=[0 * 1 * 2 * 3, 4, 1, 5 * 6, 7, 8 * 9]), dict( ragged_reduce_op=ragged_math_ops.reduce_min, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=1, + keepdims=False, expected=[min(0, 1, 2, 3), 4, _MAX_INT32, min(5, 6), 7, min(8, 9)]), @@ -225,6 +357,7 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, ragged_reduce_op=ragged_math_ops.reduce_max, rt_input=[[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]], axis=1, + keepdims=False, expected=[max(0, 1, 2, 3), 4, _MIN_INT32, max(5, 6), 7, max(8, 9)]), @@ -236,51 +369,117 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, # [ ], # [[9 ] ]] #========================================================================= + # keepdims=False dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=[], + keepdims=False, expected=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=None, + keepdims=False, expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=0, + keepdims=False, expected=[[1 + 6 + 9, 2 + 7], [], [3 + 8, 4, 5]]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=1, + keepdims=False, expected=[[1 + 3, 2 + 4, 5], [6 + 8, 7], [], [9]]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=2, + keepdims=False, expected=[[1 + 2, 0, 3 + 4 + 5], [6 + 7, 0, 8], [], [9]]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=[0, 1], + keepdims=False, expected=[1 + 3 + 6 + 8 + 9, 2 + 4 + 7, 5]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=[0, 2], + keepdims=False, expected=[1 + 6 + 9 + 2 + 7, 0, 3 + 8 + 4 + 5]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=[1, 2], + keepdims=False, expected=[1 + 2 + 3 + 4 + 5, 6 + 7 + 8, 0, 9]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=[0, 1, 2], + keepdims=False, expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])), + # keepdims=True + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=[], + keepdims=True, + expected=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=None, + keepdims=True, + expected=[[[sum([1, 2, 3, 4, 5, 6, 7, 8, 9])]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=0, + keepdims=True, + expected=[[[1 + 6 + 9, 2 + 7], [], [3 + 8, 4, 5]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=1, + keepdims=True, + expected=[[[1 + 3, 2 + 4, 5]], [[6 + 8, 7]], [[]], [[9]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=2, + keepdims=True, + expected=[[[1 + 2], [0], [3 + 4 + 5]], [[6 + 7], [0], [8]], [], + [[9]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=[0, 1], + keepdims=True, + expected=[[[1 + 3 + 6 + 8 + 9, 2 + 4 + 7, 5]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=[0, 2], + keepdims=True, + expected=[[[1 + 6 + 9 + 2 + 7], [0], [3 + 8 + 4 + 5]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=[1, 2], + keepdims=True, + expected=[[[1 + 2 + 3 + 4 + 5]], [[6 + 7 + 8]], [[0]], [[9]]]), + dict( + ragged_reduce_op=ragged_math_ops.reduce_sum, + rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], + axis=[0, 1, 2], + keepdims=True, + expected=[[[sum([1, 2, 3, 4, 5, 6, 7, 8, 9])]]]), #========================================================================= # Examples for ragged_reduce_mean ragged_rank=2: @@ -292,16 +491,19 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, ragged_reduce_op=ragged_math_ops.reduce_mean, rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]], axis=0, + keepdims=False, expected=[[mean(1, 6, 9), mean(2, 7)], [mean(3, 8), 4, 5]]), dict( ragged_reduce_op=ragged_math_ops.reduce_mean, rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]], axis=1, + keepdims=False, expected=[[mean(1, 3), mean(2, 4), 5], [mean(6, 8), 7], [9]]), dict( ragged_reduce_op=ragged_math_ops.reduce_mean, rt_input=[[[1, 2], [3, 4, 5]], [[6, 7], [8]], [[9]]], axis=2, + keepdims=False, expected=[[mean(1, 2), mean(3, 4, 5)], [mean(6, 7), 8], [9]]), # Test case for GitHub issue 27497, multiple negative axes. @@ -309,16 +511,18 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=[-2, -1], + keepdims=False, expected=[1 + 2 + 3 + 4 + 5, 6 + 7 + 8, 0, 9]), dict( ragged_reduce_op=ragged_math_ops.reduce_sum, rt_input=[[[1, 2], [], [3, 4, 5]], [[6, 7], [], [8]], [], [[9]]], axis=[-3, -2, -1], + keepdims=False, expected=sum([1, 2, 3, 4, 5, 6, 7, 8, 9])), ) - def testReduce(self, ragged_reduce_op, rt_input, axis, expected): + def testReduce(self, ragged_reduce_op, rt_input, axis, keepdims, expected): rt_input = ragged_factory_ops.constant(rt_input) - reduced = ragged_reduce_op(rt_input, axis) + reduced = ragged_reduce_op(rt_input, axis, keepdims=keepdims) self.assertAllEqual(reduced, expected) def testReduceKeepsInnerDimensionShape(self): @@ -336,8 +540,8 @@ class RaggedReduceOpsTest(test_util.TensorFlowTestCase, def testMeanNan(self): rt_as_list = [[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]] expected = ( - np.array([0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]) / np.array( - [4, 1, 0, 2, 1, 2])) + np.array([0 + 1 + 2 + 3, 4, 0, 5 + 6, 7, 8 + 9]) / + np.array([4, 1, 0, 2, 1, 2])) rt_input = ragged_factory_ops.constant(rt_as_list) reduced = ragged_math_ops.reduce_mean(rt_input, axis=1) self.assertEqualWithNan(self.evaluate(reduced), expected) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 4806617c1af..6cda36d556e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -35,14 +35,15 @@ from tensorflow.python.framework import auto_control_deps_utils as acd from tensorflow.python.framework import constant_op from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_state_ops +from tensorflow.python.ops import handle_data_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -62,14 +63,8 @@ acd.register_read_only_resource_op("ResourceGatherNd") acd.register_read_only_resource_op("_ReadVariablesOp") -def get_resource_handle_data(graph_op): - assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck - - handle_data = pywrap_tf_session.GetHandleShapeAndType( - graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access - - return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( - compat.as_bytes(handle_data)) +# TODO(allenl): Remove this alias and migrate callers. +get_resource_handle_data = handle_data_util.get_resource_handle_data def get_eager_safe_handle_data(handle): @@ -156,6 +151,12 @@ def _variable_handle_from_shape_and_dtype(shape, container = "" shape = tensor_shape.as_shape(shape) dtype = dtypes.as_dtype(dtype) + if not graph_mode: + if shared_name is not None: + raise errors.InternalError( + "Using an explicit shared_name is not supported executing eagerly.") + shared_name = context.shared_name() + handle = gen_resource_variable_ops.var_handle_op( shape=shape, dtype=dtype, @@ -169,19 +170,6 @@ def _variable_handle_from_shape_and_dtype(shape, _set_handle_shapes_and_types(handle, full_handle_data, graph_mode) return handle else: - # We do not want two distinct ResourceVariable objects for the same - # underlying resource in the runtime. - # When in eager mode, explicitly ensure so here. When in graph mode, it's - # ensured by always generating different variable names. - exists = gen_resource_variable_ops.var_is_initialized_op(handle) - - # We create an assert Op instead of checking right away in order to be - # compatible with ASYNC execution mode. Further, since not all devices - # support string tensors, we encode the assertion string in the Op name - gen_logging_ops._assert( # pylint: disable=protected-access - math_ops.logical_not(exists), [exists], - name="EagerVariableNameReuse") - handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() handle_data.is_set = True handle_data.shape_and_type.append( @@ -800,7 +788,13 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor): Returns: A `Tensor` of type `bool`. """ - return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) + # TODO(b/169792703): The current device placement logic never overrides an + # explicit placement with a custom device, causing `v.is_initalized()` to + # fail under a non-custom device context if `v` is in a custom device. The + # explicit placement below makes this work, but should not be necessary once + # the logic is updated to handle cases like this. + with ops.device(self.device): + return gen_resource_variable_ops.var_is_initialized_op(self.handle, name) def assign_sub(self, delta, use_locking=None, name=None, read_value=True): """Subtracts a value from this variable. @@ -1703,7 +1697,7 @@ class ResourceVariable(BaseResourceVariable): # When in eager mode use a uid for the shared_name, to prevent # accidental sharing. unique_id = "%s_%d" % (handle_name, ops.uid()) - shared_name = context.shared_name() + shared_name = None # Never shared # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. @@ -1954,7 +1948,7 @@ class UninitializedVariable(BaseResourceVariable): unique_id = shared_name else: unique_id = "%s_%d" % (handle_name, ops.uid()) - shared_name = context.shared_name() + shared_name = None # Never shared handle = _variable_handle_from_shape_and_dtype( shape=shape, dtype=dtype, diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 23c24476934..556360ce640 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -161,6 +161,10 @@ def while_loop(cond, Returns: A list of tensors the same length as args. """ + # The function was created with a signature rather than tensors, so + # internal placeholders were created without handle data. + _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True), + nest.flatten(args, expand_composites=True)) # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: @@ -372,7 +376,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name while_op._add_while_inputs(new_inputs) while_op._add_outputs([t.dtype for t in new_outputs], [t.shape for t in new_outputs]) - _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) + _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:]) # Do not ignore grads wrt extra outputs when computing higher order # derivatives. diff --git a/tensorflow/python/platform/remote_utils.py b/tensorflow/python/platform/remote_utils.py index 9ec2e5e5ef8..aa430f5d697 100644 --- a/tensorflow/python/platform/remote_utils.py +++ b/tensorflow/python/platform/remote_utils.py @@ -20,3 +20,11 @@ from __future__ import print_function def get_default_communication_protocol(): return 'grpc' + + +def is_remote_path(_): + return False + + +def get_appendable_file_encoding(): + return '' diff --git a/tensorflow/python/profiler/integration_test/BUILD b/tensorflow/python/profiler/integration_test/BUILD index a7b92cd714a..b7034ad1ddf 100644 --- a/tensorflow/python/profiler/integration_test/BUILD +++ b/tensorflow/python/profiler/integration_test/BUILD @@ -19,6 +19,7 @@ cuda_py_test( srcs = ["profiler_api_test.py"], python_version = "PY3", tags = [ + "external", # So that test suite reruns unconditionally. "no_pip", "no_rocm", ], diff --git a/tensorflow/python/profiler/integration_test/profiler_api_test.py b/tensorflow/python/profiler/integration_test/profiler_api_test.py index e7785dbc697..a6134000a87 100644 --- a/tensorflow/python/profiler/integration_test/profiler_api_test.py +++ b/tensorflow/python/profiler/integration_test/profiler_api_test.py @@ -67,10 +67,15 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): 'kernel_stats.pb', ] for file in expected_files: - path = os.path.join(logdir, 'plugins/profile/*/*{}'.format(file)) + path = os.path.join(logdir, 'plugins', 'profile', '*', '*{}'.format(file)) self.assertEqual(1, len(glob.glob(path)), 'Expected one path match: ' + path) + def _check_xspace_pb_exist(self, logdir): + path = os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb') + self.assertEqual(1, len(glob.glob(path)), + 'Expected one path match: ' + path) + def test_single_worker_no_profiling(self): """Test single worker without profiling.""" @@ -87,23 +92,28 @@ class ProfilerApiTest(test_util.TensorFlowTestCase): _, steps, train_ds, model = _model_setup() model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) - port = portpicker.pick_unused_port() - thread = threading.Thread(target=on_worker, args=(port,)) - thread.start() - # Request for 3 seconds of profile. - duration_ms = 3000 + def on_profile(port, logdir): + # Request for 30 milliseconds of profile. + duration_ms = 30 + + options = profiler.ProfilerOptions( + host_tracer_level=2, + python_tracer_level=0, + device_tracer_level=1, + ) + + profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, + '', 100, options) + logdir = self.get_temp_dir() - - options = profiler.ProfilerOptions( - host_tracer_level=2, - python_tracer_level=0, - device_tracer_level=1, - ) - - profiler_client.trace('localhost:{}'.format(port), logdir, duration_ms, '', - 3, options) - thread.join(30) - self._check_tools_pb_exist(logdir) + port = portpicker.pick_unused_port() + thread_profiler = threading.Thread(target=on_profile, args=(port, logdir)) + thread_worker = threading.Thread(target=on_worker, args=(port,)) + thread_worker.start() + thread_profiler.start() + thread_profiler.join() + thread_worker.join(120) + self._check_xspace_pb_exist(logdir) def test_single_worker_programmatic_mode(self): """Test single worker programmatic mode.""" diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 9b24e17b5c2..5adb6d0a4b1 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -1,12 +1,8 @@ -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") - -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/core/profiler/builds:build_config.bzl", "tf_profiler_copts") package( default_visibility = ["//tensorflow/python/profiler:__subpackages__"], @@ -85,6 +81,7 @@ cuda_py_test( tf_python_pybind_extension( name = "_pywrap_traceme", srcs = ["traceme_wrapper.cc"], + copts = tf_profiler_copts(), module_name = "_pywrap_traceme", visibility = [ "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__", @@ -99,13 +96,13 @@ tf_python_pybind_extension( cc_library( name = "traceme_wrapper", hdrs = ["traceme_wrapper.h"], - features = ["-layering_check"], + copts = tf_profiler_copts(), visibility = [ "//tensorflow/compiler/xla/python:__pkg__", ], deps = [ "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme_headers", + "//tensorflow/core/profiler/lib:traceme_for_pybind", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@pybind11", @@ -115,7 +112,7 @@ cc_library( tf_python_pybind_extension( name = "_pywrap_profiler", srcs = ["profiler_wrapper.cc"], - features = ["-layering_check"], + copts = tf_profiler_copts(), module_name = "_pywrap_profiler", visibility = [ "//tensorflow/python/eager:__pkg__", @@ -123,19 +120,17 @@ tf_python_pybind_extension( ], deps = [ "//tensorflow/core:lib", - "//tensorflow/core/profiler/convert:op_stats_to_tf_stats", - "//tensorflow/core/profiler/convert:xplane_to_op_stats", + "//tensorflow/core/profiler/convert:xplane_to_tools_data", "//tensorflow/core/profiler/convert:xplane_to_trace_events", - "//tensorflow/core/profiler/lib:profiler_session_headers", - "//tensorflow/core/profiler/protobuf:op_stats_proto_cc", + "//tensorflow/core/profiler/lib:profiler_session_for_pybind", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "//tensorflow/core/profiler/rpc:profiler_server_headers", + "//tensorflow/core/profiler/rpc:profiler_server_for_pybind", "//tensorflow/core/profiler/rpc/client:capture_profile", "//tensorflow/core/profiler/rpc/client:save_profile", "//tensorflow/python:pybind11_status", - "@com_github_grpc_grpc//:grpc++_public_hdrs", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@pybind11", ], ) @@ -145,7 +140,7 @@ cc_library( srcs = ["python_hooks.cc"], hdrs = ["python_hooks.h"], compatible_with = get_compatible_with_cloud(), - copts = ["-fexceptions"], + copts = tf_profiler_copts() + ["-fexceptions"], features = ["-use_header_modules"], # Incompatible with -fexceptions. visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index 6b0a095abea..f0c289afe01 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -14,27 +14,26 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" -#include "pybind11/cast.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" -#include "tensorflow/core/profiler/convert/op_stats_to_overview_page.h" -#include "tensorflow/core/profiler/convert/op_stats_to_tf_stats.h" -#include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" -#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" +#include "tensorflow/core/profiler/convert/xplane_to_tools_data.h" #include "tensorflow/core/profiler/convert/xplane_to_trace_events.h" #include "tensorflow/core/profiler/lib/profiler_session.h" -#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" -#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" -#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/profiler/rpc/client/save_profile.h" @@ -45,7 +44,12 @@ namespace py = ::pybind11; namespace { -tensorflow::Status ValidateHostPortPair(const std::string& host_port) { +using ::tensorflow::RemoteProfilerSessionManagerOptions; + +// Profiler gives grace after profiling duration to terminate. +constexpr absl::Duration kSessionGraceTime = absl::Seconds(5); + +tensorflow::Status ValidateHostPortPair(absl::string_view host_port) { tensorflow::uint32 port; std::vector parts = absl::StrSplit(host_port, ':'); // Must be host:port, port must be a number, host must not contain a '/', @@ -58,34 +62,156 @@ tensorflow::Status ValidateHostPortPair(const std::string& host_port) { return tensorflow::Status::OK(); } -// Takes profiler options in a py::dict and returns a ProfileOptions. +tensorflow::Status ValidateOptions( + const RemoteProfilerSessionManagerOptions& options) { + if (options.service_addresses().empty()) { + return tensorflow::errors::InvalidArgument("No service address provided."); + } + + if (options.profiler_options().duration_ms() == 0) { + return tensorflow::errors::InvalidArgument( + "duration_ms must be greater than zero."); + } + + for (absl::string_view host_port : options.service_addresses()) { + TF_RETURN_IF_ERROR(ValidateHostPortPair(host_port)); + } + + if (options.max_session_duration_ms() < + options.profiler_options().duration_ms()) { + return tensorflow::errors::InvalidArgument( + "The maximum profiling session duration must be greater than or equal " + "to the local profiler duration."); + } + + return tensorflow::Status::OK(); +} + +// Receives a comma delimited list of service_addresses and adds them to +// RemoteProfilerSessionManagerOptions::service_addresses. +void AddServiceAddresses(absl::string_view service_addresses, + RemoteProfilerSessionManagerOptions* options) { + for (absl::string_view server : absl::StrSplit(service_addresses, ',')) { + options->add_service_addresses(server.data(), server.size()); + } +} + +// Sets gRPC deadline to a grace period based on the profiling duration. +void UpdateMaxSessionDuration(RemoteProfilerSessionManagerOptions& options) { + auto local_profiler_duration = options.profiler_options().duration_ms(); + auto session_creation_ts = options.session_creation_timestamp_ns(); + auto requested_start_ts = options.profiler_options().start_timestamp_ns(); + // User only needs to set maximal session duration if the profiling duration + // is bounded. + DCHECK_GT(local_profiler_duration, 0); + VLOG(3) << "duration_ms was given as " << local_profiler_duration; + // Max session duration includes the profiling session and grace time. + auto profile_duration = + absl::Milliseconds(local_profiler_duration) + kSessionGraceTime; + absl::Duration delay_duration; + // When requested start timestamp is 0, profiling starts immediately. + if (requested_start_ts > 0) { + delay_duration = + absl::Nanoseconds(requested_start_ts - session_creation_ts); + } + + auto max_session_duration = profile_duration + delay_duration; + options.set_max_session_duration_ms( + absl::ToInt64Milliseconds(max_session_duration)); + VLOG(1) << "max_session_duration set to " << max_session_duration; +} + +// Takes profiler options in a py::dict and returns a +// RemoteProfilerSessionManagerOptions. // This must be called under GIL because it reads Python objects. Reading Python // objects require GIL because the objects can be mutated by other Python // threads. In addition, Python objects are reference counted; reading py::dict // will increase its reference count. -tensorflow::ProfileOptions GetOptionsLocked(const py::dict& opts) { - tensorflow::ProfileOptions options = +RemoteProfilerSessionManagerOptions GetOptionsLocked(absl::string_view logdir, + const py::dict& opts) { + RemoteProfilerSessionManagerOptions options; + *options.mutable_profiler_options() = tensorflow::ProfilerSession::DefaultOptions(); + // Store a timestamp of when this session was created. This will be the basis + // of gRPC deadline afterwards. + auto now = absl::Now(); + options.set_session_creation_timestamp_ns(absl::ToUnixNanos(now)); + VLOG(2) << "set_session_creation_timestamp_ns set to " + << options.session_creation_timestamp_ns() << " [" << now << "]"; + + // Set the path of where to store XSpaces. + options.mutable_profiler_options()->set_repository_path(logdir.data(), + logdir.size()); + VLOG(2) << "repository_path set to " + << options.profiler_options().repository_path(); + for (const auto& kw : opts) { std::string key = py::cast(kw.first); if (key == "host_tracer_level") { - options.set_host_tracer_level(py::cast(kw.second)); - VLOG(1) << "host_tracer_level set to " << options.host_tracer_level(); + auto value = py::cast(kw.second); + options.mutable_profiler_options()->set_host_tracer_level(value); + VLOG(1) << "host_tracer_level set to " << value; } else if (key == "device_tracer_level") { - options.set_device_tracer_level(py::cast(kw.second)); - VLOG(1) << "device_tracer_level set to " << options.device_tracer_level(); + auto value = py::cast(kw.second); + options.mutable_profiler_options()->set_device_tracer_level(value); + VLOG(1) << "device_tracer_level set to " << value; } else if (key == "python_tracer_level") { - options.set_python_tracer_level(py::cast(kw.second)); - VLOG(1) << "python_tracer_level set to " << options.python_tracer_level(); + auto value = py::cast(kw.second); + options.mutable_profiler_options()->set_python_tracer_level(value); + VLOG(1) << "python_tracer_level set to " << value; + } else { + LOG(WARNING) << "Unrecognised key: " << key; } } + + return options; +} + +RemoteProfilerSessionManagerOptions GetOptionsLocked( + absl::string_view service_addresses, absl::string_view logdir, + absl::string_view worker_list, bool include_dataset_ops, + tensorflow::int32 duration_ms, py::dict opts, bool* is_cloud_tpu_session) { + RemoteProfilerSessionManagerOptions options = GetOptionsLocked(logdir, opts); + + // Remote profiling does not support any use cases where the following options + // are set by `py::dict opts`. e.g. `opts['service_addrs']` will not happen. + DCHECK(options.service_addresses().empty()); + // In remote profiling, duration is always passed by value explicitly and not + // set in py::dict opts. + DCHECK_EQ(options.profiler_options().duration_ms(), 0); + // Because duration_ms is not set from py::dict opts, it follows that + // max_session_duration_ms must be unset as well. + DCHECK_EQ(options.max_session_duration_ms(), 0); + + // Worker_list is only used for TensorBoard TPU capture cases. For a TPU + // cluster, service_address is the Master, which can already be found in the + // list of workers. These sessions will be used with the ProfileAnalysis + // service. + *is_cloud_tpu_session = !worker_list.empty(); + AddServiceAddresses(*is_cloud_tpu_session ? worker_list : service_addresses, + &options); + + // Set local profiler duration and profiler session durations. + options.mutable_profiler_options()->set_include_dataset_ops( + include_dataset_ops); + options.mutable_profiler_options()->set_duration_ms(duration_ms); + UpdateMaxSessionDuration(options); + + for (int idx = 0; idx < options.service_addresses_size(); ++idx) { + VLOG(1) << "service_addr " << idx << " set to " + << options.service_addresses(idx); + } + VLOG(1) << "include_dataset_ops set to " << include_dataset_ops; + VLOG(1) << "duration_ms set to " << duration_ms; + return options; } class ProfilerSessionWrapper { public: void Start(const char* logdir, const py::dict& options) { - session_ = tensorflow::ProfilerSession::Create(GetOptionsLocked(options)); + auto opts = GetOptionsLocked(logdir, options); + session_ = tensorflow::ProfilerSession::Create(opts.profiler_options()); logdir_ = logdir; tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status()); } @@ -118,15 +244,6 @@ class ProfilerSessionWrapper { tensorflow::string logdir_; }; -// Converts a pybind list of XSpace paths to a cpp vector. -std::vector GetXSpacePaths(const py::list& python_paths) { - std::vector cpp_paths; - for (py::handle obj : python_paths) { - cpp_paths.push_back(std::string(py::cast(obj))); - } - return cpp_paths; -} - } // namespace PYBIND11_MODULE(_pywrap_profiler, m) { @@ -146,26 +263,28 @@ PYBIND11_MODULE(_pywrap_profiler, m) { profiler_server.release(); }); - m.def("trace", - [](const char* service_addr, const char* logdir, - const char* worker_list, bool include_dataset_ops, int duration_ms, - int num_tracing_attempts, py::dict options) { - // Normalize py::dict into a well defined proto. - tensorflow::ProfileOptions opts = GetOptionsLocked(options); + m.def("trace", [](const char* service_addr, const char* logdir, + const char* worker_list, bool include_dataset_ops, + int duration_ms, int num_tracing_attempts, + py::dict options) { + // TPU capture is true if the user sets worker_list. + bool is_cloud_tpu_session = false; + // Normalize py::dict into a well defined and validated proto. + tensorflow::RemoteProfilerSessionManagerOptions opts = + GetOptionsLocked(service_addr, logdir, worker_list, include_dataset_ops, + duration_ms, options, &is_cloud_tpu_session); + tensorflow::Status status = ValidateOptions(opts); + tensorflow::MaybeRaiseRegisteredFromStatus(status); - tensorflow::Status status = ValidateHostPortPair(service_addr); - tensorflow::MaybeRaiseRegisteredFromStatus(status); - opts.set_include_dataset_ops(include_dataset_ops); - { - // Release the lock to keep the lock scope to a minimum, and allow - // other threads to proceed. - py::gil_scoped_release release; - status = tensorflow::profiler::Trace(service_addr, logdir, - worker_list, duration_ms, - num_tracing_attempts, opts); - } - tensorflow::MaybeRaiseRegisteredFromStatus(status); - }); + { + // Release the lock to keep the lock scope to a minimum, and allow + // other threads to proceed. + py::gil_scoped_release release; + status = tensorflow::profiler::Trace(logdir, num_tracing_attempts, opts, + is_cloud_tpu_session); + } + tensorflow::MaybeRaiseRegisteredFromStatus(status); + }); m.def("monitor", [](const char* service_addr, int duration_ms, int monitoring_level, bool display_timestamp) { @@ -184,121 +303,17 @@ PYBIND11_MODULE(_pywrap_profiler, m) { return content; }); - m.def("xspace_to_trace_events", [](const py::list& xspace_path_list) { - std::vector xspace_paths = GetXSpacePaths(xspace_path_list); - if (xspace_paths.size() != 1) { - LOG(WARNING) << "Trace events tool expects only 1 XSpace path but gets " - << xspace_paths.size(); - return py::make_tuple(py::bytes(), py::bool_(false)); - } - tensorflow::profiler::XSpace xspace; - tensorflow::Status status = tensorflow::ReadBinaryProto( - tensorflow::Env::Default(), xspace_paths[0], &xspace); - if (!status.ok()) { - LOG(WARNING) << "Could not read XSpace for trace events: " - << xspace_paths[0]; - return py::make_tuple(py::bytes(), py::bool_(false)); - } - tensorflow::string content; - tensorflow::profiler::ConvertXSpaceToTraceEventsString(xspace, &content); - return py::make_tuple(py::bytes(content), py::bool_(true)); - }); - - m.def("xspace_to_overview_page", [](const py::list& xspace_path_list) { - std::vector xspace_paths = GetXSpacePaths(xspace_path_list); - tensorflow::profiler::OpStatsOptions options; - options.generate_kernel_stats_db = true; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - tensorflow::profiler::OpStats combined_op_stats; - tensorflow::Status status = ConvertMultiXSpacesToCombinedOpStats( - xspace_paths, options, &combined_op_stats); - if (!status.ok()) { - LOG(WARNING) << "Could not generate OpStats for overview page. Error: " - << status.error_message(); - return py::make_tuple(py::bytes(), py::bool_(false)); - } - // TODO(profiler): xspace should tell whether this is sampling mode. - tensorflow::profiler::OverviewPage overview_page = - tensorflow::profiler::ConvertOpStatsToOverviewPage(combined_op_stats); - return py::make_tuple(py::bytes(overview_page.SerializeAsString()), - py::bool_(true)); - }); - - m.def("xspace_to_input_pipeline", [](const py::list& xspace_path_list) { - std::vector xspace_paths = GetXSpacePaths(xspace_path_list); - tensorflow::profiler::OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_step_db = true; - tensorflow::profiler::OpStats combined_op_stats; - tensorflow::Status status = ConvertMultiXSpacesToCombinedOpStats( - xspace_paths, options, &combined_op_stats); - if (!status.ok()) { - LOG(WARNING) << "Could not generate OpStats for input pipeline. Error: " - << status.error_message(); - return py::make_tuple(py::bytes(), py::bool_(false)); - } - tensorflow::profiler::InputPipelineAnalysisResult input_pipeline = - tensorflow::profiler::ConvertOpStatsToInputPipelineAnalysis( - combined_op_stats); - return py::make_tuple(py::bytes(input_pipeline.SerializeAsString()), - py::bool_(true)); - }); - - m.def("xspace_to_tf_stats", [](const py::list& xspace_path_list) { - std::vector xspace_paths = GetXSpacePaths(xspace_path_list); - tensorflow::profiler::OpStatsOptions options; - options.generate_op_metrics_db = true; - options.generate_kernel_stats_db = true; - tensorflow::profiler::OpStats combined_op_stats; - tensorflow::Status status = ConvertMultiXSpacesToCombinedOpStats( - xspace_paths, options, &combined_op_stats); - if (!status.ok()) { - LOG(WARNING) << "Could not generate OpStats for tensorflow stats. Error: " - << status.error_message(); - return py::make_tuple(py::bytes(), py::bool_(false)); - } - tensorflow::profiler::TfStatsDatabase tf_stats_db = - tensorflow::profiler::ConvertOpStatsToTfStats(combined_op_stats); - return py::make_tuple(py::bytes(tf_stats_db.SerializeAsString()), - py::bool_(true)); - }); - - m.def("xspace_to_kernel_stats", [](const py::list& xspace_path_list) { - std::vector xspace_paths = GetXSpacePaths(xspace_path_list); - tensorflow::profiler::OpStatsOptions options; - options.generate_kernel_stats_db = true; - tensorflow::profiler::OpStats combined_op_stats; - tensorflow::Status status = ConvertMultiXSpacesToCombinedOpStats( - xspace_paths, options, &combined_op_stats); - if (!status.ok()) { - LOG(WARNING) << "Could not generate OpStats for kernel stats. Error: " - << status.error_message(); - return py::make_tuple(py::bytes(), py::bool_(false)); - } - return py::make_tuple( - py::bytes(combined_op_stats.kernel_stats_db().SerializeAsString()), - py::bool_(true)); - }); - - m.def("xspace_to_memory_profile", [](const py::list& xspace_path_list) { - std::vector xspace_paths = GetXSpacePaths(xspace_path_list); - if (xspace_paths.size() != 1) { - LOG(WARNING) << "Memory profile tool expects only 1 XSpace path but gets " - << xspace_paths.size(); - return py::make_tuple(py::bytes(), py::bool_(false)); - } - tensorflow::profiler::XSpace xspace; - tensorflow::Status status = tensorflow::ReadBinaryProto( - tensorflow::Env::Default(), xspace_paths[0], &xspace); - if (!status.ok()) { - LOG(WARNING) << "Could not read XSpace for memory profile: " - << xspace_paths[0]; - return py::make_tuple(py::bytes(), py::bool_(false)); - } - std::string json_output; - tensorflow::profiler::ConvertXSpaceToMemoryProfileJson(xspace, - &json_output); - return py::make_tuple(py::bytes(json_output), py::bool_(true)); - }); + m.def("xspace_to_tools_data", + [](const py::list& xspace_path_list, const py::str& py_tool_name) { + std::vector xspace_paths; + for (py::handle obj : xspace_path_list) { + xspace_paths.push_back(std::string(py::cast(obj))); + } + std::string tool_name = std::string(py_tool_name); + auto tool_data_and_success = + tensorflow::profiler::ConvertMultiXSpacesToToolData(xspace_paths, + tool_name); + return py::make_tuple(py::bytes(tool_data_and_success.first), + py::bool_(tool_data_and_success.second)); + }); }; diff --git a/tensorflow/python/profiler/profiler_client.py b/tensorflow/python/profiler/profiler_client.py index 1380b6578fc..0b1eb0db495 100644 --- a/tensorflow/python/profiler/profiler_client.py +++ b/tensorflow/python/profiler/profiler_client.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import errors from tensorflow.python.profiler.internal import _pywrap_profiler from tensorflow.python.util.tf_export import tf_export @@ -32,38 +33,57 @@ def trace(service_addr, worker_list='', num_tracing_attempts=3, options=None): - """Sends grpc requests to profiler server to perform on-demand profiling. + """Sends gRPC requests to one or more profiler servers to perform on-demand profiling. - This method will block caller thread until it receives tracing result. This - method supports CPU, GPU, and Cloud TPU. This method supports profiling a - single host for CPU, GPU, TPU, as well as multiple TPU workers. - The profiled results will be saved to your specified TensorBoard log - directory (e.g. the directory you save your model checkpoints). Use the + This method will block the calling thread until it receives responses from all + servers or until deadline expiration. Both single host and multiple host + profiling are supported on CPU, GPU, and TPU. + The profiled results will be saved by each server to the specified TensorBoard + log directory (i.e. the directory you save your model checkpoints). Use the TensorBoard profile plugin to view the visualization and analysis results. Args: - service_addr: gRPC address of profiler service e.g. grpc://localhost:6009. - logdir: Path of TensorBoard log directory e.g. /tmp/tb_log. - duration_ms: Duration of tracing or monitoring in ms. - worker_list: Optional. The list of workers that we are about to profile in - the current session (TPU only). + service_addr: A comma delimited string of gRPC addresses of the workers to + profile. + e.g. service_addr='grpc://localhost:6009' + service_addr='grpc://10.0.0.2:8466,grpc://10.0.0.3:8466' + service_addr='grpc://localhost:12345,grpc://localhost:23456' + logdir: Path to save profile data to, typically a TensorBoard log directory. + This path must be accessible to both the client and server. + e.g. logdir='gs://your_tb_dir' + duration_ms: Duration of tracing or monitoring in mliiseconds. Must be + greater than zero. + worker_list: An optional TPU only configuration. The list of workers to + profile in the current session. num_tracing_attempts: Optional. Automatically retry N times when no trace event is collected (default 3). options: profiler.experimental.ProfilerOptions namedtuple for miscellaneous profiler options. Raises: - UnavailableError: If no trace event is collected. + InvalidArgumentError: For when arguments fail validation checks. + UnavailableError: If no trace event was collected. Example usage (CPU/GPU): # Start a profiler server before your model runs. ```python tf.profiler.experimental.server.start(6009) - # your model code. + # (Model code goes here). # Send gRPC request to the profiler server to collect a trace of your model. ```python tf.profiler.experimental.client.trace('grpc://localhost:6009', - '/tmp/tb_log', 2000) + '/nfs/tb_log', 2000) + + Example usage (Multiple GPUs): + # E.g. your worker IP addresses are 10.0.0.2, 10.0.0.3, 10.0.0.4, and you + # would like to schedule start of profiling 1 second from now, for a duration + # of 2 seconds. + options['delay_ms'] = 1000 + tf.profiler.experimental.client.trace( + 'grpc://10.0.0.2:8466,grpc://10.0.0.3:8466,grpc://10.0.0.4:8466', + 'gs://your_tb_dir', + 2000, + options=options) Example usage (TPU): # Send gRPC request to a TPU worker to collect a trace of your model. A @@ -82,16 +102,19 @@ def trace(service_addr, # profile for 2 seconds. tf.profiler.experimental.client.trace('grpc://10.0.0.2:8466', 'gs://your_tb_dir', - 2000, '10.0.0.3,10.0.0.4') - + 2000, '10.0.0.2,10.0.0.3,10.0.0.4') Launch TensorBoard and point it to the same logdir you provided to this API. $ tensorboard --logdir=/tmp/tb_log (or gs://your_tb_dir in the above examples) Open your browser and go to localhost:6006/#profile to view profiling results. """ + if duration_ms <= 0: + raise errors.InvalidArgumentError(None, None, + 'duration_ms must be greater than zero.') + opts = dict(options._asdict()) if options is not None else {} _pywrap_profiler.trace( - _strip_prefix(service_addr, _GRPC_PREFIX), logdir, worker_list, True, + _strip_addresses(service_addr, _GRPC_PREFIX), logdir, worker_list, True, duration_ms, num_tracing_attempts, opts) @@ -127,3 +150,7 @@ def monitor(service_addr, duration_ms, level=1): def _strip_prefix(s, prefix): return s[len(prefix):] if s.startswith(prefix) else s + + +def _strip_addresses(addresses, prefix): + return ','.join([_strip_prefix(s, prefix) for s in addresses.split(',')]) diff --git a/tensorflow/python/profiler/profiler_client_test.py b/tensorflow/python/profiler/profiler_client_test.py index 343f09834fd..85042be5409 100644 --- a/tensorflow/python/profiler/profiler_client_test.py +++ b/tensorflow/python/profiler/profiler_client_test.py @@ -38,7 +38,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.UnavailableError) as error: profiler_client.trace( 'localhost:' + str(test_port), self.get_temp_dir(), duration_ms=10) - self.assertEqual('No trace event is collected', str(error.exception)) + self.assertStartsWith(str(error.exception), 'No trace event was collected') def testTrace_ProfileIdleServerWithOptions(self): test_port = portpicker.pick_unused_port() @@ -54,7 +54,7 @@ class ProfilerClientTest(test_util.TensorFlowTestCase): self.get_temp_dir(), duration_ms=10, options=options) - self.assertEqual('No trace event is collected', str(error.exception)) + self.assertStartsWith(str(error.exception), 'No trace event was collected') def testMonitor_ProcessInvalidAddress(self): # Monitor is only supported in cloud TPU. Test invalid address instead. diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index b3e6159cb72..092e4177f44 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -163,8 +163,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder): def setup_bare_concrete_function(saved_bare_concrete_function, concrete_functions): """Makes a restored bare concrete function callable.""" - # Bare concrete functions accept only flat lists of Tensors with unique - # names. concrete_function = concrete_functions[ saved_bare_concrete_function.concrete_function_name] # pylint: disable=protected-access @@ -172,6 +170,12 @@ def setup_bare_concrete_function(saved_bare_concrete_function, saved_bare_concrete_function.argument_keywords) concrete_function._num_positional_args = ( saved_bare_concrete_function.allowed_positional_arguments) + if saved_bare_concrete_function.HasField("function_spec"): + coder = nested_structure_coder.StructureCoder() + function_spec = _deserialize_function_spec_as_nonmethod( + saved_bare_concrete_function.function_spec, + coder) + concrete_function._set_function_spec(function_spec) # pylint: enable=protected-access concrete_function.add_to_graph() return concrete_function @@ -338,10 +342,11 @@ def load_function_def_library(library, load_shared_name_suffix=None): for dep in _list_function_deps(fdef, library_function_names): functions[dep].add_to_graph(func_graph) - # We do not initialize the new ConcreteFunction's function_spec or + # We do not initialize the new ConcreteFunction's function_spec and/or # arg_keywords here (which are used to parse the structured and flat - # signatures, respectively). function_spec is set up later by - # recreate_function(); and arg_keywords by setup_bare_concrete_function(). + # signatures, respectively). ConcreteFunction that are part of a saved + # function is set up later by recreate_function(); and bare ConcreteFunction + # is set up by by setup_bare_concrete_function(). func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(graph) diff --git a/tensorflow/python/saved_model/function_serialization.py b/tensorflow/python/saved_model/function_serialization.py index 07a80539a33..ad18e8f5d2a 100644 --- a/tensorflow/python/saved_model/function_serialization.py +++ b/tensorflow/python/saved_model/function_serialization.py @@ -85,17 +85,19 @@ def serialize_concrete_function(concrete_function, node_ids, coder): def serialize_bare_concrete_function(concrete_function, name_map): """Build a SavedBareConcreteFunction.""" - # TODO(edloper): Currently, bare concrete functions don't have access to a - # function_spec, so they can't be called with the structured signature. - # Update the serialization to include a function_spec. - # pylint: disable=protected-access name = name_map.get(compat.as_text(concrete_function.name), concrete_function.name) - return saved_object_graph_pb2.SavedBareConcreteFunction( + proto = saved_object_graph_pb2.SavedBareConcreteFunction( concrete_function_name=name, allowed_positional_arguments=concrete_function._num_positional_args, argument_keywords=concrete_function._arg_keywords) + if concrete_function._pre_initialized_function_spec is not None: + coder = nested_structure_coder.StructureCoder() + proto.function_spec.CopyFrom( + _serialize_function_spec( + concrete_function._pre_initialized_function_spec, coder)) + return proto # pylint: enable=protected-access diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index db5b7e449d3..a8244850308 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -480,6 +480,34 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertEqual(31, imported.f(input1).numpy()) self.assertEqual(32, imported.f(input3).numpy()) + def test_structured_inputs_bare_concrete_function(self, cycles): + + def func(x, training=True): + # x is a nested structure, we care about one particular tensor. + _, (a, b) = x + if training: + return 2 * a["a"] + b + else: + return 7 + + x = constant_op.constant(10) + y = constant_op.constant(11) + + input1 = [6, ({"a": x}, y)] + input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature. + input3 = [6, ({"a": y}, x)] # Compatible with input1 signature. + + root = tracking.AutoTrackable() + root.f = def_function.function(func).get_concrete_function(input1) + + imported = cycle(root, cycles) + + with self.assertRaises(TypeError): + imported.f(input2) + + self.assertEqual(31, imported.f(input1).numpy()) + self.assertEqual(32, imported.f(input3).numpy()) + def test_structured_output(self, cycles): # Use fields with non-alphabetical order @@ -509,6 +537,31 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertEqual(5, result[1].numpy()) self.assertEqual(0.5, result[2]["x"].numpy()) + def test_pretty_print_signature(self, cycles): + + named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"]) + + def func(input1, input2): + named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2) + return [named_tuple, input2, {"x": 0.5}] + + root = tracking.AutoTrackable() + root.f = def_function.function(func).get_concrete_function( + constant_op.constant(2), constant_op.constant(3)) + + imported = cycle(root, cycles) + self.assertEqual( + imported.f.pretty_printed_signature(), """func(input1, input2) + Args: + input1: int32 Tensor, shape=() + input2: int32 Tensor, shape=() + Returns: + [NamedTupleHello(b=<1>, a=<2>), <3>, {'x': <4>}] + <1>: int32 Tensor, shape=() + <2>: int32 Tensor, shape=() + <3>: int32 Tensor, shape=() + <4>: float32 Tensor, shape=()""") + def test_positional_arguments(self, cycles): def func(x, training=False, abc=7.1, defg=7.7): del abc diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 361883adc22..d03ae64a480 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -786,16 +786,26 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): # pylint:enable=protected-access proto.user_object.CopyFrom(registered_type_proto) + # Give the object a chance to modify the SavedObject proto. + # This is currently used by MirroredVariables to optionally write their + # component variables to the proto. + # + # This is not yet an official Trackable method, the only current use case + # being MirroredVariables. See the method implementation there for more + # documentation. + if hasattr(obj, "_write_object_proto"): + obj._write_object_proto(proto, options) # pylint: disable=protected-access -def _export_debug_info(exported_graph): - """Exports debug information from a graph. + +def _export_debug_info(exported_graph, export_dir): + """Exports debug information from graph to file. + + Creates and writes GraphDebugInfo with traces for ops in all functions of the + exported_graph. Args: exported_graph: A Graph that has been created by tracing a saveable view. - - Returns: - Corresponding GraphDebugInfo with traces for ops in all functions of the - exported_graph. + export_dir: SavedModel directory in which to write the debug info. """ exported_operations = [] for fn_name in exported_graph._functions: # pylint: disable=protected-access @@ -806,7 +816,14 @@ def _export_debug_info(exported_graph): fn_graph = fn.graph for fn_op in fn_graph.get_operations(): exported_operations.append((fn_name, fn_op)) - return error_interpolation.create_graph_debug_info_def(exported_operations) + + graph_debug_info = error_interpolation.create_graph_debug_info_def( + exported_operations) + file_io.atomic_write_string_to_file( + os.path.join( + utils_impl.get_or_create_debug_dir(export_dir), + constants.DEBUG_INFO_FILENAME_PB), + graph_debug_info.SerializeToString(deterministic=True)) @tf_export( @@ -991,10 +1008,6 @@ def save(obj, export_dir, signatures=None, options=None): @end_compatibility """ options = options or save_options.SaveOptions() - if options.experimental_variable_policy._expand_distributed_variables(): # pylint:disable=protected-access - raise NotImplementedError( - "The VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES option is " - "not implemented in saved_model.save.") # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than # making a SavedModel proto and writing it directly. @@ -1002,7 +1015,7 @@ def save(obj, export_dir, signatures=None, options=None): meta_graph_def = saved_model.meta_graphs.add() _, exported_graph, object_saver, asset_info = _build_meta_graph( - obj, export_dir, signatures, options, meta_graph_def) + obj, signatures, options, meta_graph_def) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out @@ -1033,6 +1046,9 @@ def save(obj, export_dir, signatures=None, options=None): compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) file_io.atomic_write_string_to_file( path, saved_model.SerializeToString(deterministic=True)) + # Save debug info, if requested. + if options.save_debug_info: + _export_debug_info(exported_graph, export_dir) # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured @@ -1041,7 +1057,7 @@ def save(obj, export_dir, signatures=None, options=None): def export_meta_graph(obj, filename, signatures=None, options=None): - """Exports the MetaGraph proto to a file. + """Exports the MetaGraph proto of the `obj` to a file. This function goes through the same procedures saved_model.save goes to produce the given object's MetaGraph, then saves it to the given file. It @@ -1066,11 +1082,15 @@ def export_meta_graph(obj, filename, signatures=None, options=None): options = options or save_options.SaveOptions() export_dir = os.path.dirname(filename) meta_graph_def, exported_graph, _, _ = _build_meta_graph( - obj, export_dir, signatures, options) + obj, signatures, options) file_io.atomic_write_string_to_file( filename, meta_graph_def.SerializeToString(deterministic=True)) + # Save debug info, if requested. + if options.save_debug_info: + _export_debug_info(exported_graph, export_dir) + # Clean reference cycles so repeated export()s don't make work for the garbage # collector. Before this point, we need to keep references to captured # constants in the saved graph. @@ -1078,16 +1098,14 @@ def export_meta_graph(obj, filename, signatures=None, options=None): def _build_meta_graph_impl(obj, - export_dir, signatures, options, meta_graph_def=None): """Creates a MetaGraph containing the resources and functions of an object.""" if ops.inside_function(): raise AssertionError( - "tf.saved_model.save is not supported inside a traced " - "@tf.function. Move the call to the outer eagerly-executed " - "context.") + "tf.saved_model.save is not supported inside a traced @tf.function. " + "Move the call to the outer eagerly-executed context.") # pylint: enable=line-too-long if not isinstance(obj, base.Trackable): raise ValueError( @@ -1130,24 +1148,36 @@ def _build_meta_graph_impl(obj, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) - # Save debug info, if requested. - if options.save_debug_info: - graph_debug_info = _export_debug_info(exported_graph) - file_io.atomic_write_string_to_file( - os.path.join( - utils_impl.get_or_create_debug_dir(export_dir), - constants.DEBUG_INFO_FILENAME_PB), - graph_debug_info.SerializeToString(deterministic=True)) - return meta_graph_def, exported_graph, object_saver, asset_info def _build_meta_graph(obj, - export_dir, signatures, options, meta_graph_def=None): - """Creates a MetaGraph under a SaveContext.""" + """Creates a MetaGraph under a save context. + + Args: + obj: A trackable object to build the MetaGraph from. + signatures: Can be a `tf.function` with an input signature specified or the + result of `f.get_concrete_function` on a `@tf.function`-decorated function + `f`. `signatures` may also be a dictionary, in which case it maps from + signature keys to `tf.function` instances. If None, finds signature to + export from the `@tf.function`-decorated methods in `obj`. + options: `tf.saved_model.SaveOptions` object that specifies options for + saving. + meta_graph_def: Optional, the MetaGraphDef proto fill. + + Raises: + AssertionError: If `export_meta_graph` is executing inside a `tf.function`. + ValueError: If `obj` is not trackable. + + Returns: + meta_graph_def: Filled MetaGraphDef proto + exported_graph: `tf.Graph` object generated from `obj`. + object_saver: `util.TrackableSaver` of the `obj` and its dependencies. + asset_info: `_AssetInfo` tuple containing external assets in the `obj`. + """ + with save_context.save_context(options): - return _build_meta_graph_impl(obj, export_dir, signatures, options, - meta_graph_def) + return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) diff --git a/tensorflow/python/saved_model/save_options.py b/tensorflow/python/saved_model/save_options.py index ae4421bf022..f6330848441 100644 --- a/tensorflow/python/saved_model/save_options.py +++ b/tensorflow/python/saved_model/save_options.py @@ -56,13 +56,11 @@ class VariablePolicy(enum.Enum): Distributed variables are still saved as one variable under this policy. EXPAND_DISTRIBUTED_VARIABLES - Distributed variables will be explicitly expanded into their respective - distributed replicas, and their assigned devices will be saved. This is - useful when one wants to use the model for training in environments where - the original distribution strategy is not available. Checkpoints are - currently incompatible with this option, so it is not implemented in - `saved_model.save` (only the internal `saved_model.export_meta_graph` API - supports it for now). + Distributed variables will be saved with information about their components, + allowing for their restoration on load. Also, the saved graph will contain + references to those variables. This is useful when one wants to use the + model for training in environments where the original distribution strategy + is not available. """ NONE = None diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index f3d78881429..de828f7ea1b 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -552,10 +552,16 @@ class SaveTest(test.TestCase, parameterized.TestCase): self.assertEmpty(v1.device) @parameterized.named_parameters( - ("_ExpandDistributedVariables", - save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES), - ("_DiscardDistributedVariables", save_options.VariablePolicy.NONE)) - def test_expand_distributed_variables(self, expand_strategy): + ("_ExpandDistributedVariablesWithPolicy", + save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, True), + ("_ExpandDistributedVariablesWithoutPolicy", + save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, False), + ("_DiscardDistributedVariablesWithPolicy", + save_options.VariablePolicy.NONE, True), + ("_DiscardDistributedVariablesWithoutPolicy", + save_options.VariablePolicy.NONE, False)) + def test_expand_distributed_variables(self, expand_strategy, policy): + # 1. Create a context with both CPU:0 and CPU:1. context._reset_context() cpus = context.context().list_physical_devices("CPU") if len(cpus) == 1: @@ -566,8 +572,11 @@ class SaveTest(test.TestCase, parameterized.TestCase): ]) context.ensure_initialized() + # 2. Create and save a model under a mirrored strategy. file_name = os.path.join(self.get_temp_dir(), "saved_model.pb") - with mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]).scope(): + strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]) + strategy.extended._use_var_policy = policy + with strategy.scope(): root = tracking.AutoTrackable() root.v = variables.Variable([1., 1.], name="v") @@ -582,36 +591,35 @@ class SaveTest(test.TestCase, parameterized.TestCase): filename=file_name, options=save_options.SaveOptions( experimental_variable_policy=expand_strategy)) - graph_def = meta_graph.read_meta_graph_file(file_name).graph_def - v0 = next((n for n in graph_def.node if n.name == "v"), None) - v1 = next((n for n in graph_def.node if n.name == "v/replica_1"), None) - self.assertIsNotNone(v0) + + # 3. Read the output file and test behavior. + meta_graph_def = meta_graph.read_meta_graph_file(file_name) + object_graph = meta_graph_def.object_graph_def + graph_def = meta_graph_def.graph_def + v = next((n.variable + for n in object_graph.nodes + if n.HasField("variable") and n.variable.name == "v"), None) saved_function = next((f for f in graph_def.library.function if "inference_f_" in f.signature.name), None) self.assertIsNotNone(saved_function) if (expand_strategy == save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES): - self.assertIsNotNone(v1) # experimental_save_variable_devices should have been automatically set. + self.assertIn("CPU:0", v.device) + components = v.experimental_distributed_variable_components + self.assertLen(components, 2) + v0 = next((x for x in components if x.name == "v"), None) + v1 = next((x for x in components if x.name == "v/replica_1"), None) + self.assertIsNotNone(v0) + self.assertIsNotNone(v1) self.assertIn("CPU:0", v0.device) self.assertIn("CPU:1", v1.device) self.assertLen(saved_function.signature.input_arg, 2) else: - self.assertIsNone(v1) - self.assertEmpty(v0.device) + self.assertEmpty(v.device) + self.assertEmpty(v.experimental_distributed_variable_components) self.assertLen(saved_function.signature.input_arg, 1) - def test_expand_distributed_variables_not_allowed(self): - root = tracking.AutoTrackable() - with self.assertRaisesRegex(NotImplementedError, - "not implemented in saved_model.save"): - save.save( - obj=root, - export_dir="", - options=save_options.SaveOptions( - experimental_variable_policy=save_options.VariablePolicy - .EXPAND_DISTRIBUTED_VARIABLES)) - def test_save_uninitialized_variable(self): root = tracking.AutoTrackable() root.uninitialized_variable = resource_variable_ops.UninitializedVariable( @@ -744,7 +752,6 @@ class SavingOptionsTest(test.TestCase): root.f = def_function.function( lambda x: 2. * x, input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) - root.f(constant_op.constant(1.)) save_dir = os.path.join(self.get_temp_dir(), "saved_model") options = save_options.SaveOptions(function_aliases={ "my_func": root.f, diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index fb89f3f54be..36165deeaad 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -314,19 +314,10 @@ static std::string TFE_GetCompilerIr(py::handle& ctx, TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs); - std::vector input_tensors; + std::vector input_handles; for (TFE_TensorHandle* tensor_handle : handles) { AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle); - TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle); - - const Tensor* t; - Status st = th->Tensor(&t); - if (!st.ok()) { - ThrowValueError( - absl::StrFormat("Could not resolve tensor: '%s'", st.error_message()) - .c_str()); - } - input_tensors.push_back(t); + input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle)); } DeviceNameUtils::ParsedName input_device_name; @@ -347,7 +338,7 @@ static std::string TFE_GetCompilerIr(py::handle& ctx, xla::StatusOr hlo_text = GetCompilerIr(selected_stage, context->pflr(), concrete_function_name, - *selected_device, input_tensors); + *selected_device, context, input_handles); if (!hlo_text.ok()) { ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'", diff --git a/tensorflow/python/tools/api/generator/BUILD b/tensorflow/python/tools/api/generator/BUILD index 5e6172edd39..c8d6eded2d4 100644 --- a/tensorflow/python/tools/api/generator/BUILD +++ b/tensorflow/python/tools/api/generator/BUILD @@ -87,6 +87,9 @@ py_test( "//tensorflow/python:framework_combinations", "//tensorflow/python:modules_with_exports", "//tensorflow/python:no_contrib", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:multi_process_runner", + "//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/tools/api/generator:create_python_api", ], ) diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 48c30ba1c1d..c04ab6900e9 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -6,6 +6,9 @@ TENSORFLOW_API_INIT_FILES = [ "__init__.py", "__internal__/__init__.py", "__internal__/decorator/__init__.py", + "__internal__/distribute/__init__.py", + "__internal__/distribute/combinations/__init__.py", + "__internal__/distribute/multi_process_runner/__init__.py", "__internal__/test/__init__.py", "__internal__/test/combinations/__init__.py", "__internal__/tracking/__init__.py", diff --git a/tensorflow/python/tools/api/generator/output_init_files_test.py b/tensorflow/python/tools/api/generator/output_init_files_test.py index fc2ad26dc9c..25035e0567a 100644 --- a/tensorflow/python/tools/api/generator/output_init_files_test.py +++ b/tensorflow/python/tools/api/generator/output_init_files_test.py @@ -25,6 +25,8 @@ import sys from tensorflow import python as _tf_for_api_traversal from tensorflow.lite.python import lite as _tflite_for_api_traversal from tensorflow.python import modules_with_exports +from tensorflow.python.distribute import multi_process_runner +from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.framework import combinations from tensorflow.python.framework import test_combinations # pylint: enable=unused-import diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index bdbdd3499ad..0c8b8f5576b 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -590,6 +590,9 @@ def _create_example_string(example_dict): example.features.feature[feature_name].float_list.value.extend( feature_list) elif isinstance(feature_list[0], str): + example.features.feature[feature_name].bytes_list.value.extend( + [f.encode('utf8') for f in feature_list]) + elif isinstance(feature_list[0], bytes): example.features.feature[feature_name].bytes_list.value.extend( feature_list) elif isinstance(feature_list[0], six.integer_types): diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index 84283ec7dd7..2580cbd73ca 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for SavedModelCLI tool. - -""" +"""Tests for SavedModelCLI tool.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -29,6 +27,7 @@ from absl.testing import parameterized import numpy as np from six import StringIO +from tensorflow.core.example import example_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.debug.wrappers import local_cli_wrapper @@ -177,8 +176,7 @@ signature_def['serving_default']: saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') dummy_model = DummyModel() # Call with specific values to create new polymorphic function traces. - dummy_model.func1( - constant_op.constant(5), constant_op.constant(9), True) + dummy_model.func1(constant_op.constant(5), constant_op.constant(9), True) dummy_model(constant_op.constant(5)) save.save(dummy_model, saved_model_dir) self.parser = saved_model_cli.create_parser() @@ -391,6 +389,33 @@ Defined Functions: self.assertTrue(len(input_dict) == 2) self.assertTrue(len(input_expr_dict) == 2) + def testInputPreProcessExamplesWithStrAndBytes(self): + input_examples_str = 'inputs=[{"text":["foo"], "bytes":[b"bar"]}]' + input_dict = saved_model_cli.preprocess_input_examples_arg_string( + input_examples_str) + feature = example_pb2.Example.FromString(input_dict['inputs'][0]) + self.assertProtoEquals( + """ + features { + feature { + key: "bytes" + value { + bytes_list { + value: "bar" + } + } + } + feature { + key: "text" + value { + bytes_list { + value: "foo" + } + } + } + } + """, feature) + def testInputPreProcessFileNames(self): input_str = (r'inputx=C:\Program Files\data.npz[v:0];' r'input:0=c:\PROGRA~1\data.npy') @@ -680,10 +705,11 @@ Defined Functions: def fake_wrapper_session(sess): return sess - with test.mock.patch.object(local_cli_wrapper, - 'LocalCLIDebugWrapperSession', - side_effect=fake_wrapper_session, - autospec=True) as fake: + with test.mock.patch.object( + local_cli_wrapper, + 'LocalCLIDebugWrapperSession', + side_effect=fake_wrapper_session, + autospec=True) as fake: saved_model_cli.run(args) fake.assert_called_with(test.mock.ANY) @@ -720,11 +746,11 @@ Defined Functions: self.parser = saved_model_cli.create_parser() base_path = test.test_src_dir_path(SAVED_MODEL_PATH) output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir') - args = self.parser.parse_args( - ['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve', - '--output_prefix', output_dir, - '--cpp_class', 'Compiled', - '--signature_def_key', 'MISSING']) + args = self.parser.parse_args([ + 'aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve', + '--output_prefix', output_dir, '--cpp_class', 'Compiled', + '--signature_def_key', 'MISSING' + ]) with self.assertRaisesRegex(ValueError, 'Unable to find signature_def'): saved_model_cli.aot_compile_cpu(args) @@ -785,9 +811,8 @@ Defined Functions: output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out') args = self.parser.parse_args([ 'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve', - '--signature_def_key', 'func', - '--output_prefix', output_prefix, '--variables_to_feed', - variables_to_feed, '--cpp_class', 'Generated' + '--signature_def_key', 'func', '--output_prefix', output_prefix, + '--variables_to_feed', variables_to_feed, '--cpp_class', 'Generated' ]) # Use the default seving signature_key. with test.mock.patch.object(logging, 'warn') as captured_warn: saved_model_cli.aot_compile_cpu(args) @@ -812,8 +837,8 @@ Defined Functions: if func == dummy_model.func_write: # Writeable variables setters do not preserve constness. self.assertIn('set_var_param_write_var_data(float', header_contents) - self.assertNotIn( - 'set_var_param_write_var_data(const float', header_contents) + self.assertNotIn('set_var_param_write_var_data(const float', + header_contents) makefile_contents = file_io.read_file_to_string( '{}_makefile.inc'.format(output_prefix)) diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index 32973ac08b9..0afe2086d93 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -226,6 +226,7 @@ pytype_library( "//tensorflow/python:dtypes", "//tensorflow/python:framework", "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", "//tensorflow/python:platform_analytics", "//tensorflow/python:tensor_shape", "//tensorflow/python:tpu_ops_gen", @@ -508,6 +509,7 @@ pytype_library( srcs = ["tpu_embedding_v2_utils.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:init_ops_v2", "//tensorflow/python:variable_scope", "//tensorflow/python/distribute:device_util", "//tensorflow/python/distribute:sharded_variable", diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index 0942ce814f0..41da6251ba6 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -49,6 +49,7 @@ from tensorflow.python.ops import summary_ops_v2 as summary from tensorflow.python.ops import variable_scope from tensorflow.python.platform import analytics from tensorflow.python.platform import gfile +from tensorflow.python.platform import remote_utils from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary_iterator from tensorflow.python.tpu import tensor_tracer_flags @@ -344,24 +345,6 @@ def keras_layer_tracepoint(layer, checkpoint_name): return layer -def _trace_files_need_precreated(output_dir): - """Return True if trace files must be pre-created by users.""" - - if not output_dir.startswith('/'): - return False - if len(output_dir) < 5: - return False - if output_dir[2] != 'n': - return False - if output_dir[3] != 's': - return False - if output_dir[1] != 'c': - return False - if output_dir[4] != '/': - return False - return True - - class TensorTracer(object): """A software construct for tracing tensor values in a TF graph. @@ -927,7 +910,9 @@ class TensorTracer(object): msg = '"%s"' % tensor_name if self._parameters.trace_dir: - output_path = os.path.join(self._parameters.trace_dir, _TRACE_FILE_NAME) + output_path = os.path.join( + self._parameters.trace_dir, + _TRACE_FILE_NAME + self._get_outfile_suffix()) output_stream = _OUTPUT_STREAM_ESCAPE + output_path else: output_stream = sys.stderr @@ -1229,20 +1214,10 @@ class TensorTracer(object): # Output files are handled by tf.summary operations, no need to precreate # them. return - if _trace_files_need_precreated(self._parameters.trace_dir): - for replica_id in range(0, self._tt_config.num_replicas): - trace_file_path = os.path.join( - self._parameters.trace_dir, - _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id - if not gfile.Exists(trace_file_path): - raise RuntimeError( - '%s must be pre-created with the ' - 'appropriate properties.'%trace_file_path) - else: + if not gfile.Exists(self._parameters.trace_dir): + file_io.recursive_create_dir(self._parameters.trace_dir) if not gfile.Exists(self._parameters.trace_dir): - file_io.recursive_create_dir(self._parameters.trace_dir) - if not gfile.Exists(self._parameters.trace_dir): - raise RuntimeError('Failed to create %s'%self._parameters.trace_dir) + raise RuntimeError('Failed to create %s'%self._parameters.trace_dir) def _create_temp_cache(self, num_traced_tensors, num_signatures): """Creates a temporary cache with the given dimensions. @@ -1317,18 +1292,89 @@ class TensorTracer(object): tensor_tracer_flags.TRACE_MODE_SUMMARY, tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) - def _generate_flush_cache_op(self, num_replicas, on_tpu): + def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream, + tensor_trace_order): + """Generates a print operation to print trace inspection. + + Args: + cache: Tensor storing the trace results for the step. + replica_id: Tensor storing the replica id of the running core. + step_num: Step number. + output_stream: Where to print the outputs, e.g., file path, or sys.stderr. + tensor_trace_order: TensorTraceOrder object holding tensorname to id map. + + Returns: + The Op to flush the cache to file. + """ + def _inspect_tensor(tensor): + """Returns the text to be printed for inspection output.""" + if (self._parameters.trace_mode == + tensor_tracer_flags.TRACE_MODE_NAN_INF): + return control_flow_ops.cond( + math_ops.greater(tensor, 0.0), + lambda: 'has NaNs/Infs!', + lambda: 'has no NaNs or Infs.') + else: + return tensor + + # Check if the cache includes any nan or inf + if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: + # Cache has 1s or 0s if the mode is NaN_INF + step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0) + else: + # Cache has the actual numerics for other modes. + step_has_nan_or_inf = math_ops.reduce_any( + gen_math_ops.logical_or( + gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache))) + + # Summarizing message for each step. + step_error_message = control_flow_ops.cond( + step_has_nan_or_inf, + lambda: 'NaNs or Infs in the step!', + lambda: 'No numerical issues have been found for the step.') + + # No need to print core numbers if the cache is merged already. + if self._parameters.collect_summary_per_core: + stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->', + step_error_message, + 'Printing tensors for mode:%s...' % self._parameters.trace_mode] + else: + stats = ['\n\n', 'step:', step_num, '-->', step_error_message, + 'Printing tensors for mode:%s...' % self._parameters.trace_mode] + + for tensor_name, cache_idx in sorted( + tensor_trace_order.tensorname_to_cache_idx.items(), + key=lambda item: item[1]): + if self._parameters.collect_summary_per_core: + stats.extend([ + '\n', 'core:', replica_id, ',', 'step:', step_num, ',', + tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) + else: + stats.extend([ + '\n', 'step:', step_num, ',', + tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) + return logging_ops.print_v2(*stats, summarize=-1, + output_stream=output_stream) + + def _get_outfile_suffix(self): + if remote_utils.is_remote_path(self._parameters.trace_dir): + return remote_utils.get_appendable_file_encoding() + else: + return '' + + def _generate_flush_cache_op(self, num_replicas, on_tpu, tensor_trace_order): """Generates an Op that will flush the cache to file. Args: num_replicas: total number of replicas. on_tpu: if the graph is executed on TPU. + tensor_trace_order: TensorTraceOrder object holding tensorname to id map. Returns: The Op to flush the cache to file. """ - def _flush_fun(cache, replica_id): + def _flush_fun(cache, replica_id, step_num): """Flushes the cache to a file corresponding to replica_id.""" def _f(file_index): @@ -1339,19 +1385,27 @@ class TensorTracer(object): if self._parameters.trace_dir: output_path = (os.path.join(self._parameters.trace_dir, _COMPACT_TRACE_FILE_PREFIX) - + replica_str) + + replica_str + self._get_outfile_suffix()) output_stream = _OUTPUT_STREAM_ESCAPE + output_path else: output_stream = sys.stderr new_step_line = _REPLICA_ID_TAG + replica_str print_ops = [] - for i in range(self._num_signature_dimensions()): - print_ops.append(logging_ops.print_v2( - new_step_line, '\n', - cache[:, i], '\n', - summarize=-1, - output_stream=output_stream)) + if self._parameters.inspect_trace: + if self._num_signature_dimensions() > 1: + raise ValueError('Inspecting multi signatures are not supported.') + print_ops.append(self._inspect_summary_cache( + cache=cache, replica_id=replica_id, step_num=step_num, + output_stream=output_stream, + tensor_trace_order=tensor_trace_order)) + else: + for i in range(self._num_signature_dimensions()): + print_ops.append(logging_ops.print_v2( + new_step_line, '\n', + cache[:, i], '\n', + summarize=-1, + output_stream=output_stream)) with ops.control_dependencies(print_ops): return constant_op.constant(0).op return _print_cache @@ -1388,10 +1442,12 @@ class TensorTracer(object): cache_val = self.merge_caches_on_tpu(cache_val) cache_val = self.aggregate_global_cache(cache_val)[0] - flush_op = tpu.outside_compilation(_flush_fun, - cache_val, self._replica_id) + flush_op = tpu.outside_compilation( + _flush_fun, cache_val, self._replica_id, + array_ops.identity(training_util.get_or_create_global_step())) else: - flush_op = _flush_fun(cache_val, self._replica_id) + flush_op = _flush_fun(cache_val, self._replica_id, + training_util.get_or_create_global_step()) if self._use_temp_cache(): with ops.control_dependencies([flush_op]): return constant_op.constant(0).op @@ -1405,13 +1461,15 @@ class TensorTracer(object): with ops.control_dependencies([assign_op]): return constant_op.constant(0).op - def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu): + def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu, + tensor_trace_order): """Flushes the intermediate tensor values in the graph to the cache. Args: tensor_fetches: list of tensor results returned by the model_fn. op_fetches: list of ops that are returned by the model_fn, e.g., train_op. on_tpu: if the graph is executed on TPU. + tensor_trace_order: TensorTraceOrder object holding tensorname to id map. Returns: An identical copy of tensor_fetches. @@ -1421,7 +1479,7 @@ class TensorTracer(object): with ops.control_dependencies(op_fetches + [tensor.op for tensor in tensor_fetches]): flush_cache_op = self._generate_flush_cache_op( - self._tt_config.num_replicas, on_tpu) + self._tt_config.num_replicas, on_tpu, tensor_trace_order) return control_flow_ops.tuple(tensor_fetches, control_inputs=[flush_cache_op]) @@ -1837,7 +1895,8 @@ class TensorTracer(object): del self._host_call_fn[_TT_HOSTCALL_KEY] else: processed_t_fetches = self._flush_tensor_values_cache( - processed_t_fetches, op_fetches, on_tpu=on_tpu) + processed_t_fetches, op_fetches, on_tpu=on_tpu, + tensor_trace_order=tensor_trace_order) # processed_t_fetches is a list at this point. Convert it to the same # format as given in tensor_fetches. diff --git a/tensorflow/python/tpu/tensor_tracer_flags.py b/tensorflow/python/tpu/tensor_tracer_flags.py index 6d84314185b..ba375737866 100644 --- a/tensorflow/python/tpu/tensor_tracer_flags.py +++ b/tensorflow/python/tpu/tensor_tracer_flags.py @@ -69,6 +69,7 @@ FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache' +FLAG_NAME_INSPECT_TRACE = 'inspect_trace' FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory' _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') @@ -125,6 +126,7 @@ class TTParameters(object): TRACE_MODE_MAX_ABS, TRACE_MODE_SUMMARY) self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR) + self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE) self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR) _, self.graph_dump_path = self.get_flag_value( @@ -250,7 +252,8 @@ class TTParameters(object): FLAG_NAME_OP_RANGE, FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE, - FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR + FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR, + FLAG_NAME_INSPECT_TRACE ] tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) if not tensor_tracer_flags: diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py index 9a0b03da744..3a45a579631 100644 --- a/tensorflow/python/tpu/tpu_embedding.py +++ b/tensorflow/python/tpu/tpu_embedding.py @@ -23,6 +23,8 @@ import copy import math import re +from typing import Optional + import six from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 @@ -314,6 +316,12 @@ AdamSlotVariableNames = collections.namedtuple( AdagradSlotVariableName = collections.namedtuple( 'AdagradSlotVariableName', ['accumulator']) +MomentumSlotVariableName = collections.namedtuple('MomentumSlotVariableName', + ['momenta']) + +RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames', + ['ms', 'mom']) + ProximalAdagradSlotVariableName = collections.namedtuple( 'ProximalAdagradSlotVariableName', ['accumulator']) @@ -326,6 +334,12 @@ ProximalYogiSlotVariableNames = collections.namedtuple( AdamSlotVariables = collections.namedtuple( 'AdamSlotVariables', ['m', 'v']) +MomentumSlotVariable = collections.namedtuple('MomentumSlotVariable', + ['momenta']) + +RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables', + ['ms', 'mom']) + AdagradSlotVariable = collections.namedtuple( 'AdagradSlotVariable', ['accumulator']) @@ -348,9 +362,15 @@ VariablesAndOps = collections.namedtuple( class _OptimizationParameters(object): """Parameters common to all optimizations.""" - def __init__(self, learning_rate, use_gradient_accumulation, clip_weight_min, - clip_weight_max, weight_decay_factor, - multiply_weight_decay_factor_by_learning_rate): + def __init__( + self, + learning_rate: float, + use_gradient_accumulation: bool, + clip_weight_min: Optional[float], + clip_weight_max: Optional[float], + weight_decay_factor: Optional[float], + multiply_weight_decay_factor_by_learning_rate: Optional[bool], + ): self.learning_rate = learning_rate self.use_gradient_accumulation = use_gradient_accumulation self.clip_weight_min = clip_weight_min @@ -380,14 +400,16 @@ class AdagradParameters(_OptimizationParameters): """ - def __init__(self, - learning_rate, - initial_accumulator=0.1, - use_gradient_accumulation=True, - clip_weight_min=None, - clip_weight_max=None, - weight_decay_factor=None, - multiply_weight_decay_factor_by_learning_rate=None): + def __init__( + self, + learning_rate: float, + initial_accumulator: float = 0.1, + use_gradient_accumulation: bool = True, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + ): """Optimization parameters for Adagrad. Args: @@ -422,16 +444,18 @@ class ProximalAdagradParameters(_OptimizationParameters): for more details. """ - def __init__(self, - learning_rate, - initial_accumulator=0.1, - l1_regularization_strength=0.0, - l2_regularization_strength=0.0, - use_gradient_accumulation=True, - clip_weight_min=None, - clip_weight_max=None, - weight_decay_factor=None, - multiply_weight_decay_factor_by_learning_rate=None): + def __init__( + self, + learning_rate: float, + initial_accumulator: float = 0.1, + l1_regularization_strength: float = 0.0, + l2_regularization_strength: float = 0.0, + use_gradient_accumulation: bool = True, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + ): """Optimization parameters for Adagrad. Args: @@ -490,18 +514,20 @@ class AdamParameters(_OptimizationParameters): """ - def __init__(self, - learning_rate, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - lazy_adam=True, - sum_inside_sqrt=True, - use_gradient_accumulation=True, - clip_weight_min=None, - clip_weight_max=None, - weight_decay_factor=None, - multiply_weight_decay_factor_by_learning_rate=None): + def __init__( + self, + learning_rate: float, + beta1: float = 0.9, + beta2: float = 0.999, + epsilon: float = 1e-08, + lazy_adam: bool = True, + sum_inside_sqrt: bool = True, + use_gradient_accumulation: bool = True, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + ): """Optimization parameters for Adam. Args: @@ -567,20 +593,22 @@ class FtrlParameters(_OptimizationParameters): """ - def __init__(self, - learning_rate, - learning_rate_power=-0.5, - initial_accumulator_value=0.1, - l1_regularization_strength=0.0, - l2_regularization_strength=0.0, - use_gradient_accumulation=True, - clip_weight_min=None, - clip_weight_max=None, - weight_decay_factor=None, - multiply_weight_decay_factor_by_learning_rate=None, - multiply_linear_by_learning_rate=False, - beta=0, - allow_zero_accumulator=False): + def __init__( + self, + learning_rate: float, + learning_rate_power: float = -0.5, + initial_accumulator_value: float = 0.1, + l1_regularization_strength: float = 0.0, + l2_regularization_strength: float = 0.0, + use_gradient_accumulation: bool = True, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + multiply_linear_by_learning_rate: bool = False, + beta: float = 0, + allow_zero_accumulator: bool = False, + ): """Optimization parameters for Ftrl. Implements FTRL as described in the following [paper]( @@ -661,19 +689,21 @@ class ProximalYogiParameters(_OptimizationParameters): """ # pylint: enable=line-too-long - def __init__(self, - learning_rate=0.01, - beta1=0.9, - beta2=0.999, - epsilon=1e-3, - l1_regularization_strength=0.0, - l2_regularization_strength=0.0, - initial_accumulator_value=1e-6, - use_gradient_accumulation=True, - clip_weight_min=None, - clip_weight_max=None, - weight_decay_factor=None, - multiply_weight_decay_factor_by_learning_rate=None): + def __init__( + self, + learning_rate: float = 0.01, + beta1: float = 0.9, + beta2: float = 0.999, + epsilon: float = 1e-3, + l1_regularization_strength: float = 0.0, + l2_regularization_strength: float = 0.0, + initial_accumulator_value: float = 1e-6, + use_gradient_accumulation: bool = True, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + ): """Optimization parameters for Proximal Yogi. Args: @@ -724,6 +754,135 @@ class ProximalYogiParameters(_OptimizationParameters): self.initial_accumulator_value = initial_accumulator_value +class MomentumParameters(_OptimizationParameters): + """Optimization parameters for Momentum with TPU embeddings. + + Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the + `optimization_parameters` argument to set the optimizer and its parameters. + See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` + for more details. + + ``` + estimator = tf.estimator.tpu.TPUEstimator( + ... + embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( + ... + optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), + ...)) + ``` + + """ + + def __init__( + self, + learning_rate: float, + momentum: float, + use_nesterov: bool = False, + use_gradient_accumulation: bool = True, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + ): + """Optimization parameters for momentum. + + Args: + learning_rate: a floating point value. The learning rate. + momentum: A `Tensor` or a floating point value. The momentum. + use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., + 2013). This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + This implementation is an approximation of the original formula, valid + for high values of momentum. It will compute the "adjusted gradient" in + NAG by assuming that the new gradient will be estimated by the current + average gradient plus the product of momentum and the change in the + average gradient. + use_gradient_accumulation: setting this to `False` makes embedding + gradients calculation less accurate but faster. Please see + `optimization_parameters.proto` for details. + clip_weight_min: the minimum value to clip by; None means -infinity. + clip_weight_max: the maximum value to clip by; None means +infinity. + weight_decay_factor: amount of weight decay to apply; None means that the + weights are not decayed. + multiply_weight_decay_factor_by_learning_rate: if true, + `weight_decay_factor` is multiplied by the current learning rate. + """ + super(MomentumParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=use_gradient_accumulation, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + ) + self.momentum = momentum + self.use_nesterov = use_nesterov + + +class RMSPropParameters(_OptimizationParameters): + """Optimization parameters for RMSProp with TPU embeddings. + + Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the + `optimization_parameters` argument to set the optimizer and its parameters. + See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` + for more details. + + ``` + estimator = tf.estimator.tpu.TPUEstimator( + ... + embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( + ... + optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), + ...)) + ``` + + """ + + def __init__( + self, + learning_rate: float, + rho: float, + momentum: float, + epsilon: float, + use_gradient_accumulation: bool = True, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + ): + """Optimization parameters for RMS prop. + + Args: + learning_rate: a floating point value. The learning rate. + rho: Discounting factor for the history/coming gradient + momentum: A scalar tensor. + epsilon: Small value to avoid zero denominator. + use_gradient_accumulation: setting this to `False` makes embedding + gradients calculation less accurate but faster. Please see + `optimization_parameters.proto` for details. for details. + clip_weight_min: the minimum value to clip by; None means -infinity. + clip_weight_max: the maximum value to clip by; None means +infinity. + weight_decay_factor: amount of weight decay to apply; None means that the + weights are not decayed. + multiply_weight_decay_factor_by_learning_rate: if true, + `weight_decay_factor` is multiplied by the current learning rate. + """ + super(RMSPropParameters, self).__init__( + learning_rate=learning_rate, + use_gradient_accumulation=use_gradient_accumulation, + clip_weight_min=clip_weight_min, + clip_weight_max=clip_weight_max, + weight_decay_factor=weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate=( + multiply_weight_decay_factor_by_learning_rate), + ) + self.rho = rho + self.momentum = momentum + self.epsilon = epsilon + + @tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters']) class StochasticGradientDescentParameters(_OptimizationParameters): """Optimization parameters for stochastic gradient descent for TPU embeddings. @@ -744,12 +903,14 @@ class StochasticGradientDescentParameters(_OptimizationParameters): """ - def __init__(self, - learning_rate, - clip_weight_min=None, - clip_weight_max=None, - weight_decay_factor=None, - multiply_weight_decay_factor_by_learning_rate=None): + def __init__( + self, + learning_rate: float, + clip_weight_min: Optional[float] = None, + clip_weight_max: Optional[float] = None, + weight_decay_factor: Optional[float] = None, + multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, + ): """Optimization parameters for stochastic gradient descent. Args: @@ -1636,7 +1797,7 @@ def _validate_optimization_parameters(optimization_parameters, else: # Missing global optimization_parameters. if tbl_optimizer_missing: - ValueError('`optimization_parameters` is missing.') + raise ValueError('`optimization_parameters` is missing.') class _OptimizerHandler(object): @@ -2099,6 +2260,173 @@ class _ProximalYogiHandler(_OptimizerHandler): return slot_variables, load_ops_fn, retrieve_ops_fn +class _MomentumHandler(_OptimizerHandler): + """Handles Momentum specific logic.""" + + def set_optimization_parameters(self, table_descriptor): + (table_descriptor.optimization_parameters.momentum.SetInParent()) + table_descriptor.optimization_parameters.momentum.momentum = ( + self._optimization_parameters.momentum) + table_descriptor.optimization_parameters.momentum.use_nesterov = ( + self._optimization_parameters.use_nesterov) + + def get_default_slot_variable_names(self, table): + return MomentumSlotVariableName('{}/{}'.format(table, 'Momentum')) + + def create_variables_and_ops(self, table, slot_variable_names, num_hosts, + table_config, table_variables, config_proto): + + momenta_initializer = init_ops.zeros_initializer() + momenta_variables = _create_partitioned_variables( + name=slot_variable_names.momenta, + num_hosts=num_hosts, + vocabulary_size=table_config.vocabulary_size, + embedding_dimension=table_config.dimension, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + initializer=momenta_initializer) + slot_variables = MomentumSlotVariable(momenta_variables) + + def load_ops_fn(): + """Returns the retrieve ops for Momentum embedding tables. + + Returns: + A list of ops to load embedding and slot variables from CPU to TPU. + """ + load_op_list = [] + config = config_proto + for host_id, table_variable, momenta_variable in (zip( + range(num_hosts), table_variables, momenta_variables)): + with ops.colocate_with(table_variable): + load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters( + parameters=table_variable, + momenta=momenta_variable, + table_name=table, + num_shards=num_hosts, + shard_id=host_id, + config=config, + ) + config = None + load_op_list.append(load_parameters_op) + return load_op_list + + def retrieve_ops_fn(): + """Returns the retrieve ops for Momentum embedding tables. + + Returns: + A list of ops to retrieve embedding and slot variables from TPU to CPU. + """ + retrieve_op_list = [] + config = config_proto + for host_id, table_variable, momenta_variable in (zip( + range(num_hosts), table_variables, momenta_variables)): + with ops.colocate_with(table_variable): + retrieved_table, retrieved_momenta = ( + tpu_ops.retrieve_tpu_embedding_momentum_parameters( + table_name=table, + num_shards=num_hosts, + shard_id=host_id, + config=config, + )) + retrieve_parameters_op = control_flow_ops.group( + state_ops.assign(table_variable, retrieved_table), + state_ops.assign(momenta_variable, retrieved_momenta)) + config = None + retrieve_op_list.append(retrieve_parameters_op) + return retrieve_op_list + + return slot_variables, load_ops_fn, retrieve_ops_fn + + +class _RMSPropHandler(_OptimizerHandler): + """Handles RMS prop specific logic.""" + + def set_optimization_parameters(self, table_descriptor): + (table_descriptor.optimization_parameters.rms_prop.SetInParent()) + table_descriptor.optimization_parameters.rms_prop.rho = ( + self._optimization_parameters.rho) + table_descriptor.optimization_parameters.rms_prop.epsilon = ( + self._optimization_parameters.epsilon) + table_descriptor.optimization_parameters.rms_prop.momentum = ( + self._optimization_parameters.momentum) + + def get_default_slot_variable_names(self, table): + return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'), + '{}/{}/mom'.format(table, 'RMSProp')) + + def create_variables_and_ops(self, table, slot_variable_names, num_hosts, + table_config, table_variables, config_proto): + + ms_variables = _create_partitioned_variables( + name=slot_variable_names.ms, + num_hosts=num_hosts, + vocabulary_size=table_config.vocabulary_size, + embedding_dimension=table_config.dimension, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + initializer=init_ops.zeros_initializer(), + ) + mom_variables = _create_partitioned_variables( + name=slot_variable_names.mom, + num_hosts=num_hosts, + vocabulary_size=table_config.vocabulary_size, + embedding_dimension=table_config.dimension, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + initializer=init_ops.zeros_initializer(), + ) + slot_variables = RMSPropSlotVariables(ms_variables, mom_variables) + + def load_ops_fn(): + """Returns the retrieve ops for RMS Prop embedding tables. + + Returns: + A list of ops to load embedding and slot variables from CPU to TPU. + """ + load_op_list = [] + config = config_proto + for host_id, table_variable, ms_variable, mom_variable in (zip( + range(num_hosts), table_variables, ms_variables, mom_variables)): + with ops.colocate_with(table_variable): + load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters( + parameters=table_variable, + ms=ms_variable, + mom=mom_variable, + table_name=table, + num_shards=num_hosts, + shard_id=host_id, + config=config, + ) + config = None + load_op_list.append(load_parameters_op) + return load_op_list + + def retrieve_ops_fn(): + """Returns the retrieve ops for RMS Prop embedding tables. + + Returns: + A list of ops to retrieve embedding and slot variables from TPU to CPU. + """ + retrieve_op_list = [] + config = config_proto + for host_id, table_variable, ms_variable, mom_variable in (zip( + range(num_hosts), table_variables, ms_variables, mom_variables)): + with ops.colocate_with(table_variable): + retrieved_table, retrieved_ms, retrieved_mom = ( + tpu_ops.retrieve_tpu_embedding_rms_prop_parameters( + table_name=table, + num_shards=num_hosts, + shard_id=host_id, + config=config, + )) + retrieve_parameters_op = control_flow_ops.group( + state_ops.assign(table_variable, retrieved_table), + state_ops.assign(ms_variable, retrieved_ms), + state_ops.assign(mom_variable, retrieved_mom)) + config = None + retrieve_op_list.append(retrieve_parameters_op) + return retrieve_op_list + + return slot_variables, load_ops_fn, retrieve_ops_fn + + class _StochasticGradientDescentHandler(_OptimizerHandler): """Handles stochastic gradient descent specific logic.""" @@ -2176,6 +2504,10 @@ def _get_optimization_handler(optimization_parameters): return _ProximalYogiHandler(optimization_parameters) elif isinstance(optimization_parameters, StochasticGradientDescentParameters): return _StochasticGradientDescentHandler(optimization_parameters) + elif isinstance(optimization_parameters, MomentumParameters): + return _MomentumHandler(optimization_parameters) + elif isinstance(optimization_parameters, RMSPropParameters): + return _RMSPropHandler(optimization_parameters) return NotImplementedError() diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py index 8487581346b..e04f1f0281a 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -21,6 +21,7 @@ from __future__ import unicode_literals import abc import math +import typing from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union import six @@ -620,6 +621,29 @@ class TableConfig(object): self.combiner = combiner self.name = name + def __repr__(self): + # If using the default initializer, just print "None" for clarity. + initializer = self.initializer + + if isinstance(initializer, init_ops_v2.TruncatedNormal): + # PY2 type checking can't infer type of initializer even after if. + initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer) + if (initializer.mean == 0.0 + and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))): # pytype: disable=module-attr (math.isclose not in PY2) + initializer = None + + return ( + "TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, " + "initializer={initializer!r}, optimizer={optimizer!r}, " + "combiner={combiner!r}, name={name!r})".format( + vocabulary_size=self.vocabulary_size, + dim=self.dim, + initializer=initializer, + optimizer=self.optimizer, + combiner=self.combiner, + name=self.name,) + ) + @tf_export("tpu.experimental.embedding.FeatureConfig") class FeatureConfig(object): @@ -697,3 +721,13 @@ class FeatureConfig(object): self.table = table self.max_sequence_length = max_sequence_length self.name = name + + def __repr__(self): + return ( + "FeatureConfig(table={table!r}, " + "max_sequence_length={max_sequence_length!r}, name={name!r})" + .format( + table=self.table, + max_sequence_length=self.max_sequence_length, + name=self.name) + ) diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py index 14dfb32e075..48797b00009 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py @@ -60,6 +60,34 @@ class TPUEmbeddingOptimizerTest(parameterized.TestCase, test.TestCase): self.assertEqual(1., opt.clip_gradient_max) +class ConfigTest(test.TestCase): + + def test_table_config_repr(self): + table = tpu_embedding_v2_utils.TableConfig( + vocabulary_size=2, dim=4, initializer=None, + combiner='sum', name='table') + + self.assertEqual( + repr(table), + 'TableConfig(vocabulary_size=2, dim=4, initializer=None, ' + 'optimizer=None, combiner=\'sum\', name=\'table\')') + + def test_feature_config_repr(self): + table = tpu_embedding_v2_utils.TableConfig( + vocabulary_size=2, dim=4, initializer=None, + combiner='sum', name='table') + + feature_config = tpu_embedding_v2_utils.FeatureConfig( + table=table, name='feature') + + self.assertEqual( + repr(feature_config), + 'FeatureConfig(table=TableConfig(vocabulary_size=2, dim=4, ' + 'initializer=None, optimizer=None, combiner=\'sum\', name=\'table\'), ' + 'max_sequence_length=0, name=\'feature\')' + ) + + if __name__ == '__main__': v2_compat.enable_v2_behavior() test.main() diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index 2c68dee61ed..4eb6429f3c8 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -24,6 +24,8 @@ import tempfile from absl.testing import parameterized import numpy as np +from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2 +from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2 from tensorflow.core.util import event_pb2 from tensorflow.python.distribute import tpu_strategy as tpu_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver @@ -291,7 +293,7 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): strategy = get_tpu_strategy() def host_computation(x): - summary.scalar("x", x, step=0) + scalar_summary_v2.scalar("x", x, step=0) return x * 2.0 @def_function.function @@ -317,7 +319,7 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): strategy = get_tpu_strategy() def host_computation(x): - summary.scalar("x", x, step=0) + scalar_summary_v2.scalar("x", x, step=0) return x * 2.0 @def_function.function @@ -352,7 +354,7 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): strategy = get_tpu_strategy() def host_computation(x): - summary.scalar("x", x, step=0) + scalar_summary_v2.scalar("x", x, step=0) return x * 2.0 @def_function.function @@ -490,7 +492,36 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase, strategy = get_tpu_strategy() def host_computation(x): - summary.scalar("x", x, step=0) + scalar_summary_v2.scalar("x", x, step=0) + return x * 2.0 + + @def_function.function + def step(): + + def computation(x): + x = x + 1.0 + y = host_computation(x) + return y + 1.0 + + return strategy.run(computation, args=(2.0,)) + + logdir = tempfile.mkdtemp() + summary_writer = summary.create_file_writer(logdir, flush_millis=10000) + with summary_writer.as_default(), summary.always_record_summaries(): + self.assertAllEqual( + strategy.experimental_local_results(step()), + constant_op.constant(7., shape=(strategy.num_replicas_in_sync))) + events = _events_from_logdir(self, logdir) + # There will be 2 entries: 1 summary file header entry, and 1 entry + # written by host. + self.assertLen(events, 2) + self.assertEqual(events[1].summary.value[0].tag, "x") + + def testHistogramSummaryWithAutoOutsideCompilation(self): + strategy = get_tpu_strategy() + + def host_computation(x): + histogram_summary_v2.histogram("x", x, step=0) return x * 2.0 @def_function.function @@ -514,7 +545,6 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase, # written by host. self.assertLen(events, 2) self.assertEqual(events[1].summary.value[0].tag, "x") - self.assertEqual(events[1].summary.value[0].simple_value, 3.0) @parameterized.parameters((True), (False)) def testSummaryControlFlowIfWithAutoOutsideCompilation( @@ -527,7 +557,7 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase, def computation(x): x = x + 1.0 if x < 5: - summary.scalar("x", x, step=0) + scalar_summary_v2.scalar("x", x, step=0) x = x * 2.0 return x + 1.0 @@ -553,7 +583,6 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase, # self.assertLen(events, 2) self.assertEqual(events[1].summary.value[0].tag, "cond/x") - self.assertEqual(events[1].summary.value[0].simple_value, 3.0) @test_util.disable_mlir_bridge( "TODO(b/168493455): Reenable this test once deadlock resolved." diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index d4ba15ee2db..b52d22c60dc 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -45,7 +45,8 @@ def initialize_tpu_system(cluster_resolver=None): cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: - The tf.tpu.Topology object for the topology of the TPU cluster. + The tf.tpu.Topology object for the topology of the TPU cluster. If called + inside tf.function, it returns the serialized topology object instead. Raises: RuntimeError: If running inside a tf.function. @@ -72,17 +73,17 @@ def initialize_tpu_system(cluster_resolver=None): logging.info("Initializing the TPU system: %s", tpu_name) - if context.executing_eagerly(): - # This function looks as it is for the following non-intuitive reasons. - # tpu.initialize_system creates a dummy op whose sole purpose is to trigger - # DistributedTPURewritePass. This pass actually adds real ops that - # initialize the TPU system. Thus, we can't simply run tpu.initialize_system - # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. - if tpu_name not in _LOCAL_MASTERS: - # Explicitly place the tpu.initialize_system in the first worker to - # avoid the output node match multiple devices error. - job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) + # This function looks as it is for the following non-intuitive reasons. + # tpu.initialize_system creates a dummy op whose sole purpose is to trigger + # DistributedTPURewritePass. This pass actually adds real ops that + # initialize the TPU system. Thus, we can't simply run tpu.initialize_system + # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. + if tpu_name not in _LOCAL_MASTERS: + # Explicitly place the tpu.initialize_system in the first worker to + # avoid the output node match multiple devices error. + job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) + if context.executing_eagerly(): @function.defun def _tpu_init_fn(): # In TF1, we usually close chips when compilation fails to clear the data @@ -121,8 +122,13 @@ def initialize_tpu_system(cluster_resolver=None): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) else: - raise RuntimeError("initialize_tpu_system is not supported within " - "tf.functions.") + with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access + serialized_topology = tpu.initialize_system( + job=job, compilation_failure_closes_chips=False) + # If initialize_tpu_system is called inside tf.function, we only return + # the serialized topology object as the tf.tpu.Topology object has to be + # constructed in eager mode. + return serialized_topology logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology) diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index 0cc1c45af4f..cf2d89b0d1f 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -1236,7 +1236,6 @@ cuda_py_test( grpc_enabled = True, main = "session_manager_test.py", python_version = "PY3", - tfrt_enabled = True, deps = [ ":checkpoint_management", ":saver", diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py index 6c8a6ceadc5..21cc0d15ad5 100644 --- a/tensorflow/python/training/ftrl.py +++ b/tensorflow/python/training/ftrl.py @@ -149,7 +149,7 @@ class FtrlOptimizer(optimizer.Optimizer): # TensorFlow ops do not need to include that parameter. self._adjusted_l2_regularization_strength_tensor = ops.convert_to_tensor( self._l2_regularization_strength + self._beta / - (2. * self._learning_rate), + (2. * math_ops.maximum(self._learning_rate, 1e-36)), name="adjusted_l2_regularization_strength") assert self._adjusted_l2_regularization_strength_tensor is not None self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta") diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 9f28264f52d..65f6ab67915 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -26,7 +26,6 @@ from __future__ import print_function import collections import os.path import time -import uuid import numpy as np from tensorflow.core.protobuf import meta_graph_pb2 @@ -255,7 +254,7 @@ class BaseSaverBuilder(object): _SHARDED_SUFFIX = array_ops.where( string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"), constant_op.constant(".part"), - constant_op.constant("_temp_%s/part" % uuid.uuid4().hex)) + constant_op.constant("_temp/part")) tmp_checkpoint_prefix = string_ops.string_join( [checkpoint_prefix, _SHARDED_SUFFIX]) diff --git a/tensorflow/python/training/saving/functional_saver.py b/tensorflow/python/training/saving/functional_saver.py index 1fa0da0f33c..62b8b72ce2a 100644 --- a/tensorflow/python/training/saving/functional_saver.py +++ b/tensorflow/python/training/saving/functional_saver.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import uuid - from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -252,7 +250,7 @@ class MultiDeviceSaver(object): sharded_suffix = array_ops.where( string_ops.regex_full_match(file_prefix, "^s3://.*"), constant_op.constant(".part"), - constant_op.constant("_temp_%s/part" % uuid.uuid4().hex)) + constant_op.constant("_temp/part")) tmp_checkpoint_prefix = string_ops.string_join( [file_prefix, sharded_suffix]) diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index ff2082cf819..1a547b89d35 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -587,6 +587,21 @@ def map_structure(func, *structure, **kwargs): `structure[i]`. All structures in `structure` must have the same arity, and the return value will contain results with the same structure layout. + Examples: + + 1. A single Python dict: + + >>> a = {"hello": 24, "world": 76} + >>> tf.nest.map_structure(lambda p: p * 2, a) + {'hello': 48, 'world': 152} + + 2. Multiple Python dictionaries: + + >>> d1 = {"hello": 24, "world": 76} + >>> d2 = {"hello": 36, "world": 14} + >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2) + {'hello': 60, 'world': 90} + Args: func: A callable that accepts as many arguments as there are structures. *structure: scalar, or tuple or dict or list of constructed scalars and/or diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index 33ef97f6712..84f592439ea 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -681,8 +681,8 @@ cc_library( ":device_memory_allocator", ":platform", ":stream_executor_headers", - "//tensorflow/core:allocator", "//tensorflow/core:lib", + "//tensorflow/core/framework:allocator", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", "@com_google_absl//absl/base:core_headers", diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD index ea65d7aee5c..caab6223fda 100644 --- a/tensorflow/stream_executor/cuda/BUILD +++ b/tensorflow/stream_executor/cuda/BUILD @@ -106,6 +106,10 @@ cc_library( # an intermediate target. cc_library(name = "ptxas_wrapper") +# Buildozer can not remove dependencies inside select guards, so we have to use +# an intermediate target. +cc_library(name = "fatbinary_wrapper") + cc_library( name = "cuda_driver", srcs = if_cuda_is_configured(["cuda_driver.cc"]), @@ -632,9 +636,9 @@ tf_cuda_cc_test( deps = [ ":cuda_activation", ":cuda_gpu_executor", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:event", diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index ab5b04b809d..a7ed5cedb4f 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -318,6 +318,16 @@ port::Status CudnnSupport::Init() { return port::Status(port::error::INTERNAL, error); } + // Preload sub libs for cudnn 8.0.4+ +#if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4) + cudnnOpsInferVersionCheck(); + cudnnOpsTrainVersionCheck(); + cudnnCnnInferVersionCheck(); + cudnnCnnTrainVersionCheck(); + cudnnAdvInferVersionCheck(); + cudnnAdvTrainVersionCheck(); +#endif + cudnn_.reset(new CudnnAccess(cudnn_handle)); return port::Status::OK(); } diff --git a/tensorflow/stream_executor/cuda/cudnn_8_0.inc b/tensorflow/stream_executor/cuda/cudnn_8_0.inc index e6d7eb25408..9161dbc8cf9 100644 --- a/tensorflow/stream_executor/cuda/cudnn_8_0.inc +++ b/tensorflow/stream_executor/cuda/cudnn_8_0.inc @@ -55,6 +55,23 @@ cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) { return func_ptr(handle); } + +#if CUDNN_MAJOR>=8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4) +cudnnStatus_t CUDNNWINAPI cudnnCnnInferVersionCheck(void) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(); + static auto func_ptr = LoadSymbol("cudnnCnnInferVersionCheck"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +cudnnStatus_t CUDNNWINAPI cudnnCnnTrainVersionCheck(void) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(); + static auto func_ptr = LoadSymbol("cudnnCnnTrainVersionCheck"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} +#endif + cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId) { using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); diff --git a/tensorflow/stream_executor/gpu/BUILD b/tensorflow/stream_executor/gpu/BUILD index 6328fa44a74..258142f5d83 100644 --- a/tensorflow/stream_executor/gpu/BUILD +++ b/tensorflow/stream_executor/gpu/BUILD @@ -10,7 +10,7 @@ load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts") +load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load( "//tensorflow/core/platform:rules_cc.bzl", @@ -70,7 +70,7 @@ cc_library( "//tensorflow/stream_executor:device_options", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", - ] + if_tpu( + ] + if_libtpu( if_false = ["@local_config_cuda//cuda:cuda_headers"], if_true = [], ), @@ -240,8 +240,8 @@ cc_library( ":gpu_driver_header", ":gpu_helpers_header", "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core/platform:regexp", + "//tensorflow/core/platform:cuda_libdevice_path", "//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/platform", "@com_google_absl//absl/strings:str_format", @@ -251,6 +251,7 @@ cc_library( ]) + if_cuda_is_configured([ "//tensorflow/stream_executor/cuda:cuda_driver", "//tensorflow/stream_executor/cuda:ptxas_wrapper", + "//tensorflow/stream_executor/cuda:fatbinary_wrapper", ]), ) @@ -271,9 +272,9 @@ cc_library( "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", - "//tensorflow/core:allocator", + "//tensorflow/core/framework:allocator", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:stream_executor_headers", diff --git a/tensorflow/stream_executor/gpu/asm_compiler.cc b/tensorflow/stream_executor/gpu/asm_compiler.cc index 0f6fd4de910..53f76503f2a 100644 --- a/tensorflow/stream_executor/gpu/asm_compiler.cc +++ b/tensorflow/stream_executor/gpu/asm_compiler.cc @@ -140,34 +140,44 @@ port::StatusOr> CompileGpuAsm(int device_ordinal, return CompileGpuAsm(cc_major, cc_minor, ptx_contents, options); } -port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, - const char* ptx_contents, - GpuAsmOpts options) { - std::string ptxas_path; - auto env = tensorflow::Env::Default(); - std::string ptxas_binary_name = "ptxas"; +static std::string findCudaExecutable(const std::string binary_name, + const std::string preferred_cuda_dir) { #if defined(PLATFORM_WINDOWS) - ptxas_binary_name += ".exe"; + const std::string binary_filename = binary_name + ".exe"; +#else + const std::string& binary_filename = binary_name; #endif + // Search in cuda root candidates. + auto env = tensorflow::Env::Default(); + std::string binary_path; for (const std::string& cuda_root : - tensorflow::CandidateCudaRoots(options.preferred_cuda_dir)) { - ptxas_path = tensorflow::io::JoinPath(cuda_root, "bin", ptxas_binary_name); - VLOG(2) << "Looking for ptxas at " << ptxas_path; - if (env->FileExists(ptxas_path).ok()) { + tensorflow::CandidateCudaRoots(preferred_cuda_dir)) { + binary_path = tensorflow::io::JoinPath(cuda_root, "bin", binary_filename); + VLOG(2) << "Looking for " << binary_filename << " at " << binary_path; + if (env->FileExists(binary_path).ok()) { break; } } - if (!env->FileExists(ptxas_path).ok()) { + if (!env->FileExists(binary_path).ok()) { // Rely on subprocess invocation to find the correct binary. - ptxas_path = ptxas_binary_name; + binary_path = binary_filename; } - VLOG(2) << "Using ptxas at " << ptxas_path; + VLOG(2) << "Using " << binary_filename << " at " << binary_path; + return binary_path; +} + +port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, + const char* ptx_contents, + GpuAsmOpts options) { + std::string ptxas_path = + findCudaExecutable("ptxas", options.preferred_cuda_dir); WarnIfBadPtxasVersion(ptxas_path); // Write ptx into a temporary file. std::string ptx_path; + auto env = tensorflow::Env::Default(); if (!env->LocalTempFilename(&ptx_path)) { return port::InternalError("couldn't get temp PTX file name"); } @@ -232,4 +242,78 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, return cubin_vector; } +port::StatusOr> BundleGpuAsm( + std::vector images, const std::string preferred_cuda_dir) { + std::string fatbinary_path = + findCudaExecutable("fatbinary", preferred_cuda_dir); + + // Write images to temporary files. + std::vector image_paths; + auto env = tensorflow::Env::Default(); + for (const CubinOrPTXImage& img : images) { + std::string img_path; + if (!env->LocalTempFilename(&img_path)) { + return port::InternalError( + "Could not get temporary filenames for images."); + } + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile( + env, img_path, std::string(img.bytes.begin(), img.bytes.end()))); + VLOG(2) << "image written to " << img_path; + image_paths.push_back(std::move(img_path)); + } + auto image_files_cleaner = tensorflow::gtl::MakeCleanup([&image_paths] { + for (const auto& path : image_paths) { + TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(path)); + } + }); + + // Prepare temorary result file. + std::string result_path; + if (!env->LocalTempFilename(&result_path)) { + return port::InternalError( + "Could not get temporary filename for fatbin result."); + } + auto result_file_cleaner = tensorflow::gtl::MakeCleanup([&result_path] { + // This file may never be created, so the failure to delete it should not + // propagate to TF. + tensorflow::Env::Default()->DeleteFile(result_path).IgnoreError(); + }); + + // Invoke fatbinary and collect its output. + tensorflow::SubProcess fatbinary; + std::vector fatbinary_args = { + fatbinary_path, "--64", "--cmdline=--compile-only", + "--link", "--compress-all", absl::StrCat("--create=", result_path)}; + assert(images.size() == image_paths.size()); + for (int i = 0; i < images.size(); i++) { + fatbinary_args.push_back(absl::StrFormat( + "--image=profile=%s,file=%s", images[i].profile, image_paths[i])); + } + if (VLOG_IS_ON(3)) { + VLOG(3) << absl::StrJoin(fatbinary_args, " "); + } + fatbinary.SetProgram(fatbinary_path, fatbinary_args); + fatbinary.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); + if (!fatbinary.Start()) { + return port::InternalError("Failed to launch fatbinary."); + } + std::string stderr_output; + int exit_status = fatbinary.Communicate( + /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); + if (exit_status != 0) { + return port::InternalError(absl::StrFormat( + "fatbinary exited with non-zero error code %d, output: %s", exit_status, + stderr_output)); + } + if (!stderr_output.empty()) { + VLOG(2) << stderr_output; + } + + // Read in the result and return it as a byte vector. + std::string result_blob; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + result_path, &result_blob)); + return std::vector(result_blob.begin(), result_blob.end()); +} + } // namespace stream_executor diff --git a/tensorflow/stream_executor/gpu/asm_compiler.h b/tensorflow/stream_executor/gpu/asm_compiler.h index e5f67a71242..513ac6ca867 100644 --- a/tensorflow/stream_executor/gpu/asm_compiler.h +++ b/tensorflow/stream_executor/gpu/asm_compiler.h @@ -52,6 +52,16 @@ port::StatusOr> CompileGpuAsm(int cc_major, int cc_minor, port::StatusOr> CompileGpuAsmOrGetCached( int device_ordinal, const char* ptx, GpuAsmOpts compilation_options); +struct CubinOrPTXImage { + std::string profile; + std::vector bytes; +}; + +// Bundles the GPU machine code (cubins) and PTX if requested and returns the +// resulting binary (i.e. a fatbin) as a byte array. +port::StatusOr> BundleGpuAsm( + std::vector images, const std::string preferred_cuda_dir); + } // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_ diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index 84293b7767a..a78c738f32c 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -31,6 +31,7 @@ namespace internal { namespace { string GetCudaVersion() { return TF_CUDA_VERSION; } +string GetCudaRtVersion() { return TF_CUDART_VERSION; } string GetCudnnVersion() { return TF_CUDNN_VERSION; } string GetCublasVersion() { return TF_CUBLAS_VERSION; } string GetCusolverVersion() { return TF_CUSOLVER_VERSION; } @@ -77,7 +78,7 @@ port::StatusOr GetCudaDriverDsoHandle() { } port::StatusOr GetCudaRuntimeDsoHandle() { - return GetDsoHandle("cudart", GetCudaVersion()); + return GetDsoHandle("cudart", GetCudaRtVersion()); } port::StatusOr GetCublasDsoHandle() { diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 89743ab5116..c32793e3e83 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -744,6 +744,7 @@ bool StreamExecutor::AllocateStream(Stream *stream) { return false; } + RegisterStream(stream); return true; } @@ -751,6 +752,7 @@ void StreamExecutor::DeallocateStream(Stream *stream) { implementation_->DeallocateStream(stream); CHECK_GE(live_stream_count_.fetch_sub(1), 0) << "live stream count should not dip below zero"; + UnregisterStream(stream); } bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 61a29c2c093..f19c76c3790 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -513,6 +513,24 @@ class StreamExecutor { // allocation. StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; } + // Block host until all streams associated with this stream executor have + // finished all of enqueued work. + port::Status BlockHostUntilAllStreamsAreDone() { + std::vector streams; + { + absl::MutexLock lock(&mu_); + for (Stream *stream : streams_) { + streams.push_back(stream); + } + } + + for (Stream *stream : streams) { + TF_RETURN_IF_ERROR(BlockHostUntilDone(stream)); + } + + return port::Status::OK(); + } + private: template @@ -642,6 +660,16 @@ class StreamExecutor { template void SubmitTrace(TraceCallT trace_call, ArgsT &&...args); + void RegisterStream(Stream *stream) { + absl::MutexLock lock(&mu_); + streams_.insert(stream); + } + + void UnregisterStream(Stream *stream) { + absl::MutexLock lock(&mu_); + streams_.erase(stream); + } + // Reader/writer lock for class-static StreamExecutor members. static absl::Mutex static_mu_; @@ -732,6 +760,9 @@ class StreamExecutor { StreamExecutorMemoryAllocator allocator_; + // Set of streams associated with this stream executor. + std::set streams_ TF_GUARDED_BY(mu_); + SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor); }; diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD index df0867042f2..98d59726b60 100644 --- a/tensorflow/stream_executor/tpu/BUILD +++ b/tensorflow/stream_executor/tpu/BUILD @@ -3,7 +3,10 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( - default_visibility = ["//tensorflow/core/tpu:__subpackages__"], + default_visibility = [ + "//learning/brain/experimental/dtensor:__subpackages__", + "//tensorflow/core/tpu:__subpackages__", + ], licenses = ["notice"], # Apache 2.0 ) @@ -222,6 +225,7 @@ cc_library( cc_library( name = "tpu_transfer_manager", srcs = ["tpu_transfer_manager_registration.cc"], + visibility = ["//visibility:public"], deps = [ ":tpu_executor", ":tpu_transfer_manager_base", @@ -256,6 +260,7 @@ cc_library( name = "tpu_computation_placer", srcs = ["tpu_computation_placer.cc"], hdrs = ["tpu_computation_placer.h"], + visibility = ["//visibility:public"], deps = [ ":status_helper", ":tpu_executor", diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.cc b/tensorflow/stream_executor/tpu/c_api_conversions.cc index 674a1fdfb68..0a7801f45fc 100644 --- a/tensorflow/stream_executor/tpu/c_api_conversions.cc +++ b/tensorflow/stream_executor/tpu/c_api_conversions.cc @@ -23,7 +23,6 @@ limitations under the License. namespace ApiConverter { xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) { - xla::Shape xla_on_host_shape = ApiConverter::FromC(&c_buffer->on_host_shape); xla::Shape xla_on_device_shape = ApiConverter::FromC(&c_buffer->on_device_shape); @@ -36,7 +35,7 @@ xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) { } xla::ShapedBuffer xla_shaped_buffer( - xla_on_host_shape, xla_on_device_shape, + xla_on_device_shape, tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform(), c_buffer->device_ordinal); xla_shaped_buffer.set_buffers(xla_shape_tree); @@ -199,7 +198,6 @@ xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal) { } void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) { - ApiConverter::ToC(buffer.on_host_shape(), &c_device_buffer->on_host_shape); ApiConverter::ToC(buffer.on_device_shape(), &c_device_buffer->on_device_shape); c_device_buffer->device_ordinal = buffer.device_ordinal(); @@ -226,7 +224,6 @@ void Free(XLA_Literal* c_literal) { void Free(XLA_ShapedBuffer* c_buffer) { ApiConverter::Free(&c_buffer->on_device_shape); - ApiConverter::Free(&c_buffer->on_host_shape); delete[] c_buffer->bases; } diff --git a/tensorflow/stream_executor/tpu/c_api_decl.h b/tensorflow/stream_executor/tpu/c_api_decl.h index 7953670dec7..dcb53823e0c 100644 --- a/tensorflow/stream_executor/tpu/c_api_decl.h +++ b/tensorflow/stream_executor/tpu/c_api_decl.h @@ -177,7 +177,6 @@ typedef struct XLA_Shape { // Represents a leaf node for a XLA shaped buffer. typedef struct XLA_ShapedBuffer { - XLA_Shape on_host_shape; XLA_Shape on_device_shape; int device_ordinal; @@ -208,7 +207,6 @@ typedef struct SE_ExecutionInput { XLA_ShapeIndex* unowned_indices; int unowned_indices_size; XLA_Shape dynamic_shape; - XLA_Shape host_shape; } SE_ExecutionInput; typedef struct SE_ExecutionOutput { diff --git a/tensorflow/stream_executor/tpu/tpu_executable_interface.cc b/tensorflow/stream_executor/tpu/tpu_executable_interface.cc index af29f2e2b06..84ce8444420 100644 --- a/tensorflow/stream_executor/tpu/tpu_executable_interface.cc +++ b/tensorflow/stream_executor/tpu/tpu_executable_interface.cc @@ -90,8 +90,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( } } - ExecutionOutput result(host_shape, std::move(device_shape), allocator, - device_ordinal); + ExecutionOutput result(std::move(device_shape), allocator, device_ordinal); // Iterate through and allocate a buffer for each shape index, checking for // possible input buffer reuse. int64 reused_buffer_bytes = 0; diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h index ed8c04af00d..193730567e7 100644 --- a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h +++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h @@ -246,6 +246,9 @@ int TpuTopology_LogicalDevicesPerHost(SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type); int TpuTopology_LogicalDevicesPerChip(SE_TpuTopology* tpu_topology, TpuCoreTypeEnum tpu_core_type); +int TpuTopology_HostCount(SE_TpuTopology* tpu_topology); +int TpuTopology_ChipsPerHost(SE_TpuTopology* tpu_topology); + int TpuTopology_ChipBounds_X(SE_TpuTopology* tpu_topology); int TpuTopology_ChipBounds_Y(SE_TpuTopology* tpu_topology); int TpuTopology_ChipBounds_Z(SE_TpuTopology* tpu_topology); @@ -436,6 +439,9 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuTopology_LogicalDevicesPerHost); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_LogicalDevicesPerChip); + TFTPU_ADD_FN_IN_STRUCT(TpuTopology_HostCount); + TFTPU_ADD_FN_IN_STRUCT(TpuTopology_ChipsPerHost); + TFTPU_ADD_FN_IN_STRUCT(TpuTopology_ChipBounds_X); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_ChipBounds_Y); TFTPU_ADD_FN_IN_STRUCT(TpuTopology_ChipBounds_Z); diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h index fee9d92b42d..240148977e3 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform_interface.h +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h @@ -56,6 +56,10 @@ class TpuPlatformInterface : public stream_executor::Platform { virtual const TpuTopologyPtr GetTopologyPtr() = 0; virtual const TpuHostLocationExternal GetTpuHostLocation() const = 0; + + TpuTopologyExternal topology() { + return TpuTopologyExternal(GetTopologyPtr()); + } }; } // namespace tpu diff --git a/tensorflow/stream_executor/tpu/tpu_topology.cc b/tensorflow/stream_executor/tpu/tpu_topology.cc index 2638cf2d93a..909f5bd9dac 100644 --- a/tensorflow/stream_executor/tpu/tpu_topology.cc +++ b/tensorflow/stream_executor/tpu/tpu_topology.cc @@ -73,6 +73,14 @@ int32 TpuTopologyExternal::LogicalDevicesPerChip( core_type); } +int32 TpuTopologyExternal::HostCount() const { + return tpu::ExecutorApiFn()->TpuTopology_HostCountFn(topology_); +} + +int32 TpuTopologyExternal::ChipsPerHost() const { + return tpu::ExecutorApiFn()->TpuTopology_ChipsPerHostFn(topology_); +} + TpuTopologyChipBoundsExternal TpuTopologyExternal::chip_bounds() const { return {tpu::ExecutorApiFn()->TpuTopology_ChipBounds_XFn(topology_), tpu::ExecutorApiFn()->TpuTopology_ChipBounds_YFn(topology_), diff --git a/tensorflow/stream_executor/tpu/tpu_topology.h b/tensorflow/stream_executor/tpu/tpu_topology.h index 7a92353993b..84e13a142b6 100644 --- a/tensorflow/stream_executor/tpu/tpu_topology.h +++ b/tensorflow/stream_executor/tpu/tpu_topology.h @@ -71,6 +71,8 @@ class TpuTopologyExternal { : topology_(topology) {} int32 LogicalDevicesPerHost(TpuCoreTypeEnum core_type) const; int32 LogicalDevicesPerChip(TpuCoreTypeEnum core_type) const; + int32 HostCount() const; + int32 ChipsPerHost() const; TpuTopologyChipBoundsExternal chip_bounds() const; bool HasChip(int x, int y, int z) const; TpuCoreLocationExternal Core(int x, int y, int z, TpuCoreTypeEnum core_type, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 8ed12136c55..075730742c6 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -4,7 +4,6 @@ load( "//tensorflow/core/platform:build_config_root.bzl", "if_dynamic_kernels", "if_static", - "register_extension_info", "tf_additional_grpc_deps_py", "tf_additional_xla_deps_py", "tf_exec_properties", @@ -261,8 +260,8 @@ def if_nccl(if_true, if_false = []): "//conditions:default": if_true, }) -def if_tpu(if_true, if_false = []): - """Shorthand for select()ing whether to build for TPUs.""" +def if_libtpu(if_true, if_false = []): + """Shorthand for select()ing whether to build support for using TPUs via libtpu.so""" return select({ str(Label("//tensorflow:with_tpu_support")): if_true, "//conditions:default": if_false, @@ -328,7 +327,7 @@ def tf_copts( (if_not_windows(["-fno-exceptions"]) if not allow_exceptions else []) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_nvcc(["-DTENSORFLOW_USE_NVCC=1"]) + - if_tpu(["-DLIBTFTPU"]) + + if_libtpu(["-DLIBTPU_ON_GCE"], []) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"]) + if_mkl(["-DINTEL_MKL=1", "-DENABLE_MKLDNN_V1", "-DENABLE_INTEL_MKL_BFLOAT16"]) + @@ -399,7 +398,12 @@ def tf_defines_nortti_if_lite_protos(): # Given a list of "op_lib_names" (a list of files in the ops directory # without their .cc extensions), generate a library for that file. -def tf_gen_op_libs(op_lib_names, deps = None, is_external = True, compatible_with = None): +def tf_gen_op_libs( + op_lib_names, + sub_directory = "ops/", + deps = None, + is_external = True, + compatible_with = None): # Make library out of each op so it can also be used to generate wrappers # for various languages. if not deps: @@ -408,7 +412,7 @@ def tf_gen_op_libs(op_lib_names, deps = None, is_external = True, compatible_wit cc_library( name = n + "_op_lib", copts = tf_copts(is_external = is_external), - srcs = ["ops/" + n + ".cc"], + srcs = [sub_directory + n + ".cc"], deps = deps + [clean_dep("//tensorflow/core:framework")], compatible_with = compatible_with, visibility = ["//visibility:public"], @@ -644,11 +648,6 @@ def tf_cc_shared_object( visibility = visibility, ) -register_extension_info( - extension_name = "tf_cc_shared_object", - label_regex_for_dep = "{extension_name}", -) - # Links in the framework shared object # (//third_party/tensorflow:libtensorflow_framework.so) when not building # statically. Also adds linker options (rpaths) so that the framework shared @@ -704,11 +703,6 @@ def tf_cc_binary( visibility = visibility, ) -register_extension_info( - extension_name = "tf_cc_binary", - label_regex_for_dep = "{extension_name}.*", -) - # A simple wrap around native.cc_binary rule. # When using this rule, you should realize it doesn't link to any tensorflow # dependencies by default. @@ -733,11 +727,6 @@ def tf_native_cc_binary( **kwargs ) -register_extension_info( - extension_name = "tf_native_cc_binary", - label_regex_for_dep = "{extension_name}.*", -) - def tf_gen_op_wrapper_cc( name, out_ops_file, @@ -1087,11 +1076,6 @@ def tf_cc_test( **kwargs ) -register_extension_info( - extension_name = "tf_cc_test", - label_regex_for_dep = "{extension_name}.*", -) - # Part of the testing workflow requires a distinguishable name for the build # rules that involve a GPU, even if otherwise identical to the base rule. def tf_cc_test_gpu( @@ -1116,11 +1100,6 @@ def tf_cc_test_gpu( tags = tags, ) -register_extension_info( - extension_name = "tf_cc_test_gpu", - label_regex_for_dep = "{extension_name}", -) - def tf_gpu_cc_test( name, srcs = [], @@ -1171,20 +1150,10 @@ def tf_gpu_cc_test( ]), ) -register_extension_info( - extension_name = "tf_gpu_cc_test", - label_regex_for_dep = "{extension_name}", -) - # terminology changes: saving tf_cuda_* definition for compatibility def tf_cuda_cc_test(*args, **kwargs): tf_gpu_cc_test(*args, **kwargs) -register_extension_info( - extension_name = "tf_cuda_cc_test", - label_regex_for_dep = "{extension_name}", -) - def tf_gpu_only_cc_test( name, srcs = [], @@ -1224,20 +1193,10 @@ def tf_gpu_only_cc_test( exec_properties = tf_exec_properties({"tags": tags}), ) -register_extension_info( - extension_name = "tf_gpu_only_cc_test", - label_regex_for_dep = "{extension_name}_gpu", -) - # terminology changes: saving tf_cuda_* definition for compatibility def tf_cuda_only_cc_test(*args, **kwargs): tf_gpu_only_cc_test(*args, **kwargs) -register_extension_info( - extension_name = "tf_cuda_only_cc_test", - label_regex_for_dep = "{extension_name}_gpu", -) - # Create a cc_test for each of the tensorflow tests listed in "tests", along # with a test suite of the given name, if provided. def tf_cc_tests( @@ -1369,11 +1328,6 @@ def tf_java_test( **kwargs ) -register_extension_info( - extension_name = "tf_java_test", - label_regex_for_dep = "{extension_name}", -) - def _cuda_copts(opts = []): """Gets the appropriate set of copts for (maybe) CUDA compilation. @@ -1425,11 +1379,6 @@ def tf_gpu_kernel_library( **kwargs ) -register_extension_info( - extension_name = "tf_gpu_kernel_library", - label_regex_for_dep = "{extension_name}", -) - def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs): """Generate a cc_library with a conditional set of CUDA dependencies. @@ -1465,20 +1414,10 @@ def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs): **kwargs ) -register_extension_info( - extension_name = "tf_gpu_library", - label_regex_for_dep = "{extension_name}", -) - # terminology changes: saving tf_cuda_* definition for compatibility def tf_cuda_library(*args, **kwargs): tf_gpu_library(*args, **kwargs) -register_extension_info( - extension_name = "tf_cuda_library", - label_regex_for_dep = "{extension_name}", -) - def tf_kernel_library( name, prefix = None, @@ -1591,11 +1530,6 @@ def tf_kernel_library( deps = deps, ) -register_extension_info( - extension_name = "tf_kernel_library", - label_regex_for_dep = "({extension_name}(_gpu)?|libtfkernel_{extension_name}\\.so)", -) - def tf_mkl_kernel_library( name, prefix = None, @@ -1634,11 +1568,6 @@ def tf_mkl_kernel_library( features = disable_header_modules, ) -register_extension_info( - extension_name = "tf_mkl_kernel_library", - label_regex_for_dep = "{extension_name}", -) - def _get_transitive_headers(hdrs, deps): """Obtain the header files for a target and its transitive dependencies. @@ -1906,11 +1835,6 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [ **kwargs ) -register_extension_info( - extension_name = "tf_custom_op_library", - label_regex_for_dep = "{extension_name}", -) - # Placeholder to use until bazel supports py_strict_binary. def py_strict_binary(name, **kwargs): native.py_binary(name = name, **kwargs) @@ -1941,11 +1865,6 @@ def tf_custom_op_py_library( deps = deps, ) -register_extension_info( - extension_name = "tf_custom_op_py_library", - label_regex_for_dep = "{extension_name}", -) - # In tf_py_wrap_cc_opensource generated libraries # module init functions are not exported unless # they contain one of the keywords in the version file @@ -2167,11 +2086,6 @@ def py_test(deps = [], data = [], kernels = [], exec_properties = None, **kwargs **kwargs ) -register_extension_info( - extension_name = "py_test", - label_regex_for_dep = "{extension_name}", -) - # Similar to py_test above, this macro is used to exclude dependencies for some py_binary # targets in order to reduce the size of //tensorflow/tools/pip_package:simple_console_windows. # See https://github.com/tensorflow/tensorflow/issues/22390 @@ -2192,11 +2106,6 @@ def py_binary(name, deps = [], **kwargs): **kwargs ) -register_extension_info( - extension_name = "py_binary", - label_regex_for_dep = "{extension_name}", -) - def pytype_library(**kwargs): # Types not enforced in OSS. native.py_library(**kwargs) @@ -2293,11 +2202,6 @@ def tf_py_test( **kwargs ) -register_extension_info( - extension_name = "tf_py_test", - label_regex_map = {"deps": "deps:{extension_name}"}, -) - def gpu_py_test( name, srcs, @@ -2358,20 +2262,10 @@ def gpu_py_test( **kwargs ) -register_extension_info( - extension_name = "gpu_py_test", - label_regex_map = {"deps": "deps:{extension_name}"}, -) - # terminology changes: saving cuda_* definition for compatibility def cuda_py_test(*args, **kwargs): gpu_py_test(*args, **kwargs) -register_extension_info( - extension_name = "cuda_py_test", - label_regex_map = {"deps": "deps:{extension_name}"}, -) - def py_tests( name, srcs, @@ -2616,11 +2510,6 @@ def cc_library_with_android_deps( deps = if_not_android(deps) + if_android(android_deps) + common_deps cc_library(deps = deps, copts = copts, **kwargs) -register_extension_info( - extension_name = "cc_library_with_android_deps", - label_regex_for_dep = "{extension_name}", -) - def tensorflow_opensource_extra_deps(): return [] @@ -2767,7 +2656,8 @@ def tf_python_pybind_extension( deps = [], defines = [], visibility = None, - testonly = None): + testonly = None, + compatible_with = None): """A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD. Please do not use it anywhere else as it may behave unexpectedly. b/146445820 @@ -2787,6 +2677,7 @@ def tf_python_pybind_extension( visibility = visibility, link_in_framework = True, testonly = testonly, + compatible_with = compatible_with, ) def tf_pybind_cc_library_wrapper(name, deps, visibility = None, **kwargs): diff --git a/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 618c772e92d..976563ff199 100644 --- a/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/tools/android/inference_interface/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -43,8 +43,8 @@ import org.tensorflow.types.UInt8; * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface * for inference. * - *

See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an - * example usage. + *

See tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowImageClassifier.java for + * an example usage. */ public class TensorFlowInferenceInterface { private static final String TAG = "TensorFlowInferenceInterface"; diff --git a/tensorflow/tools/android/test/.gitignore b/tensorflow/tools/android/test/.gitignore new file mode 100644 index 00000000000..fbd0a2dc7d7 --- /dev/null +++ b/tensorflow/tools/android/test/.gitignore @@ -0,0 +1,30 @@ +# This file is based on https://github.com/github/gitignore/blob/master/Android.gitignore +*.iml +.idea/compiler.xml +.idea/copyright +.idea/dictionaries +.idea/gradle.xml +.idea/libraries +.idea/inspectionProfiles +.idea/misc.xml +.idea/modules.xml +.idea/runConfigurations.xml +.idea/tasks.xml +.idea/workspace.xml +.gradle +local.properties +.DS_Store +build/ +gradleBuild/ +*.apk +*.ap_ +*.dex +*.class +bin/ +gen/ +out/ +*.log +.navigation/ +/captures +.externalNativeBuild +.cxx diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/tools/android/test/AndroidManifest.xml similarity index 100% rename from tensorflow/examples/android/AndroidManifest.xml rename to tensorflow/tools/android/test/AndroidManifest.xml diff --git a/tensorflow/examples/android/BUILD b/tensorflow/tools/android/test/BUILD similarity index 95% rename from tensorflow/examples/android/BUILD rename to tensorflow/tools/android/test/BUILD index 00d9a901395..387924a2e65 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/tools/android/test/BUILD @@ -69,7 +69,7 @@ android_binary( # Package assets from assets dir as well as all model targets. Remove undesired models # (and corresponding Activities in source) to reduce APK size. assets = [ - "//tensorflow/examples/android/assets:asset_files", + "//tensorflow/tools/android/test/assets:asset_files", ":external_assets", ], assets_dir = "", @@ -96,7 +96,7 @@ filegroup( "@stylize//:model_files", ], ) -# LINT.ThenChange(//tensorflow/examples/android/download-models.gradle) +# LINT.ThenChange(//tensorflow/tools/android/test/download-models.gradle) filegroup( name = "java_files", diff --git a/tensorflow/examples/android/README.md b/tensorflow/tools/android/test/README.md similarity index 93% rename from tensorflow/examples/android/README.md rename to tensorflow/tools/android/test/README.md index af93adfb1e4..977f9305498 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/tools/android/test/README.md @@ -1,5 +1,7 @@ # TensorFlow Android Camera Demo +DEPRECATED: These examples are deprecated. + This folder contains an example application utilizing TensorFlow for Android devices. @@ -20,22 +22,22 @@ on API >= 14 devices. ## Current samples: -1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java): +1. [TF Classify](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/android/test/src/org/tensorflow/demo/ClassifierActivity.java): Uses the [Google Inception](https://arxiv.org/abs/1409.4842) model to classify camera frames in real-time, displaying the top results in an overlay on the camera image. -2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java): +2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/android/test/src/org/tensorflow/demo/DetectorActivity.java): Demonstrates an SSD-Mobilenet model trained using the [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection/) introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to localize and track objects (from 80 categories) in the camera preview in real-time. -3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java): +3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/android/test/src/org/tensorflow/demo/StylizeActivity.java): Uses a model based on [A Learned Representation For Artistic Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview image to that of a number of different artists. 4. [TF - Speech](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java): + Speech](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/android/test/src/org/tensorflow/demo/SpeechActivity.java): Runs a simple speech recognition model built by the [audio training tutorial](https://www.tensorflow.org/versions/master/tutorials/audio_recognition). Listens for a small set of words, and highlights them in the UI when they are @@ -165,7 +167,7 @@ BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip do curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP} - unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/ + unzip /tmp/${MODEL_ZIP} -d tensorflow/tools/android/test/assets/ done ``` @@ -182,7 +184,7 @@ After editing your WORKSPACE file to update the SDK/NDK configuration, you may build the APK. Run this from your workspace root: ```bash -bazel build --cxxopt='--std=c++11' -c opt //tensorflow/examples/android:tensorflow_demo +bazel build --cxxopt='--std=c++11' -c opt //tensorflow/tools/android/test:tensorflow_demo ``` ##### Install @@ -192,7 +194,7 @@ device, then after building use the following command from your workspace root to install the APK: ```bash -adb install -r bazel-bin/tensorflow/examples/android/tensorflow_demo.apk +adb install -r bazel-bin/tensorflow/tools/android/test/tensorflow_demo.apk ``` ### Android Studio with Bazel @@ -202,7 +204,7 @@ make sure that you can build with Bazel following the above directions. Then, look at [build.gradle](build.gradle) and make sure that the path to Bazel matches that of your system. -At this point you can add the tensorflow/examples/android directory as a new +At this point you can add the tensorflow/tools/android/test directory as a new Android Studio project. Click through installing all the Gradle extensions it requests, and you should be able to have Android Studio build the demo like any other application (it will call out to Bazel to build the native code with the diff --git a/tensorflow/examples/android/__init__.py b/tensorflow/tools/android/test/__init__.py similarity index 100% rename from tensorflow/examples/android/__init__.py rename to tensorflow/tools/android/test/__init__.py diff --git a/tensorflow/examples/android/assets/BUILD b/tensorflow/tools/android/test/assets/BUILD similarity index 100% rename from tensorflow/examples/android/assets/BUILD rename to tensorflow/tools/android/test/assets/BUILD diff --git a/tensorflow/examples/android/bin/AndroidManifest.xml b/tensorflow/tools/android/test/bin/AndroidManifest.xml similarity index 100% rename from tensorflow/examples/android/bin/AndroidManifest.xml rename to tensorflow/tools/android/test/bin/AndroidManifest.xml diff --git a/tensorflow/examples/android/build.gradle b/tensorflow/tools/android/test/build.gradle similarity index 97% rename from tensorflow/examples/android/build.gradle rename to tensorflow/tools/android/test/build.gradle index 17fda55d693..4f7bea6b19d 100644 --- a/tensorflow/examples/android/build.gradle +++ b/tensorflow/tools/android/test/build.gradle @@ -56,7 +56,7 @@ def nativeOutDir = 'libs/' + cpuType // Default to building with Bazel and override with make if requested. def nativeBuildRule = 'buildNativeBazel' -def demoLibPath = '../../../bazel-bin/tensorflow/examples/android/libtensorflow_demo.so' +def demoLibPath = '../../../bazel-bin/tensorflow/tools/android/test/libtensorflow_demo.so' def inferenceLibPath = '../../../bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so' // Override for Makefile builds. @@ -149,7 +149,7 @@ android { task buildNativeBazel(type: Exec) { workingDir '../../..' commandLine bazelLocation, 'build', '-c', 'opt', \ - 'tensorflow/examples/android:tensorflow_native_libs', \ + 'tensorflow/tools/android/test:tensorflow_native_libs', \ '--crosstool_top=//external:android/crosstool', \ '--cpu=' + cpuType, \ '--host_crosstool_top=@bazel_tools//tools/cpp:toolchain' diff --git a/tensorflow/examples/android/download-models.gradle b/tensorflow/tools/android/test/download-models.gradle similarity index 96% rename from tensorflow/examples/android/download-models.gradle rename to tensorflow/tools/android/test/download-models.gradle index 727ef2cc850..f581e58a04e 100644 --- a/tensorflow/examples/android/download-models.gradle +++ b/tensorflow/tools/android/test/download-models.gradle @@ -13,7 +13,7 @@ def models = ['inception_v1.zip', 'object_detection/ssd_mobilenet_v1_android_export.zip', 'stylize_v1.zip', 'speech_commands_conv_actions.zip'] -// LINT.ThenChange(//tensorflow/examples/android/BUILD) +// LINT.ThenChange(//tensorflow/tools/android/test/BUILD) // Root URL for model archives def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models' diff --git a/tensorflow/examples/android/gradlew b/tensorflow/tools/android/test/gradlew similarity index 100% rename from tensorflow/examples/android/gradlew rename to tensorflow/tools/android/test/gradlew diff --git a/tensorflow/examples/android/gradlew.bat b/tensorflow/tools/android/test/gradlew.bat similarity index 100% rename from tensorflow/examples/android/gradlew.bat rename to tensorflow/tools/android/test/gradlew.bat diff --git a/tensorflow/examples/android/jni/CMakeLists.txt b/tensorflow/tools/android/test/jni/CMakeLists.txt similarity index 100% rename from tensorflow/examples/android/jni/CMakeLists.txt rename to tensorflow/tools/android/test/jni/CMakeLists.txt diff --git a/tensorflow/examples/android/jni/__init__.py b/tensorflow/tools/android/test/jni/__init__.py similarity index 100% rename from tensorflow/examples/android/jni/__init__.py rename to tensorflow/tools/android/test/jni/__init__.py diff --git a/tensorflow/examples/android/jni/imageutils_jni.cc b/tensorflow/tools/android/test/jni/imageutils_jni.cc similarity index 98% rename from tensorflow/examples/android/jni/imageutils_jni.cc rename to tensorflow/tools/android/test/jni/imageutils_jni.cc index a7b39bdb316..18e834f7085 100644 --- a/tensorflow/examples/android/jni/imageutils_jni.cc +++ b/tensorflow/tools/android/test/jni/imageutils_jni.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "tensorflow/examples/android/jni/rgb2yuv.h" -#include "tensorflow/examples/android/jni/yuv2rgb.h" +#include "tensorflow/tools/android/test/jni/rgb2yuv.h" +#include "tensorflow/tools/android/test/jni/yuv2rgb.h" #define IMAGEUTILS_METHOD(METHOD_NAME) \ Java_org_tensorflow_demo_env_ImageUtils_##METHOD_NAME // NOLINT diff --git a/tensorflow/examples/android/jni/object_tracking/config.h b/tensorflow/tools/android/test/jni/object_tracking/config.h similarity index 99% rename from tensorflow/examples/android/jni/object_tracking/config.h rename to tensorflow/tools/android/test/jni/object_tracking/config.h index 47de2d2c15b..830c64976a2 100644 --- a/tensorflow/examples/android/jni/object_tracking/config.h +++ b/tensorflow/tools/android/test/jni/object_tracking/config.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/flow_cache.h b/tensorflow/tools/android/test/jni/object_tracking/flow_cache.h similarity index 97% rename from tensorflow/examples/android/jni/object_tracking/flow_cache.h rename to tensorflow/tools/android/test/jni/object_tracking/flow_cache.h index b62e334ecd7..b388addb63a 100644 --- a/tensorflow/examples/android/jni/object_tracking/flow_cache.h +++ b/tensorflow/tools/android/test/jni/object_tracking/flow_cache.h @@ -16,11 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.cc b/tensorflow/tools/android/test/jni/object_tracking/frame_pair.cc similarity index 98% rename from tensorflow/examples/android/jni/object_tracking/frame_pair.cc rename to tensorflow/tools/android/test/jni/object_tracking/frame_pair.cc index 66e422e87b6..770e2b145aa 100644 --- a/tensorflow/examples/android/jni/object_tracking/frame_pair.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/frame_pair.cc @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/tools/android/test/jni/object_tracking/frame_pair.h" + #include -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/frame_pair.h b/tensorflow/tools/android/test/jni/object_tracking/frame_pair.h similarity index 98% rename from tensorflow/examples/android/jni/object_tracking/frame_pair.h rename to tensorflow/tools/android/test/jni/object_tracking/frame_pair.h index 6c8ac9be981..7365a04b8c1 100644 --- a/tensorflow/examples/android/jni/object_tracking/frame_pair.h +++ b/tensorflow/tools/android/test/jni/object_tracking/frame_pair.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FRAME_PAIR_H_ -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/geom.h b/tensorflow/tools/android/test/jni/object_tracking/geom.h similarity index 98% rename from tensorflow/examples/android/jni/object_tracking/geom.h rename to tensorflow/tools/android/test/jni/object_tracking/geom.h index c975e40144b..f46924db7e3 100644 --- a/tensorflow/examples/android/jni/object_tracking/geom.h +++ b/tensorflow/tools/android/test/jni/object_tracking/geom.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_GEOM_H_ -#include "tensorflow/examples/android/jni/object_tracking/logging.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/logging.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/gl_utils.h b/tensorflow/tools/android/test/jni/object_tracking/gl_utils.h similarity index 96% rename from tensorflow/examples/android/jni/object_tracking/gl_utils.h rename to tensorflow/tools/android/test/jni/object_tracking/gl_utils.h index a29e677d3c5..f6e2a445807 100755 --- a/tensorflow/examples/android/jni/object_tracking/gl_utils.h +++ b/tensorflow/tools/android/test/jni/object_tracking/gl_utils.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/image-inl.h b/tensorflow/tools/android/test/jni/object_tracking/image-inl.h similarity index 99% rename from tensorflow/examples/android/jni/object_tracking/image-inl.h rename to tensorflow/tools/android/test/jni/object_tracking/image-inl.h index 61d69908b55..8382e56bb0f 100644 --- a/tensorflow/examples/android/jni/object_tracking/image-inl.h +++ b/tensorflow/tools/android/test/jni/object_tracking/image-inl.h @@ -18,9 +18,9 @@ limitations under the License. #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/image.h b/tensorflow/tools/android/test/jni/object_tracking/image.h similarity index 98% rename from tensorflow/examples/android/jni/object_tracking/image.h rename to tensorflow/tools/android/test/jni/object_tracking/image.h index a436f0e0a13..982cb0e715c 100644 --- a/tensorflow/examples/android/jni/object_tracking/image.h +++ b/tensorflow/tools/android/test/jni/object_tracking/image.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" // TODO(andrewharp): Make this a cast to uint32_t if/when we go unsigned for // operations. diff --git a/tensorflow/examples/android/jni/object_tracking/image_data.h b/tensorflow/tools/android/test/jni/object_tracking/image_data.h similarity index 94% rename from tensorflow/examples/android/jni/object_tracking/image_data.h rename to tensorflow/tools/android/test/jni/object_tracking/image_data.h index c4f91d8cbd8..f89967c8d6f 100644 --- a/tensorflow/examples/android/jni/object_tracking/image_data.h +++ b/tensorflow/tools/android/test/jni/object_tracking/image_data.h @@ -17,16 +17,16 @@ limitations under the License. #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_IMAGE_DATA_H_ #include + #include -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image_utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/image_neon.cc b/tensorflow/tools/android/test/jni/object_tracking/image_neon.cc similarity index 96% rename from tensorflow/examples/android/jni/object_tracking/image_neon.cc rename to tensorflow/tools/android/test/jni/object_tracking/image_neon.cc index 5f1023cbed3..2bd31f5618b 100644 --- a/tensorflow/examples/android/jni/object_tracking/image_neon.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/image_neon.cc @@ -19,13 +19,12 @@ limitations under the License. #ifdef __ARM_NEON #include - #include -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/image_utils.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image_utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/image_utils.h b/tensorflow/tools/android/test/jni/object_tracking/image_utils.h similarity index 97% rename from tensorflow/examples/android/jni/object_tracking/image_utils.h rename to tensorflow/tools/android/test/jni/object_tracking/image_utils.h index b4ad7000b33..9e301e29a0e 100644 --- a/tensorflow/examples/android/jni/object_tracking/image_utils.h +++ b/tensorflow/tools/android/test/jni/object_tracking/image_utils.h @@ -18,11 +18,10 @@ limitations under the License. #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/integral_image.h b/tensorflow/tools/android/test/jni/object_tracking/integral_image.h similarity index 95% rename from tensorflow/examples/android/jni/object_tracking/integral_image.h rename to tensorflow/tools/android/test/jni/object_tracking/integral_image.h index caf9b7d2ab8..9306b2aa3d4 100755 --- a/tensorflow/examples/android/jni/object_tracking/integral_image.h +++ b/tensorflow/tools/android/test/jni/object_tracking/integral_image.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_INTEGRAL_IMAGE_H_ -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/jni_utils.h b/tensorflow/tools/android/test/jni/object_tracking/jni_utils.h similarity index 97% rename from tensorflow/examples/android/jni/object_tracking/jni_utils.h rename to tensorflow/tools/android/test/jni/object_tracking/jni_utils.h index 5f622a2e65f..015e6c09b41 100644 --- a/tensorflow/examples/android/jni/object_tracking/jni_utils.h +++ b/tensorflow/tools/android/test/jni/object_tracking/jni_utils.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" // The JniLongField class is used to access Java fields from native code. This // technique of hiding pointers to native objects in opaque Java fields is how diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint.h b/tensorflow/tools/android/test/jni/object_tracking/keypoint.h similarity index 73% rename from tensorflow/examples/android/jni/object_tracking/keypoint.h rename to tensorflow/tools/android/test/jni/object_tracking/keypoint.h index 93405a5b2a8..fb6cd3162b3 100644 --- a/tensorflow/examples/android/jni/object_tracking/keypoint.h +++ b/tensorflow/tools/android/test/jni/object_tracking/keypoint.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_H_ -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/logging.h" -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/logging.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc b/tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.cc similarity index 97% rename from tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc rename to tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.cc index fc60d2a5ca9..c1919ccd452 100644 --- a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.cc @@ -15,16 +15,16 @@ limitations under the License. // Various keypoint detecting functions. +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h" + #include -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h b/tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h similarity index 94% rename from tensorflow/examples/android/jni/object_tracking/keypoint_detector.h rename to tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h index 2e85b835a70..3b96d1be4e4 100644 --- a/tensorflow/examples/android/jni/object_tracking/keypoint_detector.h +++ b/tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h @@ -17,12 +17,13 @@ limitations under the License. #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_KEYPOINT_DETECTOR_H_ #include + #include -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/image_data.h" -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image_data.h" +#include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/logging.cc b/tensorflow/tools/android/test/jni/object_tracking/logging.cc similarity index 98% rename from tensorflow/examples/android/jni/object_tracking/logging.cc rename to tensorflow/tools/android/test/jni/object_tracking/logging.cc index 0ba8a541ab7..dc82cce0620 100644 --- a/tensorflow/examples/android/jni/object_tracking/logging.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/logging.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/examples/android/jni/object_tracking/logging.h" +#include "tensorflow/tools/android/test/jni/object_tracking/logging.h" #ifdef STANDALONE_DEMO_LIB diff --git a/tensorflow/examples/android/jni/object_tracking/logging.h b/tensorflow/tools/android/test/jni/object_tracking/logging.h similarity index 100% rename from tensorflow/examples/android/jni/object_tracking/logging.h rename to tensorflow/tools/android/test/jni/object_tracking/logging.h diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.cc b/tensorflow/tools/android/test/jni/object_tracking/object_detector.cc similarity index 93% rename from tensorflow/examples/android/jni/object_tracking/object_detector.cc rename to tensorflow/tools/android/test/jni/object_tracking/object_detector.cc index 7f65716fdf1..11eef293c2f 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_detector.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/object_detector.cc @@ -17,7 +17,7 @@ limitations under the License. // in this directory. This class remains mainly for historical reasons. // Detection in the TF demo is done through TensorFlowMultiBoxDetector.java. -#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" +#include "tensorflow/tools/android/test/jni/object_tracking/object_detector.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/object_detector.h b/tensorflow/tools/android/test/jni/object_tracking/object_detector.h similarity index 92% rename from tensorflow/examples/android/jni/object_tracking/object_detector.h rename to tensorflow/tools/android/test/jni/object_tracking/object_detector.h index a65c7b0db70..20e614b25df 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_detector.h +++ b/tensorflow/tools/android/test/jni/object_tracking/object_detector.h @@ -24,24 +24,24 @@ limitations under the License. #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_DETECTOR_H_ #include + #include #include #include #include #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h" #ifdef __RENDER_OPENGL__ -#include "tensorflow/examples/android/jni/object_tracking/sprite.h" +#include "tensorflow/tools/android/test/jni/object_tracking/sprite.h" #endif -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/image_data.h" -#include "tensorflow/examples/android/jni/object_tracking/object_model.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image_data.h" +#include "tensorflow/tools/android/test/jni/object_tracking/object_model.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/object_model.h b/tensorflow/tools/android/test/jni/object_tracking/object_model.h similarity index 82% rename from tensorflow/examples/android/jni/object_tracking/object_model.h rename to tensorflow/tools/android/test/jni/object_tracking/object_model.h index 4bc4d5bc9eb..9a2ace01f0e 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_model.h +++ b/tensorflow/tools/android/test/jni/object_tracking/object_model.h @@ -29,18 +29,17 @@ limitations under the License. #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h" #ifdef __RENDER_OPENGL__ -#include "tensorflow/examples/android/jni/object_tracking/sprite.h" +#include "tensorflow/tools/android/test/jni/object_tracking/sprite.h" #endif -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/image_data.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image_data.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.cc b/tensorflow/tools/android/test/jni/object_tracking/object_tracker.cc similarity index 95% rename from tensorflow/examples/android/jni/object_tracking/object_tracker.cc rename to tensorflow/tools/android/test/jni/object_tracking/object_tracker.cc index 832dee5656f..a5345742c05 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_tracker.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/object_tracker.cc @@ -22,19 +22,19 @@ limitations under the License. #include #include -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" -#include "tensorflow/examples/android/jni/object_tracking/logging.h" -#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" -#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/flow_cache.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/tools/android/test/jni/object_tracking/logging.h" +#include "tensorflow/tools/android/test/jni/object_tracking/object_detector.h" +#include "tensorflow/tools/android/test/jni/object_tracking/object_tracker.h" +#include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker.h b/tensorflow/tools/android/test/jni/object_tracking/object_tracker.h similarity index 92% rename from tensorflow/examples/android/jni/object_tracking/object_tracker.h rename to tensorflow/tools/android/test/jni/object_tracking/object_tracker.h index 20c7627fc5f..522f90e8470 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_tracker.h +++ b/tensorflow/tools/android/test/jni/object_tracking/object_tracker.h @@ -19,18 +19,17 @@ limitations under the License. #include #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/integral_image.h" -#include "tensorflow/examples/android/jni/object_tracking/logging.h" -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" -#include "tensorflow/examples/android/jni/object_tracking/object_model.h" -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" -#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/flow_cache.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/tools/android/test/jni/object_tracking/logging.h" +#include "tensorflow/tools/android/test/jni/object_tracking/object_model.h" +#include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" +#include "tensorflow/tools/android/test/jni/object_tracking/tracked_object.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc b/tensorflow/tools/android/test/jni/object_tracking/object_tracker_jni.cc similarity index 97% rename from tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc rename to tensorflow/tools/android/test/jni/object_tracking/object_tracker_jni.cc index 216903802eb..33db7d52eb3 100644 --- a/tensorflow/examples/android/jni/object_tracking/object_tracker_jni.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/object_tracker_jni.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include #include + #include -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/jni_utils.h" -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/object_tracker.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/jni_utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/object_tracker.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.cc b/tensorflow/tools/android/test/jni/object_tracking/optical_flow.cc similarity index 95% rename from tensorflow/examples/android/jni/object_tracking/optical_flow.cc rename to tensorflow/tools/android/test/jni/object_tracking/optical_flow.cc index 53c7b7328a4..35476cad9da 100644 --- a/tensorflow/examples/android/jni/object_tracking/optical_flow.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/optical_flow.cc @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" + #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" -#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" -#include "tensorflow/examples/android/jni/object_tracking/image_data.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" -#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/flow_cache.h" +#include "tensorflow/tools/android/test/jni/object_tracking/frame_pair.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image_data.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/optical_flow.h b/tensorflow/tools/android/test/jni/object_tracking/optical_flow.h similarity index 86% rename from tensorflow/examples/android/jni/object_tracking/optical_flow.h rename to tensorflow/tools/android/test/jni/object_tracking/optical_flow.h index f98ae22bd64..c33c15a897e 100644 --- a/tensorflow/examples/android/jni/object_tracking/optical_flow.h +++ b/tensorflow/tools/android/test/jni/object_tracking/optical_flow.h @@ -16,15 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OPTICAL_FLOW_H_ -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" - -#include "tensorflow/examples/android/jni/object_tracking/config.h" -#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" -#include "tensorflow/examples/android/jni/object_tracking/image_data.h" -#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" +#include "tensorflow/tools/android/test/jni/object_tracking/config.h" +#include "tensorflow/tools/android/test/jni/object_tracking/frame_pair.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image_data.h" +#include "tensorflow/tools/android/test/jni/object_tracking/keypoint.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/sprite.h b/tensorflow/tools/android/test/jni/object_tracking/sprite.h similarity index 97% rename from tensorflow/examples/android/jni/object_tracking/sprite.h rename to tensorflow/tools/android/test/jni/object_tracking/sprite.h index 964f1c30bfa..f12200303bb 100755 --- a/tensorflow/examples/android/jni/object_tracking/sprite.h +++ b/tensorflow/tools/android/test/jni/object_tracking/sprite.h @@ -21,8 +21,8 @@ limitations under the License. #include #include -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.cc b/tensorflow/tools/android/test/jni/object_tracking/time_log.cc similarity index 92% rename from tensorflow/examples/android/jni/object_tracking/time_log.cc rename to tensorflow/tools/android/test/jni/object_tracking/time_log.cc index 8fce78f915f..56416bb2c20 100644 --- a/tensorflow/examples/android/jni/object_tracking/time_log.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/time_log.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/examples/android/jni/object_tracking/time_log.h" +#include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" #ifdef LOG_TIME // Storage for logging functionality. diff --git a/tensorflow/examples/android/jni/object_tracking/time_log.h b/tensorflow/tools/android/test/jni/object_tracking/time_log.h similarity index 96% rename from tensorflow/examples/android/jni/object_tracking/time_log.h rename to tensorflow/tools/android/test/jni/object_tracking/time_log.h index 0073e115963..5f9011c28b3 100644 --- a/tensorflow/examples/android/jni/object_tracking/time_log.h +++ b/tensorflow/tools/android/test/jni/object_tracking/time_log.h @@ -20,8 +20,8 @@ limitations under the License. #include -#include "tensorflow/examples/android/jni/object_tracking/logging.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/logging.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" #ifdef LOG_TIME diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.cc b/tensorflow/tools/android/test/jni/object_tracking/tracked_object.cc similarity index 98% rename from tensorflow/examples/android/jni/object_tracking/tracked_object.cc rename to tensorflow/tools/android/test/jni/object_tracking/tracked_object.cc index b243b84ef79..28e5158f68b 100644 --- a/tensorflow/examples/android/jni/object_tracking/tracked_object.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/tracked_object.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/examples/android/jni/object_tracking/tracked_object.h" +#include "tensorflow/tools/android/test/jni/object_tracking/tracked_object.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/tracked_object.h b/tensorflow/tools/android/test/jni/object_tracking/tracked_object.h similarity index 97% rename from tensorflow/examples/android/jni/object_tracking/tracked_object.h rename to tensorflow/tools/android/test/jni/object_tracking/tracked_object.h index 6a85449c1e1..31846d3d3a4 100644 --- a/tensorflow/examples/android/jni/object_tracking/tracked_object.h +++ b/tensorflow/tools/android/test/jni/object_tracking/tracked_object.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_TRACKED_OBJECT_H_ #ifdef __RENDER_OPENGL__ -#include "tensorflow/examples/android/jni/object_tracking/gl_utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/gl_utils.h" #endif -#include "tensorflow/examples/android/jni/object_tracking/object_detector.h" +#include "tensorflow/tools/android/test/jni/object_tracking/object_detector.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/object_tracking/utils.h b/tensorflow/tools/android/test/jni/object_tracking/utils.h similarity index 99% rename from tensorflow/examples/android/jni/object_tracking/utils.h rename to tensorflow/tools/android/test/jni/object_tracking/utils.h index 2e98734ec4e..09d06ecb014 100644 --- a/tensorflow/examples/android/jni/object_tracking/utils.h +++ b/tensorflow/tools/android/test/jni/object_tracking/utils.h @@ -28,7 +28,7 @@ limitations under the License. #include #endif // ifdef HAVE_CLOCK_GETTIME -#include "tensorflow/examples/android/jni/object_tracking/logging.h" +#include "tensorflow/tools/android/test/jni/object_tracking/logging.h" // TODO(andrewharp): clean up these macros to use the codebase statndard. diff --git a/tensorflow/examples/android/jni/object_tracking/utils_neon.cc b/tensorflow/tools/android/test/jni/object_tracking/utils_neon.cc similarity index 94% rename from tensorflow/examples/android/jni/object_tracking/utils_neon.cc rename to tensorflow/tools/android/test/jni/object_tracking/utils_neon.cc index 5a5250e32ea..1f1dff53127 100755 --- a/tensorflow/examples/android/jni/object_tracking/utils_neon.cc +++ b/tensorflow/tools/android/test/jni/object_tracking/utils_neon.cc @@ -20,10 +20,10 @@ limitations under the License. #include -#include "tensorflow/examples/android/jni/object_tracking/geom.h" -#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" -#include "tensorflow/examples/android/jni/object_tracking/image.h" -#include "tensorflow/examples/android/jni/object_tracking/utils.h" +#include "tensorflow/tools/android/test/jni/object_tracking/geom.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h" +#include "tensorflow/tools/android/test/jni/object_tracking/image.h" +#include "tensorflow/tools/android/test/jni/object_tracking/utils.h" namespace tf_tracking { diff --git a/tensorflow/examples/android/jni/rgb2yuv.cc b/tensorflow/tools/android/test/jni/rgb2yuv.cc similarity index 98% rename from tensorflow/examples/android/jni/rgb2yuv.cc rename to tensorflow/tools/android/test/jni/rgb2yuv.cc index b20ce580a27..6bb9e4847ea 100755 --- a/tensorflow/examples/android/jni/rgb2yuv.cc +++ b/tensorflow/tools/android/test/jni/rgb2yuv.cc @@ -15,7 +15,7 @@ limitations under the License. // These utility functions allow for the conversion of RGB data to YUV data. -#include "tensorflow/examples/android/jni/rgb2yuv.h" +#include "tensorflow/tools/android/test/jni/rgb2yuv.h" static inline void WriteYUV(const int x, const int y, const int width, const int r8, const int g8, const int b8, diff --git a/tensorflow/examples/android/jni/rgb2yuv.h b/tensorflow/tools/android/test/jni/rgb2yuv.h similarity index 100% rename from tensorflow/examples/android/jni/rgb2yuv.h rename to tensorflow/tools/android/test/jni/rgb2yuv.h diff --git a/tensorflow/examples/android/jni/version_script.lds b/tensorflow/tools/android/test/jni/version_script.lds similarity index 100% rename from tensorflow/examples/android/jni/version_script.lds rename to tensorflow/tools/android/test/jni/version_script.lds diff --git a/tensorflow/examples/android/jni/yuv2rgb.cc b/tensorflow/tools/android/test/jni/yuv2rgb.cc similarity index 99% rename from tensorflow/examples/android/jni/yuv2rgb.cc rename to tensorflow/tools/android/test/jni/yuv2rgb.cc index 96d9632e572..91c163eec94 100644 --- a/tensorflow/examples/android/jni/yuv2rgb.cc +++ b/tensorflow/tools/android/test/jni/yuv2rgb.cc @@ -16,7 +16,7 @@ limitations under the License. // This is a collection of routines which converts various YUV image formats // to ARGB. -#include "tensorflow/examples/android/jni/yuv2rgb.h" +#include "tensorflow/tools/android/test/jni/yuv2rgb.h" #ifndef MAX #define MAX(a, b) ({__typeof__(a) _a = (a); __typeof__(b) _b = (b); _a > _b ? _a : _b; }) diff --git a/tensorflow/examples/android/jni/yuv2rgb.h b/tensorflow/tools/android/test/jni/yuv2rgb.h similarity index 100% rename from tensorflow/examples/android/jni/yuv2rgb.h rename to tensorflow/tools/android/test/jni/yuv2rgb.h diff --git a/tensorflow/examples/android/res/animator/color_animation.xml b/tensorflow/tools/android/test/res/animator/color_animation.xml similarity index 100% rename from tensorflow/examples/android/res/animator/color_animation.xml rename to tensorflow/tools/android/test/res/animator/color_animation.xml diff --git a/tensorflow/tools/android/test/res/drawable-hdpi/ic_action_info.png b/tensorflow/tools/android/test/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 00000000000..cd333b6f080 Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/tools/android/test/res/drawable-hdpi/ic_launcher.png b/tensorflow/tools/android/test/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 00000000000..52cf2ab9529 Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/tools/android/test/res/drawable-hdpi/tile.9.png b/tensorflow/tools/android/test/res/drawable-hdpi/tile.9.png new file mode 100644 index 00000000000..a84e3ef52c6 Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/tools/android/test/res/drawable-mdpi/ic_action_info.png b/tensorflow/tools/android/test/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 00000000000..b08a3e0f54b Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/tools/android/test/res/drawable-mdpi/ic_launcher.png b/tensorflow/tools/android/test/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 00000000000..b75f892c462 Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/tools/android/test/res/drawable-xhdpi/ic_action_info.png b/tensorflow/tools/android/test/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 00000000000..eb174e17169 Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/tools/android/test/res/drawable-xhdpi/ic_launcher.png b/tensorflow/tools/android/test/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 00000000000..36e14c48d14 Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/tools/android/test/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/tools/android/test/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 00000000000..9654fd3b416 Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/tools/android/test/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/tools/android/test/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 00000000000..06dd2a740ec Binary files /dev/null and b/tensorflow/tools/android/test/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/examples/android/res/drawable/border.xml b/tensorflow/tools/android/test/res/drawable/border.xml similarity index 100% rename from tensorflow/examples/android/res/drawable/border.xml rename to tensorflow/tools/android/test/res/drawable/border.xml diff --git a/tensorflow/examples/android/res/layout/activity_camera.xml b/tensorflow/tools/android/test/res/layout/activity_camera.xml similarity index 100% rename from tensorflow/examples/android/res/layout/activity_camera.xml rename to tensorflow/tools/android/test/res/layout/activity_camera.xml diff --git a/tensorflow/examples/android/res/layout/activity_speech.xml b/tensorflow/tools/android/test/res/layout/activity_speech.xml similarity index 100% rename from tensorflow/examples/android/res/layout/activity_speech.xml rename to tensorflow/tools/android/test/res/layout/activity_speech.xml diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment.xml b/tensorflow/tools/android/test/res/layout/camera_connection_fragment.xml similarity index 100% rename from tensorflow/examples/android/res/layout/camera_connection_fragment.xml rename to tensorflow/tools/android/test/res/layout/camera_connection_fragment.xml diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/tools/android/test/res/layout/camera_connection_fragment_stylize.xml similarity index 100% rename from tensorflow/examples/android/res/layout/camera_connection_fragment_stylize.xml rename to tensorflow/tools/android/test/res/layout/camera_connection_fragment_stylize.xml diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/tools/android/test/res/layout/camera_connection_fragment_tracking.xml similarity index 100% rename from tensorflow/examples/android/res/layout/camera_connection_fragment_tracking.xml rename to tensorflow/tools/android/test/res/layout/camera_connection_fragment_tracking.xml diff --git a/tensorflow/examples/android/res/layout/list_text_item.xml b/tensorflow/tools/android/test/res/layout/list_text_item.xml similarity index 100% rename from tensorflow/examples/android/res/layout/list_text_item.xml rename to tensorflow/tools/android/test/res/layout/list_text_item.xml diff --git a/tensorflow/examples/android/res/values-sw600dp/template-dimens.xml b/tensorflow/tools/android/test/res/values-sw600dp/template-dimens.xml similarity index 100% rename from tensorflow/examples/android/res/values-sw600dp/template-dimens.xml rename to tensorflow/tools/android/test/res/values-sw600dp/template-dimens.xml diff --git a/tensorflow/examples/android/res/values-sw600dp/template-styles.xml b/tensorflow/tools/android/test/res/values-sw600dp/template-styles.xml similarity index 100% rename from tensorflow/examples/android/res/values-sw600dp/template-styles.xml rename to tensorflow/tools/android/test/res/values-sw600dp/template-styles.xml diff --git a/tensorflow/examples/android/res/values-v11/styles.xml b/tensorflow/tools/android/test/res/values-v11/styles.xml similarity index 100% rename from tensorflow/examples/android/res/values-v11/styles.xml rename to tensorflow/tools/android/test/res/values-v11/styles.xml diff --git a/tensorflow/examples/android/res/values-v11/template-styles.xml b/tensorflow/tools/android/test/res/values-v11/template-styles.xml similarity index 100% rename from tensorflow/examples/android/res/values-v11/template-styles.xml rename to tensorflow/tools/android/test/res/values-v11/template-styles.xml diff --git a/tensorflow/examples/android/res/values-v14/styles.xml b/tensorflow/tools/android/test/res/values-v14/styles.xml similarity index 100% rename from tensorflow/examples/android/res/values-v14/styles.xml rename to tensorflow/tools/android/test/res/values-v14/styles.xml diff --git a/tensorflow/examples/android/res/values-v21/base-colors.xml b/tensorflow/tools/android/test/res/values-v21/base-colors.xml similarity index 100% rename from tensorflow/examples/android/res/values-v21/base-colors.xml rename to tensorflow/tools/android/test/res/values-v21/base-colors.xml diff --git a/tensorflow/examples/android/res/values-v21/base-template-styles.xml b/tensorflow/tools/android/test/res/values-v21/base-template-styles.xml similarity index 100% rename from tensorflow/examples/android/res/values-v21/base-template-styles.xml rename to tensorflow/tools/android/test/res/values-v21/base-template-styles.xml diff --git a/tensorflow/examples/android/res/values/attrs.xml b/tensorflow/tools/android/test/res/values/attrs.xml similarity index 100% rename from tensorflow/examples/android/res/values/attrs.xml rename to tensorflow/tools/android/test/res/values/attrs.xml diff --git a/tensorflow/examples/android/res/values/base-strings.xml b/tensorflow/tools/android/test/res/values/base-strings.xml similarity index 100% rename from tensorflow/examples/android/res/values/base-strings.xml rename to tensorflow/tools/android/test/res/values/base-strings.xml diff --git a/tensorflow/examples/android/res/values/colors.xml b/tensorflow/tools/android/test/res/values/colors.xml similarity index 100% rename from tensorflow/examples/android/res/values/colors.xml rename to tensorflow/tools/android/test/res/values/colors.xml diff --git a/tensorflow/examples/android/res/values/strings.xml b/tensorflow/tools/android/test/res/values/strings.xml similarity index 100% rename from tensorflow/examples/android/res/values/strings.xml rename to tensorflow/tools/android/test/res/values/strings.xml diff --git a/tensorflow/examples/android/res/values/styles.xml b/tensorflow/tools/android/test/res/values/styles.xml similarity index 100% rename from tensorflow/examples/android/res/values/styles.xml rename to tensorflow/tools/android/test/res/values/styles.xml diff --git a/tensorflow/examples/android/res/values/template-dimens.xml b/tensorflow/tools/android/test/res/values/template-dimens.xml similarity index 100% rename from tensorflow/examples/android/res/values/template-dimens.xml rename to tensorflow/tools/android/test/res/values/template-dimens.xml diff --git a/tensorflow/examples/android/res/values/template-styles.xml b/tensorflow/tools/android/test/res/values/template-styles.xml similarity index 100% rename from tensorflow/examples/android/res/values/template-styles.xml rename to tensorflow/tools/android/test/res/values/template-styles.xml diff --git a/tensorflow/examples/android/sample_images/classify1.jpg b/tensorflow/tools/android/test/sample_images/classify1.jpg similarity index 100% rename from tensorflow/examples/android/sample_images/classify1.jpg rename to tensorflow/tools/android/test/sample_images/classify1.jpg diff --git a/tensorflow/examples/android/sample_images/detect1.jpg b/tensorflow/tools/android/test/sample_images/detect1.jpg similarity index 100% rename from tensorflow/examples/android/sample_images/detect1.jpg rename to tensorflow/tools/android/test/sample_images/detect1.jpg diff --git a/tensorflow/examples/android/sample_images/stylize1.jpg b/tensorflow/tools/android/test/sample_images/stylize1.jpg similarity index 100% rename from tensorflow/examples/android/sample_images/stylize1.jpg rename to tensorflow/tools/android/test/sample_images/stylize1.jpg diff --git a/tensorflow/examples/android/settings.gradle b/tensorflow/tools/android/test/settings.gradle similarity index 100% rename from tensorflow/examples/android/settings.gradle rename to tensorflow/tools/android/test/settings.gradle diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/AutoFitTextureView.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/AutoFitTextureView.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/CameraActivity.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/CameraActivity.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/CameraConnectionFragment.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/CameraConnectionFragment.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/Classifier.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/Classifier.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/ClassifierActivity.java similarity index 99% rename from tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/ClassifierActivity.java index e2c394dde92..69595674ad4 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/ClassifierActivity.java +++ b/tensorflow/tools/android/test/src/org/tensorflow/demo/ClassifierActivity.java @@ -26,8 +26,6 @@ import android.media.ImageReader.OnImageAvailableListener; import android.os.SystemClock; import android.util.Size; import android.util.TypedValue; -import android.view.Display; -import android.view.Surface; import java.util.List; import java.util.Vector; import org.tensorflow.demo.OverlayView.DrawCallback; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/DetectorActivity.java similarity index 99% rename from tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/DetectorActivity.java index f778f3de425..5a7893d7011 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/tools/android/test/src/org/tensorflow/demo/DetectorActivity.java @@ -29,8 +29,6 @@ import android.media.ImageReader.OnImageAvailableListener; import android.os.SystemClock; import android.util.Size; import android.util.TypedValue; -import android.view.Display; -import android.view.Surface; import android.widget.Toast; import java.io.IOException; import java.util.LinkedList; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/OverlayView.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/OverlayView.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/OverlayView.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/OverlayView.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/RecognitionScoreView.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/RecognitionScoreView.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/RecognizeCommands.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/RecognizeCommands.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/RecognizeCommands.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/ResultsView.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/ResultsView.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/ResultsView.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/ResultsView.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/SpeechActivity.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/SpeechActivity.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/StylizeActivity.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/StylizeActivity.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowImageClassifier.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowImageClassifier.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowYoloDetector.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/TensorFlowYoloDetector.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/BorderedText.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/env/BorderedText.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/env/BorderedText.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/env/BorderedText.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/env/ImageUtils.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/env/ImageUtils.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/env/Logger.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/env/Logger.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/Size.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/env/Size.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/env/Size.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/env/Size.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/env/SplitTimer.java similarity index 100% rename from tensorflow/examples/android/src/org/tensorflow/demo/env/SplitTimer.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/env/SplitTimer.java diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/tracking/MultiBoxTracker.java similarity index 99% rename from tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/tracking/MultiBoxTracker.java index af6af2bc8f5..b74c5797604 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java +++ b/tensorflow/tools/android/test/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -215,7 +215,7 @@ public class MultiBoxTracker { if (objectTracker == null) { String message = "Object tracking support not found. " - + "See tensorflow/examples/android/README.md for details."; + + "See tensorflow/tools/android/test/README.md for details."; Toast.makeText(context, message, Toast.LENGTH_LONG).show(); logger.e(message); } diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/tools/android/test/src/org/tensorflow/demo/tracking/ObjectTracker.java similarity index 99% rename from tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java rename to tensorflow/tools/android/test/src/org/tensorflow/demo/tracking/ObjectTracker.java index 7ad4647d40c..1652d9090a9 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java +++ b/tensorflow/tools/android/test/src/org/tensorflow/demo/tracking/ObjectTracker.java @@ -208,7 +208,7 @@ public class ObjectTracker { if (!libraryFound) { LOGGER.e( "Native object tracking support not found. " - + "See tensorflow/examples/android/README.md for details."); + + "See tensorflow/tools/android/test/README.md for details."); return null; } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt index 3d8187ca752..88fd63d9693 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -75,6 +75,13 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "mlir_bridge_rollout" + number: 17 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.ConfigProto.Experimental.MlirBridgeRollout" + } field { name: "enable_mlir_graph_optimization" number: 16 @@ -93,6 +100,21 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT64 } + enum_type { + name: "MlirBridgeRollout" + value: { + name: "MLIR_BRIDGE_ROLLOUT_UNSPECIFIED" + number: 0 + } + value: { + name: "MLIR_BRIDGE_ROLLOUT_ENABLED" + number: 1 + } + value: { + name: "MLIR_BRIDGE_ROLLOUT_DISABLED" + number: 2 + } + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt index c32fdee5af0..a598071b970 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -204,6 +204,13 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_BOOL } + field { + name: "mlir_bridge_rollout" + number: 17 + label: LABEL_OPTIONAL + type: TYPE_ENUM + type_name: ".tensorflow.ConfigProto.Experimental.MlirBridgeRollout" + } field { name: "enable_mlir_graph_optimization" number: 16 @@ -222,6 +229,21 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_INT64 } + enum_type { + name: "MlirBridgeRollout" + value: { + name: "MLIR_BRIDGE_ROLLOUT_UNSPECIFIED" + number: 0 + } + value: { + name: "MLIR_BRIDGE_ROLLOUT_ENABLED" + number: 1 + } + value: { + name: "MLIR_BRIDGE_ROLLOUT_DISABLED" + number: 2 + } + } reserved_range { start: 2 end: 3 diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-t-p-u-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-t-p-u-config.pbtxt index e329045123e..16f39f9fe2e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-t-p-u-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.-t-p-u-config.pbtxt @@ -11,6 +11,10 @@ tf_class { name: "experimental_allow_per_host_v2_parallel_get_next" mtype: "" } + member { + name: "experimental_feed_hook" + mtype: "" + } member { name: "experimental_host_call_every_n_steps" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt index 96b809486a7..0c85d31934a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-additive-attention.pbtxt @@ -150,7 +150,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt index ae2cb7f7e20..b2c3156cf7a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-attention.pbtxt @@ -150,7 +150,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt index 070ee20ab30..89fbb32194a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -149,7 +149,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt index 57876312213..b74d27bc159 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt @@ -220,6 +220,14 @@ tf_module { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "log_cosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "logcosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "mae" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt index e3c8c7e8a65..3c016d331de 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt @@ -21,18 +21,10 @@ tf_class { name: "iterations" mtype: "" } - member { - name: "learning_rate" - mtype: "" - } member { name: "loss_scale" mtype: "" } - member { - name: "lr" - mtype: "" - } member { name: "weights" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index ba64d009908..abd33957365 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1824,6 +1824,10 @@ tf_module { name: "quantize" argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'round_mode\', \'name\', \'narrow_range\', \'axis\', \'ensure_minimum_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'HALF_AWAY_FROM_ZERO\', \'None\', \'False\', \'None\', \'0.01\'], " } + member_method { + name: "quantize_and_dequantize_v4" + argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], " + } member_method { name: "quantize_v2" argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'name\', \'round_mode\', \'narrow_range\', \'axis\', \'ensure_minimum_range\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\', \'HALF_AWAY_FROM_ZERO\', \'False\', \'None\', \'0.01\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt index 047fb4deda7..269873c63ed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt @@ -36,6 +36,10 @@ tf_module { name: "quantize_and_dequantize" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], " } + member_method { + name: "quantize_and_dequantize_v2" + argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], " + } member_method { name: "quantized_concat" argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2a2c3102862..2efd289c259 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -774,7 +774,7 @@ tf_module { } member_method { name: "CollectiveReduceV2" - argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], " + argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " } member_method { name: "CombinedNonMaxSuppression" @@ -2940,6 +2940,14 @@ tf_module { name: "QuantizeAndDequantizeV3" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], " } + member_method { + name: "QuantizeAndDequantizeV4" + argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], " + } + member_method { + name: "QuantizeAndDequantizeV4Grad" + argspec: "args=[\'gradients\', \'input\', \'input_min\', \'input_max\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + } member_method { name: "QuantizeDownAndShrinkRange" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4110,7 +4118,7 @@ tf_module { } member_method { name: "SnapshotDatasetV2" - argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'reader_prefix\', \'writer_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], " } member_method { name: "SobolSample" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.combinations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.combinations.pbtxt new file mode 100644 index 00000000000..c31afabba60 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.combinations.pbtxt @@ -0,0 +1,75 @@ +path: "tensorflow.__internal__.distribute.combinations" +tf_module { + member { + name: "central_storage_strategy_with_gpu_and_cpu" + mtype: "" + } + member { + name: "central_storage_strategy_with_two_gpus" + mtype: "" + } + member { + name: "cloud_tpu_strategy" + mtype: "" + } + member { + name: "default_strategy" + mtype: "" + } + member { + name: "mirrored_strategy_with_cpu_1_and_2" + mtype: "" + } + member { + name: "mirrored_strategy_with_gpu_and_cpu" + mtype: "" + } + member { + name: "mirrored_strategy_with_one_cpu" + mtype: "" + } + member { + name: "mirrored_strategy_with_one_gpu" + mtype: "" + } + member { + name: "mirrored_strategy_with_two_gpus" + mtype: "" + } + member { + name: "multi_worker_mirrored_2x1_cpu" + mtype: "" + } + member { + name: "multi_worker_mirrored_2x1_gpu" + mtype: "" + } + member { + name: "multi_worker_mirrored_2x2_gpu" + mtype: "" + } + member { + name: "one_device_strategy" + mtype: "" + } + member { + name: "one_device_strategy_gpu" + mtype: "" + } + member { + name: "tpu_strategy" + mtype: "" + } + member { + name: "tpu_strategy_one_core" + mtype: "" + } + member { + name: "tpu_strategy_packed_var" + mtype: "" + } + member_method { + name: "generate" + argspec: "args=[\'combinations\', \'test_combinations\'], varargs=None, keywords=None, defaults=[\'()\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-not-initialized-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-not-initialized-error.pbtxt new file mode 100644 index 00000000000..c6d5ab61840 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-not-initialized-error.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.__internal__.distribute.multi_process_runner.NotInitializedError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member_method { + name: "__init__" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-subprocess-timeout-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-subprocess-timeout-error.pbtxt new file mode 100644 index 00000000000..6251a8c03f0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-subprocess-timeout-error.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.__internal__.distribute.multi_process_runner.SubprocessTimeoutError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'msg\', \'mpr_result\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-unexpected-subprocess-exit-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-unexpected-subprocess-exit-error.pbtxt new file mode 100644 index 00000000000..2344e31ec48 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.-unexpected-subprocess-exit-error.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.__internal__.distribute.multi_process_runner.UnexpectedSubprocessExitError" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "args" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'msg\', \'mpr_result\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.pbtxt new file mode 100644 index 00000000000..b297d573993 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.multi_process_runner.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.__internal__.distribute.multi_process_runner" +tf_module { + member { + name: "NotInitializedError" + mtype: "" + } + member { + name: "SubprocessTimeoutError" + mtype: "" + } + member { + name: "UnexpectedSubprocessExitError" + mtype: "" + } + member_method { + name: "create_cluster_spec" + argspec: "args=[\'has_chief\', \'num_workers\', \'num_ps\', \'has_eval\'], varargs=None, keywords=None, defaults=[\'False\', \'1\', \'0\', \'False\'], " + } + member_method { + name: "get_barrier" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run" + argspec: "args=[\'fn\', \'cluster_spec\', \'rpc_layer\', \'max_run_time\', \'return_output\', \'timeout\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'200\', \'None\', \'None\'], " + } + member_method { + name: "test_main" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.pbtxt new file mode 100644 index 00000000000..2a019dff514 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.distribute.pbtxt @@ -0,0 +1,11 @@ +path: "tensorflow.__internal__.distribute" +tf_module { + member { + name: "combinations" + mtype: "" + } + member { + name: "multi_process_runner" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt index ea35bf3e1cc..effcae38787 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "decorator" mtype: "" } + member { + name: "distribute" + mtype: "" + } member { name: "test" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index 598c01c6da0..60c0e8d7663 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -48,10 +48,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt index 5b044576d14..17af12cf279 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt @@ -48,10 +48,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index a8ebaa87590..702cdc98e88 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -47,10 +47,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt index 3a94e8b9fa4..0cb06ec3b01 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-t-p-u-strategy.pbtxt @@ -52,10 +52,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_replicate_to_logical_devices" argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index 290d834305a..e1b92b44b73 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -48,10 +48,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index 2245eb4b122..fda255eac17 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -48,10 +48,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index 707d7281a5c..e445d7c4dab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -48,10 +48,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index a2adcea87e6..6536eefe414 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -48,10 +48,6 @@ tf_class { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "experimental_make_numpy_dataset" - argspec: "args=[\'self\', \'numpy_input\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "experimental_run" argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt index 96b809486a7..0c85d31934a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-additive-attention.pbtxt @@ -150,7 +150,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt index ae2cb7f7e20..b2c3156cf7a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-attention.pbtxt @@ -150,7 +150,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt index 070ee20ab30..89fbb32194a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -149,7 +149,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt index a58ffc1c2a5..c66604af334 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-normalization.pbtxt @@ -119,7 +119,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'axis\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'self\', \'axis\', \'dtype\', \'mean\', \'variance\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'None\', \'None\', \'None\'], " } member_method { name: "adapt" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt index 4d5a28fc8b4..9fc7c410480 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-text-vectorization.pbtxt @@ -119,7 +119,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'max_tokens\', \'standardize\', \'split\', \'ngrams\', \'output_mode\', \'output_sequence_length\', \'pad_to_max_tokens\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'lower_and_strip_punctuation\', \'whitespace\', \'None\', \'int\', \'None\', \'True\'], " + argspec: "args=[\'self\', \'max_tokens\', \'standardize\', \'split\', \'ngrams\', \'output_mode\', \'output_sequence_length\', \'pad_to_max_tokens\', \'vocabulary\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'lower_and_strip_punctuation\', \'whitespace\', \'None\', \'int\', \'None\', \'True\', \'None\'], " } member_method { name: "adapt" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt index 17768aeafbe..cffbe8a2dde 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt @@ -212,6 +212,14 @@ tf_module { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "log_cosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "logcosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "mae" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt index e3c8c7e8a65..3c016d331de 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt @@ -21,18 +21,10 @@ tf_class { name: "iterations" mtype: "" } - member { - name: "learning_rate" - mtype: "" - } member { name: "loss_scale" mtype: "" } - member { - name: "lr" - mtype: "" - } member { name: "weights" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt index c37ee31da6b..e3435a32bef 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt @@ -24,7 +24,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], " + argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'auto\'], " } member_method { name: "from_config" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt index b3c87d67d2b..494ebc98058 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt @@ -212,6 +212,14 @@ tf_module { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "log_cosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "logcosh" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "mae" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 83baba1b1ce..0aa5e8924a3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -880,6 +880,10 @@ tf_module { name: "py_function" argspec: "args=[\'func\', \'inp\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "quantize_and_dequantize_v4" + argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], " + } member_method { name: "range" argspec: "args=[\'start\', \'limit\', \'delta\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'range\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt index 047fb4deda7..269873c63ed 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt @@ -36,6 +36,10 @@ tf_module { name: "quantize_and_dequantize" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], " } + member_method { + name: "quantize_and_dequantize_v2" + argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\', \'narrow_range\', \'axis\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\', \'False\', \'None\'], " + } member_method { name: "quantized_concat" argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2a2c3102862..2efd289c259 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -774,7 +774,7 @@ tf_module { } member_method { name: "CollectiveReduceV2" - argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'None\'], " + argspec: "args=[\'input\', \'group_size\', \'group_key\', \'instance_key\', \'merge_op\', \'final_op\', \'communication_hint\', \'timeout_seconds\', \'name\'], varargs=None, keywords=None, defaults=[\'auto\', \'0\', \'None\'], " } member_method { name: "CombinedNonMaxSuppression" @@ -2940,6 +2940,14 @@ tf_module { name: "QuantizeAndDequantizeV3" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'num_bits\', \'signed_input\', \'range_given\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'False\', \'-1\', \'None\'], " } + member_method { + name: "QuantizeAndDequantizeV4" + argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'narrow_range\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'False\', \'-1\', \'None\'], " + } + member_method { + name: "QuantizeAndDequantizeV4Grad" + argspec: "args=[\'gradients\', \'input\', \'input_min\', \'input_max\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + } member_method { name: "QuantizeDownAndShrinkRange" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4110,7 +4118,7 @@ tf_module { } member_method { name: "SnapshotDatasetV2" - argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], " + argspec: "args=[\'input_dataset\', \'path\', \'reader_func_other_args\', \'shard_func_other_args\', \'output_types\', \'output_shapes\', \'reader_func\', \'shard_func\', \'compression\', \'reader_prefix\', \'writer_prefix\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], " } member_method { name: "SobolSample" diff --git a/tensorflow/tools/benchmark/BUILD b/tensorflow/tools/benchmark/BUILD index 674133431f1..cdd0e0d9452 100644 --- a/tensorflow/tools/benchmark/BUILD +++ b/tensorflow/tools/benchmark/BUILD @@ -62,7 +62,7 @@ tf_cc_test( # This binary may be built for either desktop or Android. # A typical Android build command will look like the following: -# bazel build tensorflow/core:android_tensorflow_lib \ +# bazel build tensorflow/core:portable_tensorflow_lib \ # --crosstool_top=//external:android/crosstool \ # --cpu=armeabi-v7a \ # --host_crosstool_top=@bazel_tools//tools/cpp:toolchain diff --git a/tensorflow/tools/benchmark/README.md b/tensorflow/tools/benchmark/README.md index dee1a20f3f0..758e01c8072 100644 --- a/tensorflow/tools/benchmark/README.md +++ b/tensorflow/tools/benchmark/README.md @@ -9,9 +9,11 @@ both on desktop machines and on Android. ### On Android: -(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK. +(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android +to edit the `WORKSPACE` to configure the android NDK/SDK. (1) build for your specific platform, e.g.: + ``` bazel build -c opt \ --crosstool_top=//external:android/crosstool \ @@ -23,14 +25,19 @@ bazel build -c opt \ (2) Connect your phone. Push the binary to your phone with adb push (make the directory if required): + ``` adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp ``` (3) Push the compute graph that you need to test. For example: - adb push tensorflow_inception_graph.pb /data/local/tmp + +``` +adb push tensorflow_inception_graph.pb /data/local/tmp +``` (4) Run the benchmark. For example: + ``` adb shell /data/local/tmp/benchmark_model \ --graph=/data/local/tmp/tensorflow_inception_graph.pb \ @@ -39,14 +46,17 @@ adb shell /data/local/tmp/benchmark_model \ --input_layer_type="float" \ --output_layer="output:0" ``` + ### On desktop: (1) build the binary + ``` bazel build -c opt tensorflow/tools/benchmark:benchmark_model ``` -(2) Run on your compute graph, similar to the Android case but without the need of adb shell. -For example: +(2) Run on your compute graph, similar to the Android case but without the need +of adb shell. For example: + ``` bazel-bin/tensorflow/tools/benchmark/benchmark_model \ --graph=tensorflow_inception_graph.pb \ diff --git a/tensorflow/tools/ci_build/Dockerfile.android b/tensorflow/tools/ci_build/Dockerfile.android index 80949ac64eb..03edba478af 100644 --- a/tensorflow/tools/ci_build/Dockerfile.android +++ b/tensorflow/tools/ci_build/Dockerfile.android @@ -42,7 +42,7 @@ RUN cd ${ANDROID_DEV_HOME} && \ echo y | android update sdk --no-ui -a --filter tools,platform-tools,android-${ANDROID_API_LEVEL},build-tools-${ANDROID_BUILD_TOOLS_VERSION} # Install Android NDK. -ENV ANDROID_NDK_FILENAME android-ndk-r17c-linux-x86_64.zip +ENV ANDROID_NDK_FILENAME android-ndk-r18b-linux-x86_64.zip ENV ANDROID_NDK_URL https://dl.google.com/android/repository/${ANDROID_NDK_FILENAME} ENV ANDROID_NDK_HOME ${ANDROID_DEV_HOME}/ndk ENV PATH ${PATH}:${ANDROID_NDK_HOME} diff --git a/tensorflow/tools/ci_build/Dockerfile.rocm b/tensorflow/tools/ci_build/Dockerfile.rocm index d209173258a..a72915504be 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rocm +++ b/tensorflow/tools/ci_build/Dockerfile.rocm @@ -25,6 +25,7 @@ RUN bin/bash -c 'if [[ $ROCM_DEB_REPO == http://repo.radeon.com/rocm/* ]] ; the # Install misc pkgs RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \ build-essential \ + bsdmainutils \ clang-6.0 \ clang-format-6.0 \ clang-tidy-6.0 \ diff --git a/tensorflow/tools/ci_build/builds/android_full.sh b/tensorflow/tools/ci_build/builds/android_full.sh index f76d1588e53..e0a0d65bf99 100755 --- a/tensorflow/tools/ci_build/builds/android_full.sh +++ b/tensorflow/tools/ci_build/builds/android_full.sh @@ -28,7 +28,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/builds_common.sh" configure_android_workspace -CPUS=armeabi-v7a,arm64-v8a,x86,x86_64 +CPUS=armeabi-v7a,arm64-v8a,x86 OUT_DIR="$(pwd)/out/" AAR_LIB_TMP="$(pwd)/aar_libs" @@ -44,13 +44,13 @@ do --compilation_mode=opt --cxxopt=-std=c++14 \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ - //tensorflow/core:android_tensorflow_lib \ + //tensorflow/core:portable_tensorflow_lib \ //tensorflow/tools/android/inference_interface:libtensorflow_inference.so \ - //tensorflow/examples/android:libtensorflow_demo.so \ + //tensorflow/tools/android/test:libtensorflow_demo.so \ //tensorflow/tools/benchmark:benchmark_model copy_lib bazel-bin/tensorflow/tools/android/inference_interface/libtensorflow_inference.so - copy_lib bazel-bin/tensorflow/examples/android/libtensorflow_demo.so + copy_lib bazel-bin/tensorflow/tools/android/test/libtensorflow_demo.so copy_lib bazel-bin/tensorflow/tools/benchmark/benchmark_model mkdir -p ${AAR_LIB_TMP}/jni/${CPU} @@ -68,10 +68,10 @@ bazel --bazelrc=/dev/null build --config=monolithic --fat_apk_cpu=${CPUS} \ --spawn_strategy=sandboxed --genrule_strategy=sandboxed \ //tensorflow/tools/android/inference_interface:android_tensorflow_inference_java \ //tensorflow/tools/android/inference_interface:android_tensorflow_inference_java.aar \ - //tensorflow/examples/android:tensorflow_demo + //tensorflow/tools/android/test:tensorflow_demo echo "Copying demo, AAR and Jar to ${OUT_DIR}" -cp bazel-bin/tensorflow/examples/android/tensorflow_demo.apk \ +cp bazel-bin/tensorflow/tools/android/test/tensorflow_demo.apk \ bazel-bin/tensorflow/tools/android/inference_interface/libandroid_tensorflow_inference_java.jar ${OUT_DIR} cp bazel-bin/tensorflow/tools/android/inference_interface/android_tensorflow_inference_java.aar \ diff --git a/tensorflow/tools/ci_build/builds/builds_common.sh b/tensorflow/tools/ci_build/builds/builds_common.sh index c5698f1068e..461aa81d260 100644 --- a/tensorflow/tools/ci_build/builds/builds_common.sh +++ b/tensorflow/tools/ci_build/builds/builds_common.sh @@ -222,7 +222,7 @@ configure_android_workspace() { echo "ERROR: Your WORKSPACE file does not seems to have proper android" echo " configuration and not all the environment variables expected" echo " inside ci_build android docker container are set." - echo " Please configure it manually. See: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android/README.md" + echo " Please configure it manually. See: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/android/test/README.md" else cat << EOF >> WORKSPACE android_sdk_repository( diff --git a/tensorflow/tools/ci_build/nightly_release/macos/cpu_py35.sh b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py35.sh new file mode 100644 index 00000000000..7da3b0ea9be --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py35.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh +install_bazelisk + +# Pick a version of xcode +export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +sudo xcode-select -s "${DEVELOPER_DIR}" + +sudo pip3.5 install --upgrade pip + +install_macos_pip_deps sudo pip3.5 + +# For python3 path on Mac +export PATH=$PATH:/usr/local/bin + +sudo pip install twine + +./tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.5) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package +mkdir pip_pkg +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Copy and rename to tf_nightly +for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do + copy_to_new_project_name "${f}" tf_nightly +done + +# Upload the built packages to pypi. +for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${f}" + twine upload -r pypi-warehouse "${f}" || echo + else + echo "Basic PIP test FAILED, will not upload ${f} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/macos/cpu_py36.sh b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py36.sh new file mode 100644 index 00000000000..33e1491dd86 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py36.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh +install_bazelisk + +# Pick a version of xcode +export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +sudo xcode-select -s "${DEVELOPER_DIR}" + +install_macos_pip_deps sudo pip3.6 + +# For python3 path on Mac +export PATH=$PATH:/usr/local/bin + +sudo pip install twine + +./tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.6) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package +mkdir pip_pkg +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Copy and rename to tf_nightly +for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do + copy_to_new_project_name "${f}" tf_nightly +done + +# Upload the built packages to pypi. +for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${f}" + twine upload -r pypi-warehouse "${f}" || echo + else + echo "Basic PIP test FAILED, will not upload ${f} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/macos/cpu_py37.sh b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py37.sh new file mode 100644 index 00000000000..631aea318bd --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py37.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh +install_bazelisk + +# Pick a version of xcode +export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +sudo xcode-select -s "${DEVELOPER_DIR}" + +install_macos_pip_deps sudo pip3.7 + +# For python3 path on Mac +export PATH=$PATH:/usr/local/bin + +sudo pip install twine + +./tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.7) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package +mkdir pip_pkg +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Copy and rename to tf_nightly +for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do + copy_to_new_project_name "${f}" tf_nightly +done + +# Upload the built packages to pypi. +for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${f}" + twine upload -r pypi-warehouse "${f}" || echo + else + echo "Basic PIP test FAILED, will not upload ${f} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/macos/cpu_py38.sh b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py38.sh new file mode 100644 index 00000000000..5ffef89188c --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/macos/cpu_py38.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh +install_bazelisk + +# Pick a version of xcode +export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer +sudo xcode-select -s "${DEVELOPER_DIR}" + +install_macos_pip_deps sudo pip3.8 + +# For python3 path on Mac +export PATH=$PATH:/usr/local/bin + +sudo pip install twine + +./tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.8) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_macos tensorflow/tools/pip_package:build_pip_package +mkdir pip_pkg +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Copy and rename to tf_nightly +for f in $(ls pip_pkg/tf_nightly_cpu-*dev*macosx*.whl); do + copy_to_new_project_name "${f}" tf_nightly +done + +# Upload the built packages to pypi. +for f in $(ls pip_pkg/tf_nightly*dev*macosx*.whl); do + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${f} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${f}" + twine upload -r pypi-warehouse "${f}" || echo + else + echo "Basic PIP test FAILED, will not upload ${f} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py35.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py35.sh new file mode 100644 index 00000000000..4af77739c55 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py35.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.5 + +install_bazelisk + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.5) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py36.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py36.sh new file mode 100644 index 00000000000..9cca17e5517 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py36.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.6 + +install_bazelisk + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.6) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py37.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py37.sh new file mode 100644 index 00000000000..29fe8f4c351 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py37.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.7 + +install_bazelisk + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.7) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py38.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py38.sh new file mode 100644 index 00000000000..442d6a4cc76 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py38.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.8 + +install_bazelisk + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.8) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_cpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --cpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly_cpu-*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + auditwheel repair --plat manylinux2010_x86_64 -w "${WHL_DIR}" "${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py35.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py35.sh new file mode 100644 index 00000000000..aac88b57fa7 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py35.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.5 + +install_bazelisk + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.5) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + + # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so + WHL_PATH=${AUDITED_WHL_NAME} + cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" + echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py36.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py36.sh new file mode 100644 index 00000000000..600b4b0be8e --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py36.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.6 + +install_bazelisk + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.6) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + + # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so + WHL_PATH=${AUDITED_WHL_NAME} + cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" + echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py37.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py37.sh new file mode 100644 index 00000000000..a9e51461715 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py37.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.7 + +install_bazelisk + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.7) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + + # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so + WHL_PATH=${AUDITED_WHL_NAME} + cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" + echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py38.sh b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py38.sh new file mode 100644 index 00000000000..0b8fd1380f2 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py38.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +install_ubuntu_16_pip_deps pip3.8 + +pip3.7 install --upgrade auditwheel --user + +update_bazel_linux + +python2.7 tensorflow/tools/ci_build/update_version.py --nightly + +# Run configure. +export CC_OPT_FLAGS='-mavx' +export PYTHON_BIN_PATH=$(which python3.8) +yes "" | "$PYTHON_BIN_PATH" configure.py + +# Build the pip package +bazel build --config=release_gpu_linux tensorflow/tools/pip_package:build_pip_package + +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --nightly_flag +./bazel-bin/tensorflow/tools/pip_package/build_pip_package pip_pkg --gpu --nightly_flag + +# Upload the built packages to pypi. +for WHL_PATH in $(ls pip_pkg/tf_nightly*dev*.whl); do + + WHL_DIR=$(dirname "${WHL_PATH}") + WHL_BASE_NAME=$(basename "${WHL_PATH}") + AUDITED_WHL_NAME="${WHL_DIR}"/$(echo "${WHL_BASE_NAME//linux/manylinux2010}") + + # Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so + WHL_PATH=${AUDITED_WHL_NAME} + cp "${WHL_DIR}"/"${WHL_BASE_NAME}" "${WHL_PATH}" + echo "Copied manylinux2010 wheel file at: ${WHL_PATH}" + + # test the whl pip package + chmod +x tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh + ./tensorflow/tools/ci_build/builds/nightly_release_smoke_test.sh ${AUDITED_WHL_NAME} + RETVAL=$? + + # Upload the PIP package if whl test passes. + if [ ${RETVAL} -eq 0 ]; then + echo "Basic PIP test PASSED, Uploading package: ${AUDITED_WHL_NAME}" + twine upload -r pypi-warehouse "${AUDITED_WHL_NAME}" + else + echo "Basic PIP test FAILED, will not upload ${AUDITED_WHL_NAME} package" + return 1 + fi +done diff --git a/tensorflow/tools/ci_build/nightly_release/windows/cpu_py35.bat b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py35.bat new file mode 100644 index 00000000000..6ed1088893f --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py35.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python35 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/nightly_release/windows/cpu_py36.bat b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py36.bat new file mode 100644 index 00000000000..3af98dddeae --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py36.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python36 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/nightly_release/windows/cpu_py37.bat b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py37.bat new file mode 100644 index 00000000000..850c21ee962 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py37.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python37 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/nightly_release/windows/cpu_py38.bat b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py38.bat new file mode 100644 index 00000000000..2456b1e26bb --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/cpu_py38.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python38 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --tf_nightly --project_name "tf_nightly_cpu" diff --git a/tensorflow/tools/ci_build/nightly_release/windows/gpu_py35.bat b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py35.bat new file mode 100644 index 00000000000..43e6414a74b --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py35.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python35 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/nightly_release/windows/gpu_py36.bat b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py36.bat new file mode 100644 index 00000000000..15ec83c054e --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py36.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python36 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/nightly_release/windows/gpu_py37.bat b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py37.bat new file mode 100644 index 00000000000..1eb65d8a284 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py37.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python37 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/nightly_release/windows/gpu_py38.bat b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py38.bat new file mode 100644 index 00000000000..670793340e8 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/gpu_py38.bat @@ -0,0 +1,20 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +SET PYTHON_DIRECTORY=Python38 + +CALL tensorflow\tools\ci_build\release\common_win.bat + +call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --tf_nightly diff --git a/tensorflow/tools/ci_build/nightly_release/windows/upload_nightly_pip.sh b/tensorflow/tools/ci_build/nightly_release/windows/upload_nightly_pip.sh new file mode 100644 index 00000000000..31d21c46816 --- /dev/null +++ b/tensorflow/tools/ci_build/nightly_release/windows/upload_nightly_pip.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +source tensorflow/tools/ci_build/release/common.sh + +sudo pip install --upgrade twine + +# Copy and rename to tf_nightly +for f in $(ls "${KOKORO_GFILE_DIR}"/tf_nightly_gpu*dev*cp3*-cp3*-win_amd64.whl); do + copy_to_new_project_name "${f}" tf_nightly +done + +# Upload the built packages to pypi. +for f in $(ls "${KOKORO_GFILE_DIR}"/tf_nightly*dev*cp3*-cp3*-win_amd64.whl); do + twine upload -r pypi-warehouse "$f" || echo +done diff --git a/tensorflow/tools/ci_build/presubmit/ubuntu_16/android/build.sh b/tensorflow/tools/ci_build/presubmit/ubuntu_16/android/build.sh index c20a0b14677..81afa226d3f 100644 --- a/tensorflow/tools/ci_build/presubmit/ubuntu_16/android/build.sh +++ b/tensorflow/tools/ci_build/presubmit/ubuntu_16/android/build.sh @@ -25,7 +25,7 @@ set +u set -x function run_build () { - export ANDROID_NDK_HOME="/opt/android-ndk-r17c" + export ANDROID_NDK_HOME="/opt/android-ndk-r18b" export NDK_HOME=$ANDROID_NDK_HOME export ANDROID_SDK_HOME="/opt/android-sdk/current" export ANDROID_API_LEVEL="23" diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh b/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh index 3dfab5a2aaa..148ab16de6c 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_libtensorflow.sh @@ -13,21 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -echo "chmod go+w lib_package/*" >> tensorflow/tools/ci_build/linux/libtensorflow.sh -echo "bazel clean --expunge" >> tensorflow/tools/ci_build/linux/libtensorflow.sh -# Install latest bazel -source tensorflow/tools/ci_build/release/common.sh -install_bazelisk +if [[ "$IS_NIGHTLY" -eq 1 ]]; then + echo "chmod go+w lib_package/*" >> tensorflow/tools/ci_build/linux/libtensorflow.sh + echo "bazel clean --expunge" >> tensorflow/tools/ci_build/linux/libtensorflow.sh -# Pick a version of xcode -export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer -sudo xcode-select -s "${DEVELOPER_DIR}" + # Install latest bazel + source tensorflow/tools/ci_build/release/common.sh + install_bazelisk -# Update the version string to nightly -./tensorflow/tools/ci_build/update_version.py --nightly + # Pick a version of xcode + export DEVELOPER_DIR=/Applications/Xcode_10.3.app/Contents/Developer + sudo xcode-select -s "${DEVELOPER_DIR}" -tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh + # Update the version string to nightly + ./tensorflow/tools/ci_build/update_version.py --nightly -# Copy the nightly version update script -cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package + tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh + + # Copy the nightly version update script + cp tensorflow/tools/ci_build/builds/libtensorflow_nightly_symlink.sh lib_package + +else + set -ex + # Install latest bazel + source tensorflow/tools/ci_build/release/common.sh + install_bazelisk + tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh +fi diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py35_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py35_nonpip.sh index 7e85779a207..dc1491d65d7 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py35_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py35_nonpip.sh @@ -41,11 +41,9 @@ tag_filters="-no_oss,-oss_serial,-nomac,-no_mac,-no_oss_py35,-v1only,-gpu,-tpu,- source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh # Run tests -set +e bazel test --test_output=errors --config=opt \ --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ --build_tag_filters="${tag_filters}" \ --test_tag_filters="${tag_filters}" -- \ ${DEFAULT_BAZEL_TARGETS} \ -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh index 07d4f7957af..eb245cb1d04 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py36_nonpip.sh @@ -41,11 +41,9 @@ tag_filters="-no_oss,-oss_serial,-nomac,-no_mac,-no_oss_py36,-v1only,-gpu,-tpu,- source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh # Run tests -set +e bazel test --test_output=errors --config=opt \ --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ --build_tag_filters="${tag_filters}" \ --test_tag_filters="${tag_filters}" -- \ ${DEFAULT_BAZEL_TARGETS} \ -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh index a23ca47a038..b9e6c3c9cf0 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh @@ -41,11 +41,9 @@ tag_filters="-no_oss,-oss_serial,-nomac,-no_mac$(maybe_skip_v1),-gpu,-tpu,-bench source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh # Run tests -set +e bazel test --test_output=errors --config=opt \ --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ --build_tag_filters="${tag_filters}" \ --test_tag_filters="${tag_filters}" -- \ ${DEFAULT_BAZEL_TARGETS} \ -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh index 179ecdf97ca..a90d59ff492 100644 --- a/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh +++ b/tensorflow/tools/ci_build/rel/macos/cpu_py38_nonpip.sh @@ -41,11 +41,9 @@ tag_filters="-no_oss,-oss_serial,-nomac,-no_mac$(maybe_skip_v1),-gpu,-tpu,-bench source tensorflow/tools/ci_build/build_scripts/DEFAULT_TEST_TARGETS.sh # Run tests -set +e bazel test --test_output=errors --config=opt \ --action_env=TF2_BEHAVIOR="${TF2_BEHAVIOR}" \ --build_tag_filters="${tag_filters}" \ --test_tag_filters="${tag_filters}" -- \ ${DEFAULT_BAZEL_TARGETS} \ -//tensorflow/lite/... -test_xml_summary_exit diff --git a/tensorflow/tools/ci_build/rel/windows/cpu_py35.bat b/tensorflow/tools/ci_build/rel/windows/cpu_py35.bat index 175917d7cad..4122756ef40 100644 --- a/tensorflow/tools/ci_build/rel/windows/cpu_py35.bat +++ b/tensorflow/tools/ci_build/rel/windows/cpu_py35.bat @@ -17,4 +17,8 @@ SET PYTHON_DIRECTORY=Python35 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows/cpu_py36.bat b/tensorflow/tools/ci_build/rel/windows/cpu_py36.bat index 85b75053eff..fde52ca24a5 100644 --- a/tensorflow/tools/ci_build/rel/windows/cpu_py36.bat +++ b/tensorflow/tools/ci_build/rel/windows/cpu_py36.bat @@ -17,4 +17,8 @@ SET PYTHON_DIRECTORY=Python36 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows/cpu_py37.bat b/tensorflow/tools/ci_build/rel/windows/cpu_py37.bat index d8a6673ba4c..4b696bb744e 100644 --- a/tensorflow/tools/ci_build/rel/windows/cpu_py37.bat +++ b/tensorflow/tools/ci_build/rel/windows/cpu_py37.bat @@ -17,4 +17,8 @@ SET PYTHON_DIRECTORY=Python37 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows/cpu_py38.bat b/tensorflow/tools/ci_build/rel/windows/cpu_py38.bat index 86adcda0bb9..a1657b077cb 100644 --- a/tensorflow/tools/ci_build/rel/windows/cpu_py38.bat +++ b/tensorflow/tools/ci_build/rel/windows/cpu_py38.bat @@ -17,5 +17,8 @@ SET PYTHON_DIRECTORY=Python38 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" - +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu" +) diff --git a/tensorflow/tools/ci_build/rel/windows/gpu_py35.bat b/tensorflow/tools/ci_build/rel/windows/gpu_py35.bat index 86c118b2f83..7a8eb53d1e1 100644 --- a/tensorflow/tools/ci_build/rel/windows/gpu_py35.bat +++ b/tensorflow/tools/ci_build/rel/windows/gpu_py35.bat @@ -17,7 +17,10 @@ SET PYTHON_DIRECTORY=Python35 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" + bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh +) diff --git a/tensorflow/tools/ci_build/rel/windows/gpu_py36.bat b/tensorflow/tools/ci_build/rel/windows/gpu_py36.bat index cc4f84afbee..6737261ce69 100644 --- a/tensorflow/tools/ci_build/rel/windows/gpu_py36.bat +++ b/tensorflow/tools/ci_build/rel/windows/gpu_py36.bat @@ -17,7 +17,10 @@ SET PYTHON_DIRECTORY=Python36 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh \ No newline at end of file +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" + bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh +) diff --git a/tensorflow/tools/ci_build/rel/windows/gpu_py37.bat b/tensorflow/tools/ci_build/rel/windows/gpu_py37.bat index 5fa798e3eb8..7ecfd83927f 100644 --- a/tensorflow/tools/ci_build/rel/windows/gpu_py37.bat +++ b/tensorflow/tools/ci_build/rel/windows/gpu_py37.bat @@ -17,7 +17,10 @@ SET PYTHON_DIRECTORY=Python37 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh \ No newline at end of file +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" + bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh +) diff --git a/tensorflow/tools/ci_build/rel/windows/gpu_py38.bat b/tensorflow/tools/ci_build/rel/windows/gpu_py38.bat index fa1fc131145..8d152b10a2e 100644 --- a/tensorflow/tools/ci_build/rel/windows/gpu_py38.bat +++ b/tensorflow/tools/ci_build/rel/windows/gpu_py38.bat @@ -17,7 +17,10 @@ SET PYTHON_DIRECTORY=Python38 CALL tensorflow\tools\ci_build\release\common_win.bat -call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" - -for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" -bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh +if "%IS_NIGHTLY%" == "1" ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --extra_test_flags "--test_env=TF2_BEHAVIOR=1" +) else ( + call tensorflow\tools\ci_build\windows\gpu\pip\run.bat --release_build --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow" + for %%a in ("%~dp0\.") do set "PARENT_DIR=%%~nxa" + bash -l tensorflow\tools\ci_build\release\windows\%PARENT_DIR%\release_pip_rename.sh +) diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index 5cc9fa67e6d..80b1e59f1a4 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -8,8 +8,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - py_library( name = "public_api", srcs = ["public_api.py"], diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD index 18a8d72a0b0..a15014f447b 100644 --- a/tensorflow/tools/compatibility/BUILD +++ b/tensorflow/tools/compatibility/BUILD @@ -4,6 +4,7 @@ load( "tf_cc_test", # @unused "tf_copts", # @unused ) +load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") package( default_visibility = ["//tensorflow:internal"], @@ -80,6 +81,7 @@ py_test( py_library( name = "renames_v2", srcs = ["renames_v2.py"], + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", ) @@ -92,6 +94,7 @@ py_library( py_library( name = "all_renames_v2", srcs = ["all_renames_v2.py"], + compatible_with = get_compatible_with_portable(), srcs_version = "PY2AND3", visibility = [ "//tensorflow:__pkg__", diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index a8d922cf7a9..e2aa438cb1a 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -282,6 +282,8 @@ renames = { 'tf.compat.v1.disable_v2_behavior', 'tf.disable_v2_tensorshape': 'tf.compat.v1.disable_v2_tensorshape', + 'tf.distribute.experimental.ParameterServerStrategy': + 'tf.compat.v1.distribute.experimental.ParameterServerStrategy', 'tf.distribute.get_loss_reduction': 'tf.compat.v1.distribute.get_loss_reduction', 'tf.distributions.Bernoulli': diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index bf149170977..5d795de68c5 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -895,13 +895,13 @@ class TFAPIChangeSpec(ast_edits.NoUpdateSpec): "changes in constructor. " + distribute_strategy_api_changes) contrib_ps_strategy_warning = ( - ast_edits.ERROR, - "(Manual edit required) " + ast_edits.ERROR, "(Manual edit required) " "tf.contrib.distribute.ParameterServerStrategy has " "been migrated to " - "tf.distribute.experimental.ParameterServerStrategy (multi machine) " - " and tf.distribute.experimental.CentralStorageStrategy (one machine). " - "Note the changes in constructors. " + distribute_strategy_api_changes) + "tf.compat.v1.distribute.experimental.ParameterServerStrategy (multi " + "machine) and tf.distribute.experimental.CentralStorageStrategy (one " + "machine). Note the changes in constructors. " + + distribute_strategy_api_changes) keras_experimental_export_comment = ( ast_edits.WARNING, diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index 12fba05c622..d5cb6b6bc38 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -123,6 +123,7 @@ toco::TocoConvert toco::TocoGetPotentiallySupportedOps toco::MlirQuantizeModel toco::MlirSparsifyModel +toco::RegisterCustomOpdefs [transform_graph_lib] # transform_graph tensorflow::graph_transforms::TransformGraph @@ -331,13 +332,9 @@ tensorflow::grappler::GetNumAvailableLogicalCPUCores [traceme_recorder_impl] # profiler tensorflow::profiler::TraceMeRecorder::Record -[annotation_stack_impl] # profiler -tensorflow::profiler::AnnotationStack::ThreadAnnotationStack - [profiler_session_impl] # profiler tensorflow::ProfilerSession::Create tensorflow::ProfilerSession::CollectData -tensorflow::ProfilerSession::SerializeToString tensorflow::ProfilerSession::Status tensorflow::ProfilerSession::~ProfilerSession @@ -349,6 +346,10 @@ tensorflow::profiler::ProfilerServer::~ProfilerServer tensorflow::profiler::ProfileGrpc tensorflow::profiler::NewSessionGrpc tensorflow::profiler::MonitorGrpc +tensorflow::profiler::RemoteProfilerSession::Create +tensorflow::profiler::RemoteProfilerSession::GetServiceAddress +tensorflow::profiler::RemoteProfilerSession::WaitForCompletion +tensorflow::profiler::RemoteProfilerSession::~RemoteProfilerSession [status_macros] # tfcompile xla::status_macros::MakeErrorStream::Impl::Impl diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile index 83e01bdfd16..23a4e33548a 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile @@ -28,7 +28,7 @@ FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base # (but their default value is retained if set previously) ARG ARCH ARG CUDA -ARG CUDNN=8.0.2.39-1 +ARG CUDNN=8.0.4.30-1 ARG CUDNN_MAJOR_VERSION=8 ARG LIB_DIR_PREFIX=x86_64 ARG LIBNVINFER=7.1.3-1 diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile index 60a3e57c294..a4770817596 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile @@ -28,7 +28,7 @@ FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base # (but their default value is retained if set previously) ARG ARCH ARG CUDA -ARG CUDNN=8.0.2.39-1 +ARG CUDNN=8.0.4.30-1 ARG CUDNN_MAJOR_VERSION=8 ARG LIB_DIR_PREFIX=x86_64 ARG LIBNVINFER=7.1.3-1 diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile index 911678b2ce3..7dda72e533b 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile @@ -28,7 +28,7 @@ FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base # (but their default value is retained if set previously) ARG ARCH ARG CUDA -ARG CUDNN=8.0.2.39-1 +ARG CUDNN=8.0.4.30-1 ARG CUDNN_MAJOR_VERSION=8 ARG LIB_DIR_PREFIX=x86_64 ARG LIBNVINFER=7.1.3-1 diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile index 228513d6736..279a7906c3d 100644 --- a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile +++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile @@ -28,7 +28,7 @@ FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base # (but their default value is retained if set previously) ARG ARCH ARG CUDA -ARG CUDNN=8.0.2.39-1 +ARG CUDNN=8.0.4.30-1 ARG CUDNN_MAJOR_VERSION=8 ARG LIB_DIR_PREFIX=x86_64 ARG LIBNVINFER=7.1.3-1 diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile index ed310f39ecf..36187c2c9b7 100644 --- a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile @@ -5,7 +5,7 @@ FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base # (but their default value is retained if set previously) ARG ARCH ARG CUDA -ARG CUDNN=8.0.2.39-1 +ARG CUDNN=8.0.4.30-1 ARG CUDNN_MAJOR_VERSION=8 ARG LIB_DIR_PREFIX=x86_64 ARG LIBNVINFER=7.1.3-1 diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile index b2a7b46a7cb..ca9d38e51c7 100644 --- a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile +++ b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile @@ -5,7 +5,7 @@ FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base # (but their default value is retained if set previously) ARG ARCH ARG CUDA -ARG CUDNN=8.0.2.39-1 +ARG CUDNN=8.0.4.30-1 ARG CUDNN_MAJOR_VERSION=8 ARG LIB_DIR_PREFIX=x86_64 ARG LIBNVINFER=7.1.3-1 diff --git a/tensorflow/tools/dockerfiles/tflite-android.Dockerfile b/tensorflow/tools/dockerfiles/tflite-android.Dockerfile index 71286c69642..4501054b77a 100644 --- a/tensorflow/tools/dockerfiles/tflite-android.Dockerfile +++ b/tensorflow/tools/dockerfiles/tflite-android.Dockerfile @@ -19,7 +19,7 @@ RUN cd ${ANDROID_DEV_HOME} && \ bash -c "ln -s ${ANDROID_DEV_HOME}/android-sdk-* ${ANDROID_SDK_HOME}" # Install Android NDK. -ENV ANDROID_NDK_FILENAME android-ndk-r17c-linux-x86_64.zip +ENV ANDROID_NDK_FILENAME android-ndk-r18b-linux-x86_64.zip ENV ANDROID_NDK_URL https://dl.google.com/android/repository/${ANDROID_NDK_FILENAME} ENV ANDROID_NDK_HOME ${ANDROID_DEV_HOME}/ndk ENV PATH ${PATH}:${ANDROID_NDK_HOME} diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 3668da2645d..1998e47a3ad 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -13,8 +13,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tpu_module = [ "tpu.", "distribute.tpu_strategy", diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 12dfe038e8b..2f5eceb5e08 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -16,8 +16,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - cc_library( name = "transform_utils", srcs = [ diff --git a/tensorflow/tools/mlpbtxt/BUILD b/tensorflow/tools/mlpbtxt/BUILD index 68966404730..0786fc26512 100644 --- a/tensorflow/tools/mlpbtxt/BUILD +++ b/tensorflow/tools/mlpbtxt/BUILD @@ -9,10 +9,7 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files([ - "LICENSE", - "placeholder.txt", -]) +exports_files(["placeholder.txt"]) tf_cc_binary( name = "tomlpbtxt", diff --git a/tensorflow/tools/optimization/BUILD b/tensorflow/tools/optimization/BUILD index 14db4f77232..08df199bf4d 100644 --- a/tensorflow/tools/optimization/BUILD +++ b/tensorflow/tools/optimization/BUILD @@ -13,8 +13,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["LICENSE"]) - tf_cuda_library( name = "optimization_pass_runner_lib", srcs = ["optimization_pass_runner.cc"], diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index cace6af08c6..c9d2932bcee 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -32,7 +32,7 @@ transitive_hdrs( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor", + "//tensorflow/core/platform:stream_executor", "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:reader", "//tensorflow/cc/saved_model:bundle_v2", @@ -98,7 +98,6 @@ COMMON_PIP_DEPS = [ "//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_srcs", "//tensorflow/compiler/mlir/tensorflow:gen_mlir_passthrough_op_py", "//tensorflow/core:protos_all_proto_srcs", - "//tensorflow/examples/saved_model/integration_tests:mnist_util", "//tensorflow/lite/python/testdata:interpreter_test_data", "//tensorflow/lite/python:tflite_convert", "//tensorflow/lite/toco/python:toco_from_protos", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index c62a7beb0a7..e4a69d22a6f 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -63,6 +63,7 @@ REQUIRED_PACKAGES = [ # https://github.com/numpy/numpy/pull/15355 'numpy >= 1.16.0, < 1.19.0', 'opt_einsum >= 2.3.2', + 'portpicker >= 1.3.0', # Needed by tf.__internal__.distribute.combinations. 'protobuf >= 3.9.2', 'tensorboard >= 2.3.0, < 3', 'tensorflow_estimator >= 2.3.0, < 2.4.0', diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD index 105151c18cf..19ed935b38f 100644 --- a/tensorflow/tools/proto_text/BUILD +++ b/tensorflow/tools/proto_text/BUILD @@ -26,10 +26,7 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files([ - "LICENSE", - "placeholder.txt", -]) +exports_files(["placeholder.txt"]) cc_binary( name = "gen_proto_text_functions", diff --git a/tensorflow/tools/test/BUILD b/tensorflow/tools/test/BUILD index 30c8d991b5f..f314dcfff11 100644 --- a/tensorflow/tools/test/BUILD +++ b/tensorflow/tools/test/BUILD @@ -12,10 +12,7 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files([ - "LICENSE", - "run_and_gather_logs_lib.py", -]) +exports_files(["run_and_gather_logs_lib.py"]) py_library( name = "system_info_lib", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index eafc52ad549..e3b9a3557a4 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -209,11 +209,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "mkl_dnn_v1", build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"), - sha256 = "aef4d2a726f76f5b98902491a1a4ac69954039aa8e5a1d67ef6ce58ed00e23a6", - strip_prefix = "oneDNN-1.5.1", + sha256 = "5369f7b2f0b52b40890da50c0632c3a5d1082d98325d0f2bff125d19d0dcaa1d", + strip_prefix = "oneDNN-1.6.4", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v1.5.1.tar.gz", - "https://github.com/oneapi-src/oneDNN/archive/v1.5.1.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz", + "https://github.com/oneapi-src/oneDNN/archive/v1.6.4.tar.gz", ], ) @@ -235,11 +235,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "a3c10a8c14f55e9f09f98b0a0ac6874c21bda91f65b7469d9b1f6925990e867b", # SHARED_EIGEN_SHA - strip_prefix = "eigen-d10b27fe37736d2944630ecd7557cefa95cf87c9", + sha256 = "00ff67c15f8e8faf14495482e7396cc1d99cdfaaa2151f4aafef92bc754e634d", # SHARED_EIGEN_SHA + strip_prefix = "eigen-22c971a225dbb567cd1a45f6006d16c4aa618551", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/d10b27fe37736d2944630ecd7557cefa95cf87c9/eigen-d10b27fe37736d2944630ecd7557cefa95cf87c9.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/d10b27fe37736d2944630ecd7557cefa95cf87c9/eigen-d10b27fe37736d2944630ecd7557cefa95cf87c9.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/22c971a225dbb567cd1a45f6006d16c4aa618551/eigen-22c971a225dbb567cd1a45f6006d16c4aa618551.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/22c971a225dbb567cd1a45f6006d16c4aa618551/eigen-22c971a225dbb567cd1a45f6006d16c4aa618551.tar.gz", ], ) @@ -712,8 +712,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "6caf3fb8178699ac14fb94fef99aaf1cf297264f" - LLVM_SHA256 = "c938dc8e71b3ed2fb29042fdc03c005e5334fa2be048a5b5e62dd406871c9c59" + LLVM_COMMIT = "6713332fddb796f5b14fcb6a7e5d36979676e4ab" + LLVM_SHA256 = "d83051da693c4165a8bd9a2703c806820d5146188b1036fd94300fafa09aae50" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), @@ -1185,6 +1185,17 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ], ) + tf_http_archive( + name = "tf_toolchains", + sha256 = "eb175afa73e5a33d2b5d2aabcfde6c8c3395fd7001eb5ba765a5cd98cce714ba", + strip_prefix = "toolchains-0.0.2", + build_file = clean_dep("//third_party:tf_toolchains.BUILD"), + urls = [ + "http://mirror.tensorflow.org/github.com/tensorflow/toolchains/archive/v0.0.2.tar.gz", + "https://github.com/tensorflow/toolchains/archive/v0.0.2.tar.gz", + ], + ) + def tf_bind(): """Bind targets for some external repositories""" ############################################################################## diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel index 1ee46f05235..108c0cd8e3b 100644 --- a/third_party/flatbuffers/BUILD.bazel +++ b/third_party/flatbuffers/BUILD.bazel @@ -1,4 +1,5 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load(":build_defs.bzl", "flatbuffer_py_strip_prefix_srcs") package(default_visibility = ["//visibility:public"]) @@ -102,8 +103,8 @@ cc_library( visibility = ["//visibility:public"], ) -filegroup( - name = "runtime_py_srcs", +flatbuffer_py_strip_prefix_srcs( + name = "flatbuffer_py_strip_prefix", srcs = [ "python/flatbuffers/__init__.py", "python/flatbuffers/builder.py", @@ -114,6 +115,21 @@ filegroup( "python/flatbuffers/table.py", "python/flatbuffers/util.py", ], + strip_prefix = "python/flatbuffers/", +) + +filegroup( + name = "runtime_py_srcs", + srcs = [ + "__init__.py", + "builder.py", + "compat.py", + "encode.py", + "number_types.py", + "packer.py", + "table.py", + "util.py", + ], ) py_library( diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl index d7099e58717..4fe9629b9d1 100644 --- a/third_party/flatbuffers/build_defs.bzl +++ b/third_party/flatbuffers/build_defs.bzl @@ -361,18 +361,30 @@ _gen_flatbuffer_srcs = rule( output_to_genfiles = True, ) +def flatbuffer_py_strip_prefix_srcs(name, srcs = [], strip_prefix = ""): + """Strips path prefix. + + Args: + name: Rule name. (required) + srcs: Source .py files. (required) + strip_prefix: Path that needs to be stripped from the srcs filepaths. (required) + """ + for src in srcs: + native.genrule( + name = name + "_" + src.replace(".", "_").replace("/", "_"), + srcs = [src], + outs = [src.replace(strip_prefix, "")], + cmd = "cp $< $@", + ) + def _concat_flatbuffer_py_srcs_impl(ctx): - # Merge all generated python files. The files are concatenated and the - # import statements are removed. Finally we import the flatbuffer runtime - # library. + # Merge all generated python files. + command = "find '%s' -name '*.py' -exec cat {} + | sed '/import flatbuffers/d'" + command += " | sed '1s/^/import flatbuffers\\'$'\\n/' > %s" ctx.actions.run_shell( inputs = ctx.attr.deps[0].files, outputs = [ctx.outputs.out], - command = ( - "find '%s' -name '*.py' -exec cat {} + |" + - "sed 's/from flatbuffers.compat import import_numpy/import numpy as np' |" + - "sed '/np = import_numpy()/d' > %s" - ) % ( + command = command % ( ctx.attr.deps[0].files.to_list()[0].path, ctx.outputs.out.path, ), diff --git a/third_party/gpus/cuda/cuda_config.h.tpl b/third_party/gpus/cuda/cuda_config.h.tpl index b59889938b1..ab26686ccb8 100644 --- a/third_party/gpus/cuda/cuda_config.h.tpl +++ b/third_party/gpus/cuda/cuda_config.h.tpl @@ -17,6 +17,7 @@ limitations under the License. #define CUDA_CUDA_CONFIG_H_ #define TF_CUDA_VERSION "%{cuda_version}" +#define TF_CUDART_VERSION "%{cudart_version}" #define TF_CUBLAS_VERSION "%{cublas_version}" #define TF_CUSOLVER_VERSION "%{cusolver_version}" #define TF_CURAND_VERSION "%{curand_version}" diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index ea33963fe19..704003b7f63 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -534,14 +534,14 @@ def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config): "cudart", cpu_value, cuda_config.config["cuda_library_dir"], - cuda_config.cuda_version, + cuda_config.cudart_version, static = False, ), "cudart_static": _check_cuda_lib_params( "cudart_static", cpu_value, cuda_config.config["cuda_library_dir"], - cuda_config.cuda_version, + cuda_config.cudart_version, static = True, ), "cublas": _check_cuda_lib_params( @@ -651,6 +651,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script): cuda_toolkit_path: The CUDA toolkit installation directory. cudnn_install_basedir: The cuDNN installation directory. cuda_version: The version of CUDA on the system. + cudart_version: The CUDA runtime version on the system. cudnn_version: The version of cuDNN on the system. compute_capabilities: A list of the system's CUDA compute capabilities. cpu_value: The name of the host operating system. @@ -668,6 +669,11 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script): cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"] if int(cuda_major) >= 11: + # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability. + if int(cuda_major) == 11: + cudart_version = "64_110" if is_windows else "11.0" + else: + cudart_version = ("64_%s" if is_windows else "%s") % cuda_major cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0] cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0] curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0] @@ -677,12 +683,14 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script): # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc. # It changed from 'x.y' to just 'x' in CUDA 10.1. cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major + cudart_version = cuda_version cublas_version = cuda_lib_version cusolver_version = cuda_lib_version curand_version = cuda_lib_version cufft_version = cuda_lib_version cusparse_version = cuda_lib_version else: + cudart_version = cuda_version cublas_version = cuda_version cusolver_version = cuda_version curand_version = cuda_version @@ -693,6 +701,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script): cuda_toolkit_path = toolkit_path, cuda_version = cuda_version, cuda_version_major = cuda_major, + cudart_version = cudart_version, cublas_version = cublas_version, cusolver_version = cusolver_version, curand_version = curand_version, @@ -816,6 +825,7 @@ filegroup(name="cudnn-include") "cuda:cuda_config.h", { "%{cuda_version}": "", + "%{cudart_version}": "", "%{cublas_version}": "", "%{cusolver_version}": "", "%{curand_version}": "", @@ -1281,6 +1291,7 @@ def _create_local_cuda_repository(repository_ctx): tpl_paths["cuda:cuda_config.h"], { "%{cuda_version}": cuda_config.cuda_version, + "%{cudart_version}": cuda_config.cudart_version, "%{cublas_version}": cuda_config.cublas_version, "%{cusolver_version}": cuda_config.cusolver_version, "%{curand_version}": cuda_config.curand_version, diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD index 2fa578c3f4e..579005926a4 100644 --- a/third_party/llvm/llvm.autogenerated.BUILD +++ b/third_party/llvm/llvm.autogenerated.BUILD @@ -420,7 +420,7 @@ cc_binary( ) cc_library( - name = "filecheck-lib", + name = "FileCheckLib", srcs = glob([ "lib/FileCheck/*.cpp", "lib/FileCheck/*.h", @@ -443,8 +443,8 @@ cc_binary( linkopts = llvm_linkopts, stamp = 0, deps = [ + ":FileCheckLib", ":Support", - ":filecheck-lib", ], ) @@ -1772,6 +1772,50 @@ cc_library( ], ) +cc_library( + name = "CSKYCodeGen", + srcs = glob([ + "lib/Target/CSKY/*.c", + "lib/Target/CSKY/*.cpp", + "lib/Target/CSKY/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/CSKY/*.h", + "include/llvm/Target/CSKY/*.def", + "include/llvm/Target/CSKY/*.inc", + "lib/Target/CSKY/*.h", + ]), + copts = llvm_copts + ["-Iexternal/llvm-project/llvm/lib/Target/CSKY"], + deps = [ + ":CSKYInfo", + ":CodeGen", + ":Core", + ":Support", + ":Target", + ":config", + ], +) + +cc_library( + name = "CSKYInfo", + srcs = glob([ + "lib/Target/CSKY/TargetInfo/*.c", + "lib/Target/CSKY/TargetInfo/*.cpp", + "lib/Target/CSKY/TargetInfo/*.inc", + ]), + hdrs = glob([ + "include/llvm/Target/CSKY/TargetInfo/*.h", + "include/llvm/Target/CSKY/TargetInfo/*.def", + "include/llvm/Target/CSKY/TargetInfo/*.inc", + "lib/Target/CSKY/TargetInfo/*.h", + ]), + copts = llvm_copts + ["-Iexternal/llvm-project/llvm/lib/Target/CSKY"], + deps = [ + ":Support", + ":config", + ], +) + cc_library( name = "CodeGen", srcs = glob([ diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index 66a2bf8ceb9..c1c2c450e34 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -21,6 +21,14 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "build_with_mkl_aarch64", + define_values = { + "build_with_mkl_aarch64": "true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "enable_mkl", define_values = { diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index 28bd262e61e..b3efa4d9ca7 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -91,6 +91,7 @@ def mkl_deps(): """ return select({ "@org_tensorflow//third_party/mkl:build_with_mkl": ["@mkl_dnn_v1//:mkl_dnn"], + "@org_tensorflow//third_party/mkl:build_with_mkl_aarch64": ["@mkl_dnn_v1//:mkl_dnn_aarch64"], "//conditions:default": [], }) diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 0e6acc2fadd..8e7a3d61564 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -58,8 +58,8 @@ template_rule( out = "include/dnnl_version.h", substitutions = { "@DNNL_VERSION_MAJOR@": "1", - "@DNNL_VERSION_MINOR@": "5", - "@DNNL_VERSION_PATCH@": "1", + "@DNNL_VERSION_MINOR@": "6", + "@DNNL_VERSION_PATCH@": "4", "@DNNL_VERSION_HASH@": "N/A", }, ) @@ -135,3 +135,36 @@ cc_library( ], visibility = ["//visibility:public"], ) + +cc_library( + name = "mkl_dnn_aarch64", + srcs = glob([ + "src/common/*.cpp", + "src/common/*.hpp", + "src/cpu/*.cpp", + "src/cpu/*.hpp", + "src/cpu/rnn/*.cpp", + "src/cpu/rnn/*.hpp", + "src/cpu/matmul/*.cpp", + "src/cpu/matmul/*.hpp", + "src/cpu/gemm/**/*", + ]) + [ + ":dnnl_config_h", + ":dnnl_version_h", + ], + hdrs = glob(["include/*"]), + copts = [ + "-fexceptions", + "-UUSE_MKL", + "-UUSE_CBLAS", + ], + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + ], + linkopts = ["-lgomp"], + visibility = ["//visibility:public"], +) diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 1c750e61bb4..31a839d81fa 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -121,18 +121,23 @@ cc_library( name = "CAPIIR", srcs = [ "lib/CAPI/IR/AffineMap.cpp", + "lib/CAPI/IR/Diagnostics.cpp", "lib/CAPI/IR/IR.cpp", "lib/CAPI/IR/StandardAttributes.cpp", "lib/CAPI/IR/StandardTypes.cpp", "lib/CAPI/IR/Support.cpp", + "lib/CAPI/Standard/StandardDialect.cpp", ], hdrs = [ "include/mlir-c/AffineMap.h", + "include/mlir-c/Diagnostics.h", "include/mlir-c/IR.h", "include/mlir-c/StandardAttributes.h", + "include/mlir-c/StandardDialect.h", "include/mlir-c/StandardTypes.h", "include/mlir-c/Support.h", "include/mlir/CAPI/AffineMap.h", + "include/mlir/CAPI/Diagnostics.h", "include/mlir/CAPI/IR.h", "include/mlir/CAPI/Support.h", "include/mlir/CAPI/Utils.h", @@ -142,11 +147,23 @@ cc_library( deps = [ ":IR", ":Parser", + ":StandardOps", ":Support", "@llvm-project//llvm:Support", ], ) +cc_library( + name = "MLIRBindingsPythonExtension", + hdrs = [ + "include/mlir-c/Bindings/Python/Interop.h", + ], + deps = [ + ":CAPIIR", + "//third_party/python_runtime:headers", + ], +) + cc_library( name = "CAPIRegistration", srcs = [ @@ -165,7 +182,6 @@ cc_library( filegroup( name = "OpBaseTdFiles", srcs = [ - "include/mlir/Dialect/Affine/IR/AffineOpsBase.td", "include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td", "include/mlir/IR/OpBase.td", ], @@ -187,7 +203,6 @@ filegroup( srcs = [ "include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td", "include/mlir/Dialect/Affine/IR/AffineOps.td", - "include/mlir/Dialect/Affine/IR/AffineOpsBase.td", "include/mlir/Interfaces/ControlFlowInterfaces.td", "include/mlir/Interfaces/LoopLikeInterface.td", "include/mlir/Interfaces/SideEffectInterfaces.td", @@ -239,6 +254,44 @@ gentbl( ], ) +##---------------------------------------------------------------------------## +# Async dialect. +##---------------------------------------------------------------------------## + +filegroup( + name = "AsyncOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Async/IR/AsyncBase.td", + "include/mlir/Dialect/Async/IR/AsyncOps.td", + "include/mlir/Interfaces/SideEffectInterfaces.td", + ":OpBaseTdFiles", + ], +) + +gentbl( + name = "AsyncOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-op-decls", + "include/mlir/Dialect/Async/IR/AsyncOps.h.inc", + ), + ( + "-gen-op-defs", + "include/mlir/Dialect/Async/IR/AsyncOps.cpp.inc", + ), + ( + "-gen-dialect-decls", + "include/mlir/Dialect/Async/IR/AsyncOpsDialect.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Async/IR/AsyncOps.td", + td_srcs = [ + ":AsyncOpsTdFiles", + ], +) + ##---------------------------------------------------------------------------## # AVX512 dialect. ##---------------------------------------------------------------------------## @@ -508,6 +561,26 @@ cc_library( ], ) +cc_library( + name = "Async", + srcs = glob([ + "lib/Dialect/Async/IR/*.cpp", + ]), + hdrs = glob([ + "include/mlir/Dialect/Async/IR/*.h", + ]), + includes = ["include"], + deps = [ + ":AsyncOpsIncGen", + ":Dialect", + ":IR", + ":SideEffectInterfaces", + ":StandardOps", + ":Support", + "@llvm-project//llvm:Support", + ], +) + cc_library( name = "AffineUtils", srcs = glob( @@ -610,6 +683,7 @@ cc_library( ":VectorToLLVM", ":VectorToROCDL", ":VectorToSCF", + ":VectorToSPIRV", ], ) @@ -938,6 +1012,7 @@ cc_library( ":Support", ":VectorInterfaces", ":VectorOpsIncGen", + ":ViewLikeInterface", "@llvm-project//llvm:Support", ], ) @@ -1196,6 +1271,30 @@ gentbl( ], ) +gentbl( + name = "GPUBaseIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-dialect-decls -dialect=gpu", + "include/mlir/Dialect/GPU/GPUOpsDialect.h.inc", + ), + ( + "-gen-op-interface-decls", + "include/mlir/Dialect/GPU/GPUOpInterfaces.h.inc", + ), + ( + "-gen-op-interface-defs", + "include/mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/GPU/GPUBase.td", + td_srcs = [ + ":GPUOpsTdFiles", + ], +) + gentbl( name = "GPUOpsIncGen", strip_include_prefix = "include", @@ -1208,10 +1307,6 @@ gentbl( "-gen-op-defs", "include/mlir/Dialect/GPU/GPUOps.cpp.inc", ), - ( - "-gen-dialect-decls -dialect=gpu", - "include/mlir/Dialect/GPU/GPUOpsDialect.h.inc", - ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/GPU/GPUOps.td", @@ -1233,12 +1328,14 @@ cc_library( ]), includes = ["include"], deps = [ + ":GPUBaseIncGen", ":GPUOpsIncGen", ":IR", ":LLVMDialect", ":SideEffectInterfaces", ":StandardOps", ":Support", + "@llvm-project//llvm:Support", ], ) @@ -1385,6 +1482,27 @@ cc_library( ], ) +cc_library( + name = "VectorToSPIRV", + srcs = [ + "lib/Conversion/PassDetail.h", + "lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp", + ], + hdrs = [ + "include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h", + "include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h", + ], + includes = ["include"], + deps = [ + ":ConversionPassIncGen", + ":Pass", + ":SPIRVDialect", + ":SPIRVLowering", + ":Transforms", + ":VectorOps", + ], +) + gentbl( name = "GPUToROCDLTGen", strip_include_prefix = "lib/Conversion/GPUToROCDL", @@ -1523,6 +1641,7 @@ cc_library( ":StandardToSPIRVTransforms", ":Support", ":Transforms", + ":VectorToSPIRV", ], ) @@ -2952,6 +3071,7 @@ cc_library( ":AffinePassIncGen", ":AffineToStandard", ":AffineTransforms", + ":Async", ":ConversionPasses", ":GPUDialect", ":GPUPassIncGen", @@ -3006,6 +3126,7 @@ cc_library( ":VectorToLLVM", ":VectorToROCDL", ":VectorToSCF", + ":VectorToSPIRV", ], ) @@ -3071,6 +3192,7 @@ cc_library( name = "mlir_c_runner_utils", srcs = [ "lib/ExecutionEngine/CRunnerUtils.cpp", + "lib/ExecutionEngine/SparseUtils.cpp", ], hdrs = [ "include/mlir/ExecutionEngine/CRunnerUtils.h", @@ -3104,10 +3226,10 @@ cc_binary( ], ) -cc_binary( - name = "tools/libcuda-runtime-wrappers.so", +cc_library( + name = "tools/libcuda-runtime-wrappers", srcs = ["tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp"], - linkshared = True, + compatible_with = ["//buildenv/target:prod"], deps = [ ":mlir_c_runner_utils", "//third_party/gpus/cuda:cuda_headers", @@ -3117,6 +3239,12 @@ cc_binary( ], ) +cc_binary( + name = "tools/libcuda-runtime-wrappers.so", + linkshared = True, + deps = [":tools/libcuda-runtime-wrappers"], +) + cc_library( name = "VulkanRuntime", srcs = [ @@ -3693,6 +3821,7 @@ cc_library( ":ConversionPassIncGen", ":IR", ":LinalgOps", + ":LinalgTransforms", ":Pass", ":SCFDialect", ":StandardOps", @@ -3826,6 +3955,7 @@ filegroup( srcs = [ "include/mlir/Dialect/Vector/VectorOps.td", "include/mlir/Interfaces/VectorInterfaces.td", + "include/mlir/Interfaces/ViewLikeInterface.td", ":AffineOpsTdFiles", ":OpBaseTdFiles", ], diff --git a/third_party/ngraph/ngraph.BUILD b/third_party/ngraph/ngraph.BUILD index dfdd891b228..715148d38f6 100644 --- a/third_party/ngraph/ngraph.BUILD +++ b/third_party/ngraph/ngraph.BUILD @@ -118,6 +118,7 @@ cc_library( ":ngraph_headers", "@eigen_archive//:eigen", "@mkl_dnn_v1//:mkl_dnn", + "@mkl_dnn_v1//:mkl_dnn_aarch64", "@nlohmann_json_lib", "@tbb", ], diff --git a/third_party/tf_toolchains.BUILD b/third_party/tf_toolchains.BUILD new file mode 100644 index 00000000000..ffd0fb0cdc5 --- /dev/null +++ b/third_party/tf_toolchains.BUILD @@ -0,0 +1 @@ +package(default_visibility = ["//visibility:public"]) diff --git a/third_party/toolchains/clang6/CROSSTOOL.tpl b/third_party/toolchains/clang6/CROSSTOOL.tpl index f98bef7a09b..16e36a1cd1f 100644 --- a/third_party/toolchains/clang6/CROSSTOOL.tpl +++ b/third_party/toolchains/clang6/CROSSTOOL.tpl @@ -531,7 +531,6 @@ toolchain { linker_flag: "-L/usr/lib/gcc/x86_64-linux-gnu/4.8" # This is the minimum x86 architecture TensorFlow supports. - compiler_flag: "-DARCH_K8" compiler_flag: "-m64" # These are for Linux. diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl index 46d89150306..8efa702bf1c 100644 --- a/third_party/toolchains/preconfig/generate/containers.bzl +++ b/third_party/toolchains/preconfig/generate/containers.bzl @@ -11,7 +11,7 @@ container_digests = { "cuda10.1-cudnn7-ubuntu16.04-manylinux2010": "sha256:5e6d21c8ef226316eb6df5e2e6015244c16a8e5d936b52a09820442d2f8a919f", "cuda10.1-cudnn7-ubuntu16.04-manylinux2010-multipython": "sha256:3f890a951c81a201d60d0161a56ce628a90323be0c7f795550caa37f6f41a85c", "cuda10.1-cudnn7-ubuntu18.04-manylinux2010-multipython": "sha256:bd7666d1ef49b2b2e2a64981f1c9234deeccdb0d5198b30ff4289c3dfcffedbf", - "cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython": "sha256:b52edb4e35c780334ba417b008927722ae668847715a1624e9b2984e99c05338", + "cuda11.0-cudnn8-ubuntu18.04-manylinux2010-multipython": "sha256:f436545b7e14b014393b42975923dcd01f408496b1399abb5a35608f888ca140", "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:8c6ba5a831c23906716cc9e9c201081f2b5632e3bf3cbc0207da0ddbef18d525", "windows-1803": "sha256:f109576c7c0c8a1783ff22b666e8923b52dbbe7933f69a1c7a7275202c304a12", }