Merge branch 'master' into fix_pack
This commit is contained in:
commit
1cc444cb0f
3
.bazelrc
3
.bazelrc
@ -258,6 +258,9 @@ build:windows --host_cxxopt=/std:c++14
|
||||
# On windows, we still link everything into a single DLL.
|
||||
build:windows --config=monolithic
|
||||
|
||||
# On linux, we dynamically link small amount of kernels
|
||||
build:linux --config=dynamic_kernels
|
||||
|
||||
# Make sure to include as little of windows.h as possible
|
||||
build:windows --copt=-DWIN32_LEAN_AND_MEAN
|
||||
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
|
||||
|
19
.github/ISSUE_TEMPLATE/70-tflite-micro-issue.md
vendored
Normal file
19
.github/ISSUE_TEMPLATE/70-tflite-micro-issue.md
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
---
|
||||
name: TensorFlow Lite for Microcontrollers Issue
|
||||
about: Use this template for reporting issues with TensorFlow Lite for microcontrollers
|
||||
labels: 'comp:micro'
|
||||
|
||||
---
|
||||
|
||||
@tensorflow/micro
|
||||
|
||||
**System information**
|
||||
- Host OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- TensorFlow installed from (source or binary):
|
||||
- Tensorflow version (commit SHA if source):
|
||||
- Target platform (e.g. Arm Mbed OS, Arduino Nano 33 etc.):
|
||||
|
||||
**Describe the problem**
|
||||
|
||||
**Please provide the exact sequence of commands/steps when you ran into the problem**
|
||||
|
18
README.md
18
README.md
@ -29,6 +29,20 @@ to
|
||||
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
|
||||
See all the [mailing lists](https://www.tensorflow.org/community/forums).
|
||||
|
||||
## Feature Prioritization Survey
|
||||
|
||||
The TensorFlow team is working on building/improving features, and understands
|
||||
that it is very important to prioritize these efforts based on what TF users
|
||||
need.
|
||||
|
||||
The goal of this short, < 5 minute
|
||||
[survey](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad), is to help
|
||||
the TensorFlow team better understand what features to prioritize based on your
|
||||
feedback. Participation is of course optional.
|
||||
|
||||
Take the survey
|
||||
[HERE](https://google.qualtrics.com/jfe/form/SV_d5nqhCEbkDkQ7ad).
|
||||
|
||||
## Install
|
||||
|
||||
See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
|
||||
@ -51,6 +65,9 @@ Windows)*:
|
||||
$ pip install tensorflow-gpu
|
||||
```
|
||||
|
||||
To update TensorFlow to the latest version, add `--upgrade` flag to the above
|
||||
commands.
|
||||
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) packages on PyPi.*
|
||||
@ -147,3 +164,4 @@ Learn more about the
|
||||
## License
|
||||
|
||||
[Apache License 2.0](LICENSE)
|
||||
|
||||
|
608
RELEASE.md
608
RELEASE.md
@ -201,239 +201,387 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* `tf.contrib`:
|
||||
* Expose `tf.contrib.proto.*` ops in `tf.io` (they will exist in TF2)
|
||||
|
||||
* `tf.data`:
|
||||
* Add support for TensorArrays to `tf.data Dataset`.
|
||||
* Integrate Ragged Tensors with `tf.data`.
|
||||
* All core and experimental tf.data transformations that input user-defined functions can span multiple devices now.
|
||||
* Extending the TF 2.0 support for `shuffle(..., reshuffle_each_iteration=True)` and `cache()` to work across different Python iterators for the same dataset.
|
||||
* Removing the `experimental_numa_aware` option from `tf.data.Options`.
|
||||
* Add `num_parallel_reads` and passing in a Dataset containing filenames into `TextLineDataset` and `FixedLengthRecordDataset`.
|
||||
* Add support for defaulting the value of `cycle_length` argument of `tf.data.Dataset.interleave` to the number of schedulable CPU cores.
|
||||
* Promoting `tf.data.experimental.enumerate_dataset` to core as `tf.data.Dataset.enumerate`.
|
||||
* Promoting `tf.data.experimental.unbatch` to core as `tf.data.Dataset.unbatch`.
|
||||
* Adds option for introducing slack in the pipeline to reduce CPU contention, via `tf.data.Options().experimental_slack = True`
|
||||
* Added experimental support for parallel batching to `batch()` and `padded_batch()`. This functionality can be enabled through `tf.data.Options()`.
|
||||
* Support cancellation of long-running `reduce`.
|
||||
* Now we use `dataset` node name as prefix instead of the op name, to identify the component correctly in metrics, for pipelines with repeated components.
|
||||
* Improve the performance of datasets using `from_tensors()`.
|
||||
* Promoting `unbatch` from experimental to core API.
|
||||
* Adding support for datasets as inputs to `from_tensors` and `from_tensor_slices` and batching and unbatching of nested datasets.
|
||||
* `tf.contrib`:
|
||||
|
||||
* `tf.distribute`:
|
||||
* Enable `tf.distribute.experimental.MultiWorkerMirroredStrategy` working in eager mode.
|
||||
* Callbacks are supported in `MultiWorkerMirroredStrategy`.
|
||||
* Disable `run_eagerly` and distribution strategy if there are symbolic tensors added to the model using `add_metric` or `add_loss`.
|
||||
* Loss and gradients should now more reliably be correctly scaled w.r.t. the global batch size when using a `tf.distribute.Strategy`.
|
||||
* Set default loss reduction as `AUTO` for improving reliability of loss scaling with distribution strategy and custom training loops. `AUTO` indicates that the reduction option will be determined by the usage context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When used in distribution strategy scope, outside of built-in training loops such as `tf.keras` `compile` and `fit`, we expect reduction value to be 'None' or 'SUM'. Using other values will raise an error.
|
||||
* Support for multi-host `ncclAllReduce` in Distribution Strategy.
|
||||
* Expose `tf.contrib.proto.*` ops in `tf.io` (they will exist in TF2)
|
||||
|
||||
* `tf.estimator`:
|
||||
* Replace `tf.contrib.estimator.add_metrics` with `tf.estimator.add_metrics`
|
||||
* Use `tf.compat.v1.estimator.inputs` instead of `tf.estimator.inputs`
|
||||
* Replace contrib references with `tf.estimator.experimental.*` for apis in early_s in Estimator
|
||||
* Canned Estimators will now use keras optimizers by default. An error will be raised if tf.train.Optimizers are used, and you will have to switch to tf.keras.optimizers or tf.compat.v1 canned Estimators.
|
||||
* A checkpoint converter for canned Estimators has been provided to transition canned Estimators that are warm started from `tf.train.Optimizers` to `tf.keras.optimizers`.
|
||||
* Losses are scaled in canned estimator v2 and not in the optimizers anymore. If you are using Estimator + distribution strategy + optimikzer v1 then the behavior does not change. This implies that if you are using custom estimator with optimizer v2, you have to scale losses. We have new utilities to help scale losses `tf.nn.compute_average_loss`, `tf.nn.scale_regularization_loss`.
|
||||
* `tf.data`:
|
||||
|
||||
* `tf.keras`:
|
||||
* Premade models (including Linear and WideDeep) have been introduced for the purpose of replacing Premade estimators.
|
||||
* Model saving changes
|
||||
* `model.save` and `tf.saved_model.save` may now save to the TensorFlow SavedModel format. The model can be restored using `tf.keras.models.load_model`. HDF5 files are still supported, and may be used by specifying `save_format="h5"` when saving.
|
||||
* Raw TensorFlow functions can now be used in conjunction with the Keras Functional API during model creation. This obviates the need for users to create Lambda layers in most cases when using the Functional API. Like Lambda layers, TensorFlow functions that result in Variable creation or assign ops are not supported.
|
||||
* Add support for passing list of lists to the `metrics` argument in Keras `compile`.
|
||||
* Add `tf.keras.layers.AbstractRNNCell` as the preferred implementation for RNN cells in TF v2. User can use it to implement RNN cells with custom behavior.
|
||||
* Keras training and validation curves are shown on the same plot when using the TensorBoard callback.
|
||||
* Switched Keras `fit/evaluate/predict` execution to use only a single unified path by default unless eager execution has been explicitly disabled, regardless of input type. This unified path places an eager-friendly training step inside of a `tf.function`. With this
|
||||
1. All input types are converted to `Dataset`.
|
||||
2. The path assumes there is always a distribution strategy. when distribution strategy is not specified the path uses a no-op distribution strategy.
|
||||
3. The training step is wrapped in `tf.function` unless `run_eagerly=True` is set in compile. The single path execution code does not yet support all use cases. We fallback to the existing v1 execution paths if your model contains the following:
|
||||
1. `sample_weight_mode` in compile
|
||||
2. `weighted_metrics` in compile
|
||||
3. v1 optimizer
|
||||
4. target tensors in compile
|
||||
If you are experiencing any issues because of this change, please inform us (file an issue) about your use case and you can unblock yourself by setting `experimental_run_tf_function=False` in compile meanwhile. We have seen couple of use cases where the model usage pattern is not as expected and would not work with this change.
|
||||
1. output tensors of one layer is used in the constructor of another.
|
||||
2. symbolic tensors outside the scope of the model are used in custom loss functions.
|
||||
The flag can be disabled for these cases and ideally the usage pattern will need to be fixed.
|
||||
* Mark Keras `set_session` as `compat.v1` only.
|
||||
* `tf.keras.estimator.model_to_estimator` now supports exporting to `tf.train.Checkpoint format`, which allows the saved checkpoints to be compatible with `model.load_weights`.
|
||||
* `keras.backend.resize_images` (and consequently, `keras.layers.Upsampling2D`) behavior has changed, a bug in the resizing implementation was fixed.
|
||||
* Add an `implementation=3` mode for `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D` layers using `tf.SparseTensor` to store weights, allowing a dramatic speedup for large sparse models.
|
||||
* Raise error if `batch_size` argument is used when input is dataset/generator/keras sequence.
|
||||
* Update TF 2.0 `keras.backend.name_scope` to use TF 2.0 `name_scope`.
|
||||
* Add v2 module aliases for losses, metrics, initializers and optimizers: `tf.losses = tf.keras.losses` & `tf.metrics = tf.keras.metrics` & `tf.initializers = tf.keras.initializers` & `tf.optimizers = tf.keras.optimizers`.
|
||||
* Updates binary cross entropy logic in Keras when input is probabilities. Instead of converting probabilities to logits, we are using the cross entropy formula for probabilities.
|
||||
* Added public APIs for `cumsum` and `cumprod` keras backend functions.
|
||||
* Add support for temporal sample weight mode in subclassed models.
|
||||
* Raise `ValueError` if an integer is passed to the training APIs.
|
||||
* Added fault-tolerance support for training Keras model via `model.fit()` with `MultiWorkerMirroredStrategy`, tutorial available.
|
||||
* Custom Callback tutorial is now available.
|
||||
* To train with `tf.distribute`, Keras API is recommended over estimator.
|
||||
* `steps_per_epoch` and `steps` arguments are supported with numpy arrays.
|
||||
* New error message when unexpected keys are used in sample_weight/class_weight dictionaries
|
||||
* Losses are scaled in Keras compile/fit and not in the optimizers anymore. If you are using custom training loop, we have new utilities to help scale losses `tf.nn.compute_average_loss`, `tf.nn.scale_regularization_loss`.
|
||||
* `Layer` apply and add_variable APIs are deprecated.
|
||||
* Added support for channels first data format in cross entropy losses with logits and support for tensors with unknown ranks.
|
||||
* Error messages will be raised if `add_update`, `add_metric`, `add_loss`, activity regularizers are used inside of a control flow branch.
|
||||
* New loss reduction types:
|
||||
1. `AUTO`: Indicates that the reduction option will be determined by the usage context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`. When used with `tf.distribute.Strategy`, outside of built-in training loops such as `tf.keras` `compile` and `fit`, we expect reduction value to be `SUM` or `NONE`. Using `AUTO` in that case will raise an error.
|
||||
2. `NONE`: Weighted losses with one dimension reduced (axis=-1, or axis specified by loss function). When this reduction type used with built-in Keras training loops like `fit`/`evaluate`, the unreduced vector loss is passed to the optimizer but the reported loss will be a scalar value.
|
||||
3. `SUM`: Scalar sum of weighted losses. 4. `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses. This reduction type is not supported when used with `tf.distribute.Strategy` outside of built-in training loops like `tf.keras` `compile`/`fit`.
|
||||
* Wraps losses passed to the `compile` API (strings and v1 losses) which are not instances of v2 `Loss` class in `LossWrapper` class. => All losses will now use `SUM_OVER_BATCH_SIZE` reduction as default.
|
||||
* `model.add_loss(symbolic_tensor)` should work in ambient eager.
|
||||
* Update metric name to always reflect what the user has given in compile. Affects following cases
|
||||
1. When name is given as 'accuracy'/'crossentropy'
|
||||
2. When an aliased function name is used eg. 'mse'
|
||||
3. Removing the `weighted` prefix from weighted metric names.
|
||||
* Allow non-Tensors through v2 losses.
|
||||
* Add v2 sparse categorical crossentropy metric.
|
||||
* Add v2 APIs for `AUCCurve` and `AUCSummationMethod` enums.
|
||||
* `add_update` can now be passed a zero-arg callable in order to support turning off the update when setting `trainable=False` on a Layer of a Model compiled with `run_eagerly=True`.
|
||||
* Standardize the LayerNormalization API by replacing the args `norm_axis` and `params_axis` with `axis`.
|
||||
* Fixed critical bugs that help with DenseFeatures usability in TF2
|
||||
* Add support for TensorArrays to `tf.data Dataset`.
|
||||
* Integrate Ragged Tensors with `tf.data`.
|
||||
* All core and experimental tf.data transformations that input
|
||||
user-defined functions can span multiple devices now.
|
||||
* Extending the TF 2.0 support for `shuffle(...,
|
||||
reshuffle_each_iteration=True)` and `cache()` to work across different
|
||||
Python iterators for the same dataset.
|
||||
* Removing the `experimental_numa_aware` option from `tf.data.Options`.
|
||||
* Add `num_parallel_reads` and passing in a Dataset containing filenames
|
||||
into `TextLineDataset` and `FixedLengthRecordDataset`.
|
||||
* Add support for defaulting the value of `cycle_length` argument of
|
||||
`tf.data.Dataset.interleave` to the number of schedulable CPU cores.
|
||||
* Promoting `tf.data.experimental.enumerate_dataset` to core as
|
||||
`tf.data.Dataset.enumerate`.
|
||||
* Promoting `tf.data.experimental.unbatch` to core as
|
||||
`tf.data.Dataset.unbatch`.
|
||||
* Adds option for introducing slack in the pipeline to reduce CPU
|
||||
contention, via `tf.data.Options().experimental_slack = True`
|
||||
* Added experimental support for parallel batching to `batch()` and
|
||||
`padded_batch()`. This functionality can be enabled through
|
||||
`tf.data.Options()`.
|
||||
* Support cancellation of long-running `reduce`.
|
||||
* Now we use `dataset` node name as prefix instead of the op name, to
|
||||
identify the component correctly in metrics, for pipelines with repeated
|
||||
components.
|
||||
* Improve the performance of datasets using `from_tensors()`.
|
||||
* Promoting `unbatch` from experimental to core API.
|
||||
* Adding support for datasets as inputs to `from_tensors` and
|
||||
`from_tensor_slices` and batching and unbatching of nested datasets.
|
||||
|
||||
* `tf.lite`:
|
||||
* Added evaluation script for `COCO` minival
|
||||
* Add delegate support for `QUANTIZE`.
|
||||
* Add `GATHER` support to NN API delegate.
|
||||
* Added support for TFLiteConverter Python API in 2.0. Contains functions from_saved_model, from_keras_file, and from_concrete_functions.
|
||||
* Add `EXPAND_DIMS` support to NN API delegate TEST.
|
||||
* Add `narrow_range` attribute to QuantizeAndDequantizeV2 and V3.
|
||||
* Added support for `tflite_convert` command line tool in 2.0.
|
||||
* Post-training quantization tool supports quantizing weights shared by multiple operations. The models made with versions of this tool will use INT8 types for weights and will only be executable interpreters from this version onwards.
|
||||
* Post-training quantization tool supports fp16 weights and GPU delegate acceleration for fp16.
|
||||
* Add delegate support for `QUANTIZED_16BIT_LSTM`.
|
||||
* Extracts `NNAPIDelegateKernel` from nnapi_delegate.cc
|
||||
* `tf.distribute`:
|
||||
|
||||
* TensorRT
|
||||
* Add TensorFlow 2.0-compatible `TrtGraphConverterV2` API for TensorRT conversion.
|
||||
TensorRT initialization arguments are now passed wrapped in a named-tuple,
|
||||
`TrtConversionParams`, rather than as separate arguments as in `TrtGraphConverter`.
|
||||
* Changed API to optimize TensorRT enginges during graph optimization. This is now
|
||||
done by calling `converter.build()` where previously `is_dynamic_op=False` would
|
||||
be set.
|
||||
* `converter.convert()` no longer returns a `tf.function`. Now the funtion must be
|
||||
accessed from the saved model.
|
||||
* The `converter.calibrate()` method has been removed. To trigger calibration, a
|
||||
`calibration_input_fn` should be provided to `converter.convert()`.
|
||||
* Enable `tf.distribute.experimental.MultiWorkerMirroredStrategy` working
|
||||
in eager mode.
|
||||
* Callbacks are supported in `MultiWorkerMirroredStrategy`.
|
||||
* Disable `run_eagerly` and distribution strategy if there are symbolic
|
||||
tensors added to the model using `add_metric` or `add_loss`.
|
||||
* Loss and gradients should now more reliably be correctly scaled w.r.t.
|
||||
the global batch size when using a `tf.distribute.Strategy`.
|
||||
* Set default loss reduction as `AUTO` for improving reliability of loss
|
||||
scaling with distribution strategy and custom training loops. `AUTO`
|
||||
indicates that the reduction option will be determined by the usage
|
||||
context. For almost all cases this defaults to `SUM_OVER_BATCH_SIZE`.
|
||||
When used in distribution strategy scope, outside of built-in training
|
||||
loops such as `tf.keras` `compile` and `fit`, we expect reduction value
|
||||
to be 'None' or 'SUM'. Using other values will raise an error.
|
||||
* Support for multi-host `ncclAllReduce` in Distribution Strategy.
|
||||
|
||||
* Other:
|
||||
* Fix accidental quadratic graph construction cost in graph-mode `tf.gradients()`.
|
||||
* ResourceVariable's gather op supports batch dimensions.
|
||||
* ResourceVariable support for `gather_nd`.
|
||||
* `ResourceVariable` and `Variable` no longer accepts `constraint` in the constructor, nor expose it as a @property.
|
||||
* Added gradient for `SparseToDense` op.
|
||||
* Expose a flag that allows the number of threads to vary across Python benchmarks.
|
||||
* `image.resize` in 2.0 now supports gradients for the new resize kernels.
|
||||
* `image.resize` now considers proper pixel centers and has new kernels (incl. anti-aliasing).
|
||||
* Renamed `tf.image` functions to remove duplicate "image" where it is redundant.
|
||||
* Variadic reduce is supported on CPU Variadic reduce is supported on CPU
|
||||
* Remove unused `StringViewVariantWrapper`.
|
||||
* Delete unused `Fingerprint64Map` op registration
|
||||
* Add broadcasting support to `tf.matmul`.
|
||||
* Add C++ Gradient for `BatchMatMulV2`.
|
||||
* Add `tf.math.cumulative_logsumexp` operation.
|
||||
* Add ellipsis (...) support for `tf.einsum()`.
|
||||
* Add expand_composites argument to all `nest.*` methods.
|
||||
* Added `strings.byte_split`.
|
||||
* Add a new "result_type" parameter to `tf.strings.split`.
|
||||
* Add name argument to `tf.string_split` and `tf.strings_split`.
|
||||
* Extend `tf.strings.split` to support inputs with any rank.
|
||||
* Added `tf.random.binomial`.
|
||||
* Added `key` and `skip` methods to `random.experimental.Generator`.
|
||||
* Extend `tf.function` with basic support for CompositeTensors arguments (such as `SparseTensor` and `RaggedTensor`).
|
||||
* `parallel_for.pfor`: add converters for Softmax, LogSoftmax, IsNaN, All, Any, and MatrixSetDiag.
|
||||
* `parallel_for`: add converters for LowerTriangularSolve and Cholesky.
|
||||
* `parallel_for`: add converters for `LogMatrixDeterminant` and `MatrixBandPart`.
|
||||
* `parallel_for`: Add converter for `MatrixDiag`.
|
||||
* `parallel_for`: Add converters for `OneHot`, `LowerBound`, `UpperBound`.
|
||||
* `parallel_for`: add converter for `BroadcastTo`.
|
||||
* Add `pfor` converter for `Squeeze`.
|
||||
* Add `RaggedTensor.placeholder()`.
|
||||
* Add ragged tensor support to `tf.squeeze`.
|
||||
* Update RaggedTensors to support int32 row_splits.
|
||||
* Allow `LinearOperator.solve` to take a `LinearOperator`.
|
||||
* Allow all dtypes for `LinearOperatorCirculant`.
|
||||
* Introduce MaxParallelism method
|
||||
* Add `LinearOperatorHouseholder`.
|
||||
* Adds Philox support to new stateful RNG's XLA path.
|
||||
* Added `TensorSpec` support for CompositeTensors.
|
||||
* Added `tf.linalg.tridiagonal_solve` op.
|
||||
* Added partial_pivoting input parameter to `tf.linalg.tridiagonal_solve`.
|
||||
* Added gradient to `tf.linalg.tridiagonal_solve`.
|
||||
* Added `tf.linalg.tridiagonal_mul op`.
|
||||
* Added GPU implementation of `tf.linalg.tridiagonal_matmul`.
|
||||
* Added `LinearOperatorToeplitz`.
|
||||
* Upgraded LIBXSMM to version 1.11.
|
||||
* Uniform processing of quantized embeddings by Gather and EmbeddingLookup Ops.
|
||||
* Correct a misstatement in the documentation of the sparse softmax cross entropy logit parameter.
|
||||
* Add `tf.ragged.boolean_mask`.
|
||||
* `tf.switch_case` added, which selects a branch_fn based on a branch_index.
|
||||
* The C++ kernel of gather op supports batch dimensions.
|
||||
* Fixed default value and documentation for `trainable` arg of tf.Variable.
|
||||
* `EagerTensor` now supports numpy buffer interface for tensors.
|
||||
* This change bumps the version number of the `FullyConnected` Op to 5.
|
||||
* Added new op: `tf.strings.unsorted_segment_join`.
|
||||
* Added HW acceleration support for `topK_v2`.
|
||||
* CloudBigtable version updated to v0.10.0 BEGIN_PUBLIC CloudBigtable version updated to v0.10.0.
|
||||
* Expose `Head` as public API.
|
||||
* Added `tf.sparse.from_dense` utility function.
|
||||
* Improved ragged tensor support in `TensorFlowTestCase`.
|
||||
* Added a function `nested_value_rowids` for ragged tensors.
|
||||
* Added `tf.ragged.stack`.
|
||||
* Makes the a-normal form transformation in Pyct configurable as to which nodes are converted to variables and which are not.
|
||||
* `ResizeInputTensor` now works for all delegates.
|
||||
* `tf.cond` emits a StatelessIf op if the branch functions are stateless and do not touch any resources.
|
||||
* Add support of local soft device placement for eager op.
|
||||
* Pass partial_pivoting to the `_TridiagonalSolveGrad`.
|
||||
* Add HW acceleration support for `LogSoftMax`.
|
||||
* Add guard to avoid acceleration of L2 Normalization with input rank != 4
|
||||
* Fix memory allocation problem when calling `AddNewInputConstantTensor`.
|
||||
* Delegate application failure leaves interpreter in valid state
|
||||
* `tf.while_loop` emits a StatelessWhile op if the cond and body functions are stateless and do not touch any resources.
|
||||
* `tf.cond`, `tf.while` and if and while in AutoGraph now accept a nonscalar predicate if has a single element. This does not affect non-V2 control flow.
|
||||
* Fix potential security vulnerability where decoding variant tensors from proto could result in heap out of bounds memory access.
|
||||
* Only create a GCS directory object if the object does not already exist.
|
||||
* Introduce `dynamic` constructor argument in Layer and Model, which should be set to `True` when using imperative control flow in the `call` method.
|
||||
* Begin adding Go wrapper for C Eager API.
|
||||
* XLA HLO graphs can be inspected with interactive_graphviz tool now.
|
||||
* Add dataset ops to the graph (or create kernels in Eager execution) during the python Dataset object creation instead doing it during Iterator creation time.
|
||||
* Add `batch_dims` argument to `tf.gather`.
|
||||
* The behavior of `tf.gather` is now correct when `axis=None` and `batch_dims<0`.
|
||||
* Update docstring for gather to properly describe the non-empty `batch_dims` case.
|
||||
* Removing of dtype in the constructor of initializers and partition_info in call.
|
||||
* Add `tf.math.nextafter` op.
|
||||
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically dispatches the best kernel implementation based on CPU vector architecture. To disable them, build with `--define=tensorflow_mkldnn_contraction_kernel=0`.
|
||||
* `tf.linspace(start, stop, num)` now always uses "stop" as last value (for num > 1)
|
||||
* Added top-k to precision and recall to keras metrics.
|
||||
* Add a ragged size op and register it to the op dispatcher
|
||||
* Transitive dependencies on :`pooling_ops` were removed. Some users may need to add explicit dependencies on :`pooling_ops` if they reference the operators from that library.
|
||||
* Add `CompositeTensor` base class.
|
||||
* Malformed gif images could result in an access out of bounds in the color palette of the frame. This has been fixed now
|
||||
* Add templates and interfaces for creating lookup tables
|
||||
* `Tensor::UnsafeCopyFromInternal` deprecated in favor `Tensor::BitcastFrom`.
|
||||
* In `map_vectorization` optimization, reduce the degree of parallelism in the vectorized map node.
|
||||
* Add variant wrapper for `absl::string_view`.
|
||||
* Add OpKernels for some stateless maps.
|
||||
* DType is no longer convertible to an int. Use `dtype.as_datatype_enum` instead of `int(dtype)` to get the same result.
|
||||
* Support both binary and -1/1 label input in v2 hinge and squared hinge losses.
|
||||
* Added `LinearOperator.adjoint` and `LinearOperator.H` (alias).
|
||||
* Expose CriticalSection in core as `tf.CriticalSection`.
|
||||
* Enhanced graphviz output.
|
||||
* Add opkernel templates for common table operations.
|
||||
* Fix callbacks do not log values in eager mode when a deferred build model is used.
|
||||
* `SignatureDef` util functions have been deprecated.
|
||||
* Update `Fingerprint64Map` to use aliases
|
||||
* Add legacy string flat hash map op kernels.
|
||||
* Add support for `add_metric` in the graph function mode.
|
||||
* Updating cosine similarity loss - removed the negate sign from cosine similarity.
|
||||
* Changed default for gradient accumulation for TPU embeddings to true.
|
||||
* Adds summary trace API for collecting graph and profile information.
|
||||
* The `precision_mode` argument to `TrtGraphConverter` is now case insensitive.
|
||||
* `tf.estimator`:
|
||||
|
||||
* Replace `tf.contrib.estimator.add_metrics` with
|
||||
`tf.estimator.add_metrics`
|
||||
* Use `tf.compat.v1.estimator.inputs` instead of `tf.estimator.inputs`
|
||||
* Replace contrib references with `tf.estimator.experimental.*` for apis
|
||||
in early_s in Estimator
|
||||
* Canned Estimators will now use keras optimizers by default. An error
|
||||
will be raised if tf.train.Optimizers are used, and you will have to
|
||||
switch to tf.keras.optimizers or tf.compat.v1 canned Estimators.
|
||||
* A checkpoint converter for canned Estimators has been provided to
|
||||
transition canned Estimators that are warm started from
|
||||
`tf.train.Optimizers` to `tf.keras.optimizers`.
|
||||
* Losses are scaled in canned estimator v2 and not in the optimizers
|
||||
anymore. If you are using Estimator + distribution strategy + optimikzer
|
||||
v1 then the behavior does not change. This implies that if you are using
|
||||
custom estimator with optimizer v2, you have to scale losses. We have
|
||||
new utilities to help scale losses `tf.nn.compute_average_loss`,
|
||||
`tf.nn.scale_regularization_loss`.
|
||||
|
||||
* `tf.keras`:
|
||||
|
||||
* Premade models (including Linear and WideDeep) have been introduced for
|
||||
the purpose of replacing Premade estimators.
|
||||
* Model saving changes
|
||||
* `model.save` and `tf.saved_model.save` may now save to the TensorFlow
|
||||
SavedModel format. The model can be restored using
|
||||
`tf.keras.models.load_model`. HDF5 files are still supported, and may be
|
||||
used by specifying `save_format="h5"` when saving.
|
||||
* Raw TensorFlow functions can now be used in conjunction with the Keras
|
||||
Functional API during model creation. This obviates the need for users
|
||||
to create Lambda layers in most cases when using the Functional API.
|
||||
Like Lambda layers, TensorFlow functions that result in Variable
|
||||
creation or assign ops are not supported.
|
||||
* Add support for passing list of lists to the `metrics` argument in Keras
|
||||
`compile`.
|
||||
* Add `tf.keras.layers.AbstractRNNCell` as the preferred implementation
|
||||
for RNN cells in TF v2. User can use it to implement RNN cells with
|
||||
custom behavior.
|
||||
* Keras training and validation curves are shown on the same plot when
|
||||
using the TensorBoard callback.
|
||||
* Switched Keras `fit/evaluate/predict` execution to use only a single
|
||||
unified path by default unless eager execution has been explicitly
|
||||
disabled, regardless of input type. This unified path places an
|
||||
eager-friendly training step inside of a `tf.function`. With this
|
||||
* All input types are converted to `Dataset`.
|
||||
* The path assumes there is always a distribution strategy. when
|
||||
distribution strategy is not specified the path uses a no-op
|
||||
distribution strategy.
|
||||
* The training step is wrapped in `tf.function` unless `run_eagerly=True`
|
||||
is set in compile. The single path execution code does not yet support
|
||||
all use cases. We fallback to the existing v1 execution paths if your
|
||||
model contains the following:
|
||||
1. `sample_weight_mode` in compile
|
||||
2. `weighted_metrics` in compile
|
||||
3. v1 optimizer
|
||||
4. target tensors in compile If you are experiencing any issues because
|
||||
of this change, please inform us (file an issue) about your use case
|
||||
and you can unblock yourself by setting
|
||||
`experimental_run_tf_function=False` in compile meanwhile. We have
|
||||
seen couple of use cases where the model usage pattern is not as
|
||||
expected and would not work with this change.
|
||||
* output tensors of one layer is used in the constructor of another.
|
||||
* symbolic tensors outside the scope of the model are used in custom loss
|
||||
functions. The flag can be disabled for these cases and ideally the
|
||||
usage pattern will need to be fixed.
|
||||
* Mark Keras `set_session` as `compat.v1` only.
|
||||
* `tf.keras.estimator.model_to_estimator` now supports exporting to
|
||||
`tf.train.Checkpoint format`, which allows the saved checkpoints to be
|
||||
compatible with `model.load_weights`.
|
||||
* `keras.backend.resize_images` (and consequently,
|
||||
`keras.layers.Upsampling2D`) behavior has changed, a bug in the resizing
|
||||
implementation was fixed.
|
||||
* Add an `implementation=3` mode for `tf.keras.layers.LocallyConnected2D`
|
||||
and `tf.keras.layers.LocallyConnected1D` layers using `tf.SparseTensor`
|
||||
to store weights, allowing a dramatic speedup for large sparse models.
|
||||
* Raise error if `batch_size` argument is used when input is
|
||||
dataset/generator/keras sequence.
|
||||
* Update TF 2.0 `keras.backend.name_scope` to use TF 2.0 `name_scope`.
|
||||
* Add v2 module aliases for losses, metrics, initializers and optimizers:
|
||||
`tf.losses = tf.keras.losses` & `tf.metrics = tf.keras.metrics` &
|
||||
`tf.initializers = tf.keras.initializers` & `tf.optimizers =
|
||||
tf.keras.optimizers`.
|
||||
* Updates binary cross entropy logic in Keras when input is probabilities.
|
||||
Instead of converting probabilities to logits, we are using the cross
|
||||
entropy formula for probabilities.
|
||||
* Added public APIs for `cumsum` and `cumprod` keras backend functions.
|
||||
* Add support for temporal sample weight mode in subclassed models.
|
||||
* Raise `ValueError` if an integer is passed to the training APIs.
|
||||
* Added fault-tolerance support for training Keras model via `model.fit()`
|
||||
with `MultiWorkerMirroredStrategy`, tutorial available.
|
||||
* Custom Callback tutorial is now available.
|
||||
* To train with `tf.distribute`, Keras API is recommended over estimator.
|
||||
* `steps_per_epoch` and `steps` arguments are supported with numpy arrays.
|
||||
* New error message when unexpected keys are used in
|
||||
sample_weight/class_weight dictionaries
|
||||
* Losses are scaled in Keras compile/fit and not in the optimizers
|
||||
anymore. If you are using custom training loop, we have new utilities to
|
||||
help scale losses `tf.nn.compute_average_loss`,
|
||||
`tf.nn.scale_regularization_loss`.
|
||||
* `Layer` apply and add_variable APIs are deprecated.
|
||||
* Added support for channels first data format in cross entropy losses
|
||||
with logits and support for tensors with unknown ranks.
|
||||
* Error messages will be raised if `add_update`, `add_metric`, `add_loss`,
|
||||
activity regularizers are used inside of a control flow branch.
|
||||
* New loss reduction types:
|
||||
* `AUTO`: Indicates that the reduction option will be determined by the
|
||||
usage context. For almost all cases this defaults to
|
||||
`SUM_OVER_BATCH_SIZE`. When used with `tf.distribute.Strategy`, outside
|
||||
of built-in training loops such as `tf.keras` `compile` and `fit`, we
|
||||
expect reduction value to be `SUM` or `NONE`. Using `AUTO` in that case
|
||||
will raise an error.
|
||||
* `NONE`: Weighted losses with one dimension reduced (axis=-1, or axis
|
||||
specified by loss function). When this reduction type used with built-in
|
||||
Keras training loops like `fit`/`evaluate`, the unreduced vector loss is
|
||||
passed to the optimizer but the reported loss will be a scalar value.
|
||||
* `SUM`: Scalar sum of weighted losses. 4. `SUM_OVER_BATCH_SIZE`: Scalar
|
||||
`SUM` divided by number of elements in losses. This reduction type is
|
||||
not supported when used with `tf.distribute.Strategy` outside of
|
||||
built-in training loops like `tf.keras` `compile`/`fit`.
|
||||
* Wraps losses passed to the `compile` API (strings and v1 losses) which
|
||||
are not instances of v2 `Loss` class in `LossWrapper` class. => All
|
||||
losses will now use `SUM_OVER_BATCH_SIZE` reduction as default.
|
||||
* `model.add_loss(symbolic_tensor)` should work in ambient eager.
|
||||
* Update metric name to always reflect what the user has given in compile.
|
||||
Affects following cases
|
||||
* When name is given as 'accuracy'/'crossentropy'
|
||||
* When an aliased function name is used eg. 'mse'
|
||||
* Removing the `weighted` prefix from weighted metric names.
|
||||
* Allow non-Tensors through v2 losses.
|
||||
* Add v2 sparse categorical crossentropy metric.
|
||||
* Add v2 APIs for `AUCCurve` and `AUCSummationMethod` enums.
|
||||
* `add_update` can now be passed a zero-arg callable in order to support
|
||||
turning off the update when setting `trainable=False` on a Layer of a
|
||||
Model compiled with `run_eagerly=True`.
|
||||
* Standardize the LayerNormalization API by replacing the args `norm_axis`
|
||||
and `params_axis` with `axis`.
|
||||
* Fixed critical bugs that help with DenseFeatures usability in TF2
|
||||
|
||||
* `tf.lite`:
|
||||
|
||||
* Added evaluation script for `COCO` minival
|
||||
* Add delegate support for `QUANTIZE`.
|
||||
* Add `GATHER` support to NN API delegate.
|
||||
* Added support for TFLiteConverter Python API in 2.0. Contains functions
|
||||
from_saved_model, from_keras_file, and from_concrete_functions.
|
||||
* Add `EXPAND_DIMS` support to NN API delegate TEST.
|
||||
* Add `narrow_range` attribute to QuantizeAndDequantizeV2 and V3.
|
||||
* Added support for `tflite_convert` command line tool in 2.0.
|
||||
* Post-training quantization tool supports quantizing weights shared by
|
||||
multiple operations. The models made with versions of this tool will use
|
||||
INT8 types for weights and will only be executable interpreters from
|
||||
this version onwards.
|
||||
* Post-training quantization tool supports fp16 weights and GPU delegate
|
||||
acceleration for fp16.
|
||||
* Add delegate support for `QUANTIZED_16BIT_LSTM`.
|
||||
* Extracts `NNAPIDelegateKernel` from nnapi_delegate.cc
|
||||
|
||||
* TensorRT
|
||||
|
||||
* Add TensorFlow 2.0-compatible `TrtGraphConverterV2` API for TensorRT
|
||||
conversion. TensorRT initialization arguments are now passed wrapped in
|
||||
a named-tuple, `TrtConversionParams`, rather than as separate arguments
|
||||
as in `TrtGraphConverter`.
|
||||
* Changed API to optimize TensorRT enginges during graph optimization.
|
||||
This is now done by calling `converter.build()` where previously
|
||||
`is_dynamic_op=False` would be set.
|
||||
* `converter.convert()` no longer returns a `tf.function`. Now the
|
||||
function must be accessed from the saved model.
|
||||
* The `converter.calibrate()` method has been removed. To trigger
|
||||
calibration, a `calibration_input_fn` should be provided to
|
||||
`converter.convert()`.
|
||||
|
||||
* Other:
|
||||
|
||||
* Fix accidental quadratic graph construction cost in graph-mode
|
||||
`tf.gradients()`.
|
||||
* ResourceVariable's gather op supports batch dimensions.
|
||||
* ResourceVariable support for `gather_nd`.
|
||||
* `ResourceVariable` and `Variable` no longer accepts `constraint` in the
|
||||
constructor, nor expose it as a @property.
|
||||
* Added gradient for `SparseToDense` op.
|
||||
* Expose a flag that allows the number of threads to vary across Python
|
||||
benchmarks.
|
||||
* `image.resize` in 2.0 now supports gradients for the new resize kernels.
|
||||
* `image.resize` now considers proper pixel centers and has new kernels
|
||||
(incl. anti-aliasing).
|
||||
* Renamed `tf.image` functions to remove duplicate "image" where it is
|
||||
redundant.
|
||||
* Variadic reduce is supported on CPU Variadic reduce is supported on CPU
|
||||
* Remove unused `StringViewVariantWrapper`.
|
||||
* Delete unused `Fingerprint64Map` op registration
|
||||
* Add broadcasting support to `tf.matmul`.
|
||||
* Add C++ Gradient for `BatchMatMulV2`.
|
||||
* Add `tf.math.cumulative_logsumexp` operation.
|
||||
* Add ellipsis (...) support for `tf.einsum()`.
|
||||
* Add expand_composites argument to all `nest.*` methods.
|
||||
* Added `strings.byte_split`.
|
||||
* Add a new "result_type" parameter to `tf.strings.split`.
|
||||
* Add name argument to `tf.string_split` and `tf.strings_split`.
|
||||
* Extend `tf.strings.split` to support inputs with any rank.
|
||||
* Added `tf.random.binomial`.
|
||||
* Added `key` and `skip` methods to `random.experimental.Generator`.
|
||||
* Extend `tf.function` with basic support for CompositeTensors arguments
|
||||
(such as `SparseTensor` and `RaggedTensor`).
|
||||
* `parallel_for.pfor`: add converters for Softmax, LogSoftmax, IsNaN, All,
|
||||
Any, and MatrixSetDiag.
|
||||
* `parallel_for`: add converters for LowerTriangularSolve and Cholesky.
|
||||
* `parallel_for`: add converters for `LogMatrixDeterminant` and
|
||||
`MatrixBandPart`.
|
||||
* `parallel_for`: Add converter for `MatrixDiag`.
|
||||
* `parallel_for`: Add converters for `OneHot`, `LowerBound`, `UpperBound`.
|
||||
* `parallel_for`: add converter for `BroadcastTo`.
|
||||
* Add `pfor` converter for `Squeeze`.
|
||||
* Add `RaggedTensor.placeholder()`.
|
||||
* Add ragged tensor support to `tf.squeeze`.
|
||||
* Update RaggedTensors to support int32 row_splits.
|
||||
* Allow `LinearOperator.solve` to take a `LinearOperator`.
|
||||
* Allow all dtypes for `LinearOperatorCirculant`.
|
||||
* Introduce MaxParallelism method
|
||||
* Add `LinearOperatorHouseholder`.
|
||||
* Adds Philox support to new stateful RNG's XLA path.
|
||||
* Added `TensorSpec` support for CompositeTensors.
|
||||
* Added `tf.linalg.tridiagonal_solve` op.
|
||||
* Added partial_pivoting input parameter to `tf.linalg.tridiagonal_solve`.
|
||||
* Added gradient to `tf.linalg.tridiagonal_solve`.
|
||||
* Added `tf.linalg.tridiagonal_mul op`.
|
||||
* Added GPU implementation of `tf.linalg.tridiagonal_matmul`.
|
||||
* Added `LinearOperatorToeplitz`.
|
||||
* Upgraded LIBXSMM to version 1.11.
|
||||
* Uniform processing of quantized embeddings by Gather and EmbeddingLookup
|
||||
Ops.
|
||||
* Correct a misstatement in the documentation of the sparse softmax cross
|
||||
entropy logit parameter.
|
||||
* Add `tf.ragged.boolean_mask`.
|
||||
* `tf.switch_case` added, which selects a branch_fn based on a
|
||||
branch_index.
|
||||
* The C++ kernel of gather op supports batch dimensions.
|
||||
* Fixed default value and documentation for `trainable` arg of
|
||||
tf.Variable.
|
||||
* `EagerTensor` now supports numpy buffer interface for tensors.
|
||||
* This change bumps the version number of the `FullyConnected` Op to 5.
|
||||
* Added new op: `tf.strings.unsorted_segment_join`.
|
||||
* Added HW acceleration support for `topK_v2`.
|
||||
* CloudBigtable version updated to v0.10.0 BEGIN_PUBLIC CloudBigtable
|
||||
version updated to v0.10.0.
|
||||
* Expose `Head` as public API.
|
||||
* Added `tf.sparse.from_dense` utility function.
|
||||
* Improved ragged tensor support in `TensorFlowTestCase`.
|
||||
* Added a function `nested_value_rowids` for ragged tensors.
|
||||
* Added `tf.ragged.stack`.
|
||||
* Makes the a-normal form transformation in Pyct configurable as to which
|
||||
nodes are converted to variables and which are not.
|
||||
* `ResizeInputTensor` now works for all delegates.
|
||||
* `tf.cond` emits a StatelessIf op if the branch functions are stateless
|
||||
and do not touch any resources.
|
||||
* Add support of local soft device placement for eager op.
|
||||
* Pass partial_pivoting to the `_TridiagonalSolveGrad`.
|
||||
* Add HW acceleration support for `LogSoftMax`.
|
||||
* Add guard to avoid acceleration of L2 Normalization with input rank != 4
|
||||
* Fix memory allocation problem when calling `AddNewInputConstantTensor`.
|
||||
* Delegate application failure leaves interpreter in valid state
|
||||
* `tf.while_loop` emits a StatelessWhile op if the cond and body functions
|
||||
are stateless and do not touch any resources.
|
||||
* `tf.cond`, `tf.while` and if and while in AutoGraph now accept a
|
||||
nonscalar predicate if has a single element. This does not affect non-V2
|
||||
control flow.
|
||||
* Fix potential security vulnerability where decoding variant tensors from
|
||||
proto could result in heap out of bounds memory access.
|
||||
* Only create a GCS directory object if the object does not already exist.
|
||||
* Introduce `dynamic` constructor argument in Layer and Model, which
|
||||
should be set to `True` when using imperative control flow in the `call`
|
||||
method.
|
||||
* Begin adding Go wrapper for C Eager API.
|
||||
* XLA HLO graphs can be inspected with interactive_graphviz tool now.
|
||||
* Add dataset ops to the graph (or create kernels in Eager execution)
|
||||
during the python Dataset object creation instead doing it during
|
||||
Iterator creation time.
|
||||
* Add `batch_dims` argument to `tf.gather`.
|
||||
* The behavior of `tf.gather` is now correct when `axis=None` and
|
||||
`batch_dims<0`.
|
||||
* Update docstring for gather to properly describe the non-empty
|
||||
`batch_dims` case.
|
||||
* Removing of dtype in the constructor of initializers and partition_info
|
||||
in call.
|
||||
* Add `tf.math.nextafter` op.
|
||||
* Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically
|
||||
dispatches the best kernel implementation based on CPU vector
|
||||
architecture. To disable them, build with
|
||||
`--define=tensorflow_mkldnn_contraction_kernel=0`.
|
||||
* `tf.linspace(start, stop, num)` now always uses "stop" as last value
|
||||
(for num > 1)
|
||||
* Added top-k to precision and recall to keras metrics.
|
||||
* Add a ragged size op and register it to the op dispatcher
|
||||
* Transitive dependencies on :`pooling_ops` were removed. Some users may
|
||||
need to add explicit dependencies on :`pooling_ops` if they reference
|
||||
the operators from that library.
|
||||
* Add `CompositeTensor` base class.
|
||||
* Malformed gif images could result in an access out of bounds in the
|
||||
color palette of the frame. This has been fixed now
|
||||
* Add templates and interfaces for creating lookup tables
|
||||
* `Tensor::UnsafeCopyFromInternal` deprecated in favor
|
||||
`Tensor::BitcastFrom`.
|
||||
* In `map_vectorization` optimization, reduce the degree of parallelism in
|
||||
the vectorized map node.
|
||||
* Add variant wrapper for `absl::string_view`.
|
||||
* Add OpKernels for some stateless maps.
|
||||
* DType is no longer convertible to an int. Use `dtype.as_datatype_enum`
|
||||
instead of `int(dtype)` to get the same result.
|
||||
* Support both binary and -1/1 label input in v2 hinge and squared hinge
|
||||
losses.
|
||||
* Added `LinearOperator.adjoint` and `LinearOperator.H` (alias).
|
||||
* Expose CriticalSection in core as `tf.CriticalSection`.
|
||||
* Enhanced graphviz output.
|
||||
* Add opkernel templates for common table operations.
|
||||
* Fix callbacks do not log values in eager mode when a deferred build
|
||||
model is used.
|
||||
* `SignatureDef` util functions have been deprecated.
|
||||
* Update `Fingerprint64Map` to use aliases
|
||||
* Add legacy string flat hash map op kernels.
|
||||
* Add support for `add_metric` in the graph function mode.
|
||||
* Updating cosine similarity loss - removed the negate sign from cosine
|
||||
similarity.
|
||||
* Changed default for gradient accumulation for TPU embeddings to true.
|
||||
* Adds summary trace API for collecting graph and profile information.
|
||||
* The `precision_mode` argument to `TrtGraphConverter` is now case
|
||||
insensitive.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
@ -715,7 +863,7 @@ Weweler, Zantares, zjjott, 卜居, 王振华 (Wang Zhenhua), 黄鑫
|
||||
|
||||
* Updates `png_archive` dependency to 1.6.37 to not be affected by
|
||||
CVE-2019-7317, CVE-2018-13785, and CVE-2018-14048.
|
||||
* Updates `sqlite` depenency to 3.28.0 to not be affected by CVE-2018-20506,
|
||||
* Updates `sqlite` dependency to 3.28.0 to not be affected by CVE-2018-20506,
|
||||
CVE-2018-20346, and CVE-2018-20505.
|
||||
|
||||
# Release 1.12.2
|
||||
@ -901,9 +1049,9 @@ Weweler, Zantares, zjjott, 卜居, 王振华 (Wang Zhenhua), 黄鑫
|
||||
compilation as a second return argument.
|
||||
* XLA HLO graphs can now be rendered as SVG/HTML.
|
||||
* Estimator
|
||||
* Replace all occurences of `tf.contrib.estimator.BaselineEstimator` with
|
||||
* Replace all occurrences of `tf.contrib.estimator.BaselineEstimator` with
|
||||
`tf.estimator.BaselineEstimator`
|
||||
* Replace all occurences of
|
||||
* Replace all occurrences of
|
||||
`tf.contrib.estimator.DNNLinearCombinedEstimator` with
|
||||
`tf.estimator.DNNLinearCombinedEstimator`
|
||||
* Replace all occurrences of `tf.contrib.estimator.DNNEstimator` with
|
||||
@ -915,7 +1063,7 @@ Weweler, Zantares, zjjott, 卜居, 王振华 (Wang Zhenhua), 黄鑫
|
||||
`tf.estimator.Estimator.experimental_export_all_saved_models`.
|
||||
* Update `regression_head` to the new Head API for Canned Estimator V2.
|
||||
* Switch `multi_class_head` to Head API for Canned Estimator V2.
|
||||
* Replace all occurences of `tf.contrib.estimator.InMemoryEvaluatorHook`
|
||||
* Replace all occurrences of `tf.contrib.estimator.InMemoryEvaluatorHook`
|
||||
and `tf.contrib.estimator.make_stop_at_checkpoint_step_hook` with
|
||||
`tf.estimator.experimental.InMemoryEvaluatorHook` and
|
||||
`tf.estimator.experimental.make_stop_at_checkpoint_step_hook`
|
||||
|
@ -33,7 +33,7 @@ except ImportError:
|
||||
from distutils.spawn import find_executable as which
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
_DEFAULT_CUDA_VERSION = '10.1'
|
||||
_DEFAULT_CUDA_VERSION = '10'
|
||||
_DEFAULT_CUDNN_VERSION = '7'
|
||||
_DEFAULT_TENSORRT_VERSION = '6'
|
||||
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
|
||||
|
@ -478,7 +478,7 @@ bzl_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:build_config_root_bzl",
|
||||
"//tensorflow/core/platform:cuda_build_defs_bzl",
|
||||
"//tensorflow/core/platform/default:cuda_build_defs_bzl",
|
||||
"//third_party/mkl:build_defs_bzl",
|
||||
"//third_party/mkl_dnn:build_defs_bzl",
|
||||
"//third_party/ngraph:build_defs_bzl",
|
||||
|
@ -108,6 +108,7 @@ tf_cuda_library(
|
||||
":tf_attrtype",
|
||||
":tf_status_internal",
|
||||
":tf_file_statistics",
|
||||
":tf_tensor_internal",
|
||||
] + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
@ -251,6 +252,7 @@ tf_cuda_library(
|
||||
"tf_tensor.h",
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
@ -259,6 +261,7 @@ tf_cuda_library(
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
@ -37,9 +37,11 @@ tf_cuda_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
@ -53,6 +55,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
|
@ -26,10 +26,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -38,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -1007,9 +1010,105 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
}
|
||||
}
|
||||
|
||||
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDevicePointer may not be called on a remote tensor "
|
||||
"handle.");
|
||||
return nullptr;
|
||||
}
|
||||
if (handle->device() != nullptr) {
|
||||
status->status = handle->device()->Sync();
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return const_cast<void*>(
|
||||
static_cast<const void*>(tensor->tensor_data().data()));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TFE_Context* ctx, const char* device_name, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims, void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device;
|
||||
status->status = ctx->context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (dtype == TF_STRING || dtype == TF_RESOURCE ||
|
||||
!tensorflow::DataTypeCanUseMemcpy(
|
||||
static_cast<tensorflow::DataType>(dtype))) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to create a tensor with a pointer to non-pod memory.");
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||
// the device?
|
||||
TF_ManagedBuffer* buf =
|
||||
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
|
||||
|
||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, ctx->context, &ret_handle);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle(ret_handle);
|
||||
}
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. Returns the size in
|
||||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
if (h == nullptr || h->handle == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle = h->handle;
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
|
||||
"handle.");
|
||||
return 0;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
return 0;
|
||||
}
|
||||
return tensor->TotalBytes();
|
||||
}
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
return NewOrResetOp(ctx, op_or_function_name, status,
|
||||
return NewOrResetOp(ctx, op_or_function_name, nullptr, status,
|
||||
/* op_to_reset= */ nullptr);
|
||||
}
|
||||
|
||||
|
@ -29,9 +29,11 @@ limitations under the License.
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset) {
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
if (op_to_reset) {
|
||||
NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
|
||||
NewOrResetOp(ctx, op_or_function_name, raw_device_name, status,
|
||||
op_to_reset);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
|
@ -22,8 +22,16 @@ limitations under the License.
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
|
||||
// is for performance optimization by reusing an exiting unused op rather than
|
||||
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
|
||||
// does not set the device name. If it's not `NULL`, then it attempts to parse
|
||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||
// than seperately calling it because if the existing op has the same
|
||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
|
||||
const char* op_or_function_name,
|
||||
const char* raw_device_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
@ -426,6 +434,30 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
const char* worker_name,
|
||||
TF_Status* status);
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||
// for a GPU tensor this will return a pointer to GPU memory). The pointer is
|
||||
// only guaranteed to be valid until TFE_DeleteTensorHandle is called on this
|
||||
// TensorHandle. Only supports POD data types.
|
||||
TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*,
|
||||
TF_Status*);
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. Returns the size in
|
||||
// bytes of the memory pointed to by the device pointer returned above.
|
||||
TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*,
|
||||
TF_Status*);
|
||||
|
||||
// Creates a new TensorHandle from memory residing in device_name. Takes
|
||||
// ownership of the memory, and will call deleter to release it after TF
|
||||
// no longer needs it or in case of error.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims,
|
||||
int num_dims, void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -495,5 +495,54 @@ void Executor_MatMul_CPU(bool async) {
|
||||
TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); }
|
||||
TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); }
|
||||
|
||||
void Deleter(void* data, size_t unused, void* tensor_handle) {
|
||||
TFE_DeleteTensorHandle(static_cast<TFE_TensorHandle*>(tensor_handle));
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float* m_float = static_cast<float*>(TF_TensorData(m_data));
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int num_devices = TF_DeviceListCount(devices);
|
||||
for (int d = 0; d < num_devices; ++d) {
|
||||
const char* name = TF_DeviceListName(devices, d, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* copy = TFE_TensorHandleCopyToDevice(m, ctx, name, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
void* data = TFE_TensorHandleDevicePointer(copy, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
size_t size = TFE_TensorHandleDeviceMemorySize(copy, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
int64_t dims[] = {2, 2};
|
||||
TFE_TensorHandle* copy_aliased = TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, name, TF_FLOAT, dims, 2, data, size, &Deleter, copy, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* on_host =
|
||||
TFE_TensorHandleCopyToDevice(copy_aliased, ctx, "CPU:0", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_Tensor* resolved = TFE_TensorHandleResolve(on_host, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const float* resolved_data =
|
||||
static_cast<const float*>(TF_TensorData(resolved));
|
||||
EXPECT_EQ(0, memcmp(m_float, resolved_data, 4 * sizeof(float)));
|
||||
TF_DeleteTensor(resolved);
|
||||
TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
|
||||
TFE_DeleteTensorHandle(on_host);
|
||||
}
|
||||
TF_DeleteTensor(m_data);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -17,7 +17,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset) {
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset) {
|
||||
const char* name = op_or_function_name; // Shorthand
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
@ -25,14 +26,17 @@ TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
|
||||
bool is_function,
|
||||
TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
|
||||
auto create_or_reset =
|
||||
[&op_to_reset, &ctx, &name, &types, &raw_device_name, &status](
|
||||
bool is_function, TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
|
||||
status->status = op_to_reset->Reset(ctx, name, is_function, types,
|
||||
raw_device_name, inference_ctx);
|
||||
return op_to_reset;
|
||||
} else {
|
||||
return new TFE_Op(ctx, name, is_function, types, inference_ctx);
|
||||
TFE_Op* new_op = new TFE_Op(ctx, name, is_function, types, inference_ctx);
|
||||
status->status = new_op->operation.SetDeviceName(raw_device_name);
|
||||
return new_op;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -58,7 +58,7 @@ struct TFE_ContextOptions {
|
||||
TFE_DEVICE_PLACEMENT_SILENT};
|
||||
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
|
||||
// If true, lazily copy the remote inputs of a function to the target devices.
|
||||
bool lazy_remote_inputs_copy = false;
|
||||
bool lazy_remote_inputs_copy = true;
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
@ -134,11 +134,13 @@ struct TFE_Op {
|
||||
inference_ctx.reset();
|
||||
}
|
||||
|
||||
void Reset(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
TFE_OpInferenceContext* infer_ctx) {
|
||||
operation.Reset(ctx->context, op, is_function, t, nullptr);
|
||||
tensorflow::Status Reset(TFE_Context* ctx, const char* op, bool is_function,
|
||||
const tensorflow::AttrTypeMap* t,
|
||||
const char* raw_device_name,
|
||||
TFE_OpInferenceContext* infer_ctx) {
|
||||
inference_ctx.reset(infer_ctx);
|
||||
return operation.Reset(ctx->context, op, is_function, t, raw_device_name,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
tensorflow::EagerOperation operation;
|
||||
@ -146,7 +148,8 @@ struct TFE_Op {
|
||||
};
|
||||
|
||||
TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status, TFE_Op* op_to_reset = nullptr);
|
||||
const char* raw_device_name, TF_Status* status,
|
||||
TFE_Op* op_to_reset = nullptr);
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
@ -215,9 +215,24 @@ Status ModularFileSystem::DeleteFile(const std::string& fname) {
|
||||
Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
|
||||
int64* undeleted_files,
|
||||
int64* undeleted_dirs) {
|
||||
// TODO(mihaimaruseac): Implementation to come in a new change
|
||||
return Status(error::UNIMPLEMENTED,
|
||||
"Modular filesystem stub not implemented yet");
|
||||
if (undeleted_files == nullptr || undeleted_dirs == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"DeleteRecursively must not be called with `undeleted_files` or "
|
||||
"`undeleted_dirs` set to NULL");
|
||||
|
||||
if (ops_->delete_recursively == nullptr)
|
||||
return FileSystem::DeleteRecursively(dirname, undeleted_files,
|
||||
undeleted_dirs);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
std::string translated_name = TranslateName(dirname);
|
||||
uint64_t plugin_undeleted_files, plugin_undeleted_dirs;
|
||||
ops_->delete_recursively(filesystem_.get(), translated_name.c_str(),
|
||||
&plugin_undeleted_files, &plugin_undeleted_dirs,
|
||||
plugin_status.get());
|
||||
*undeleted_files = plugin_undeleted_files;
|
||||
*undeleted_dirs = plugin_undeleted_dirs;
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::DeleteDir(const std::string& dirname) {
|
||||
@ -233,9 +248,14 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname) {
|
||||
}
|
||||
|
||||
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) {
|
||||
// TODO(mihaimaruseac): Implementation to come in a new change
|
||||
return Status(error::UNIMPLEMENTED,
|
||||
"Modular filesystem stub not implemented yet");
|
||||
if (ops_->recursively_create_dir == nullptr)
|
||||
return FileSystem::RecursivelyCreateDir(dirname);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
std::string translated_name = TranslateName(dirname);
|
||||
ops_->recursively_create_dir(filesystem_.get(), translated_name.c_str(),
|
||||
plugin_status.get());
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
}
|
||||
|
||||
Status ModularFileSystem::CreateDir(const std::string& dirname) {
|
||||
@ -324,8 +344,8 @@ Status ModularFileSystem::CopyFile(const std::string& src,
|
||||
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
const std::string& translated_src = TranslateName(src);
|
||||
const std::string& translated_target = TranslateName(target);
|
||||
std::string translated_src = TranslateName(src);
|
||||
std::string translated_target = TranslateName(target);
|
||||
ops_->copy_file(filesystem_.get(), translated_src.c_str(),
|
||||
translated_target.c_str(), plugin_status.get());
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -62,39 +62,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) {
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
class TF_ManagedBuffer : public TensorBuffer {
|
||||
public:
|
||||
TF_ManagedBuffer(void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg)
|
||||
: TensorBuffer(data),
|
||||
len_(len),
|
||||
deallocator_(deallocator),
|
||||
deallocator_arg_(deallocator_arg) {}
|
||||
|
||||
const size_t len_;
|
||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||
void* const deallocator_arg_;
|
||||
|
||||
~TF_ManagedBuffer() override {
|
||||
(*deallocator_)(data(), len_, deallocator_arg_);
|
||||
}
|
||||
|
||||
size_t size() const override { return len_; }
|
||||
TensorBuffer* root_buffer() override { return this; }
|
||||
void FillAllocationDescription(
|
||||
tensorflow::AllocationDescription* proto) const override {
|
||||
tensorflow::int64 rb = size();
|
||||
proto->set_requested_bytes(rb);
|
||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||
}
|
||||
|
||||
// Prevents input forwarding from mutating this buffer.
|
||||
bool OwnsMemory() const override { return false; }
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
|
||||
int num_dims, size_t len) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
@ -30,6 +31,38 @@ typedef struct TF_Tensor {
|
||||
::tensorflow::Tensor tensor;
|
||||
} TF_Tensor;
|
||||
|
||||
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
|
||||
public:
|
||||
TF_ManagedBuffer(void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg)
|
||||
: TensorBuffer(data),
|
||||
len_(len),
|
||||
deallocator_(deallocator),
|
||||
deallocator_arg_(deallocator_arg) {}
|
||||
|
||||
~TF_ManagedBuffer() override {
|
||||
(*deallocator_)(data(), len_, deallocator_arg_);
|
||||
}
|
||||
|
||||
size_t size() const override { return len_; }
|
||||
TensorBuffer* root_buffer() override { return this; }
|
||||
void FillAllocationDescription(
|
||||
tensorflow::AllocationDescription* proto) const override {
|
||||
tensorflow::int64 rb = size();
|
||||
proto->set_requested_bytes(rb);
|
||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||
}
|
||||
|
||||
// Prevents input forwarding from mutating this buffer.
|
||||
bool OwnsMemory() const override { return false; }
|
||||
|
||||
private:
|
||||
const size_t len_;
|
||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||
void* const deallocator_arg_;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorCApi {
|
||||
|
@ -39,7 +39,7 @@ def tf_library(
|
||||
enable_xla_hlo_profiling = False,
|
||||
mlir_components = None,
|
||||
deps = None,
|
||||
tags = None):
|
||||
tags = []):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
|
@ -27,6 +27,17 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
# defs.cc/h only contains string constants, and can be included in mobile
|
||||
# builds.
|
||||
filegroup(
|
||||
name = "mobile_srcs_no_runtime",
|
||||
srcs = [
|
||||
"defs.cc",
|
||||
"defs.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
# Target that bundles up the XLA CPU and GPU JIT devices.
|
||||
cc_library(
|
||||
name = "jit",
|
||||
@ -71,6 +82,19 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_mlir_gpu_jit",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/xla/service:mlir_gpu_plugin",
|
||||
]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_cpu_device",
|
||||
srcs = ["xla_cpu_device.cc"],
|
||||
|
@ -509,10 +509,10 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
auto it = uncompilable_nodes->find(function_identifier);
|
||||
if (it == uncompilable_nodes->end()) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
uncompileable_node_info{std::move(node_info)};
|
||||
uncompilable_node_info{std::move(node_info)};
|
||||
uncompilable_nodes->emplace(
|
||||
std::move(function_identifier),
|
||||
std::make_pair(function, std::move(uncompileable_node_info)));
|
||||
std::make_pair(function, std::move(uncompilable_node_info)));
|
||||
} else {
|
||||
it->second.second.emplace_back(std::move(node_info));
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ limitations under the License.
|
||||
// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
|
||||
// sufficient to converge.
|
||||
//
|
||||
// We first do an optimisitc analysis and, if it does not converge, we then fall
|
||||
// We first do an optimistic analysis and, if it does not converge, we then fall
|
||||
// back to a pessimistic analysis. The optimistic analysis assigns the same
|
||||
// symbolic predicate to all the merge nodes whose preceding enter nodes have
|
||||
// the same frame name on the first iteration. On the second iteration, if all
|
||||
@ -1255,7 +1255,7 @@ Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
|
||||
} else if (IsRootExit(node)) {
|
||||
++num_exits_for_frame[cf.frame_name];
|
||||
}
|
||||
// Edge NextIteration->Merge is counted before starting the traveral to
|
||||
// Edge NextIteration->Merge is counted before starting the traversal to
|
||||
// break the backedges.
|
||||
if (IsMerge(node)) {
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
@ -1458,11 +1458,11 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
|
||||
|
||||
for (Node* n : topo) {
|
||||
// The nodes added to should_revisit in the previous loop need to be
|
||||
// revisited now. Reprocesing these initial nodes may add *their* consumers
|
||||
// to should_revisit, and these newly added nodes will also be processed by
|
||||
// this very same loop. Since we're traversing the graph in topological
|
||||
// order (producers before consumers) and HandleNode(n) can only ever add
|
||||
// n's consumers to should_revisit, we won't "miss" an addition to
|
||||
// revisited now. Reprocessing these initial nodes may add *their*
|
||||
// consumers to should_revisit, and these newly added nodes will also be
|
||||
// processed by this very same loop. Since we're traversing the graph in
|
||||
// topological order (producers before consumers) and HandleNode(n) can only
|
||||
// ever add n's consumers to should_revisit, we won't "miss" an addition to
|
||||
// should_revisit.
|
||||
if (should_revisit[n->id()]) {
|
||||
VLOG(4) << "Revisiting " << n->name();
|
||||
|
@ -95,7 +95,7 @@ extern const char* const kXlaNumResourceArgsAttr;
|
||||
extern const char* const kXlaHasReferenceVarsAttr;
|
||||
|
||||
// Sorts each node's control inputs by their names. This guarantees that for two
|
||||
// structually equivalent GraphDefs, we get the same traversal ordering on
|
||||
// structurally equivalent GraphDefs, we get the same traversal ordering on
|
||||
// node's control input fields.
|
||||
// TODO(hpucha): Move the utilities to a more appropriate place.
|
||||
void SortControlInputs(GraphDef* gdef);
|
||||
|
@ -72,7 +72,7 @@ extern const char kXlaLiftedArgOutsideCompilationAttrName[];
|
||||
|
||||
// Attribute indicating that this is an IdentityN node receiving inputs for a
|
||||
// outside compilation Placeholder node (the original outside compilation node
|
||||
// is moved out of TPU comutation, and we left a Placeholder node there).
|
||||
// is moved out of TPU computation, and we left a Placeholder node there).
|
||||
// Attribute value will be a string, which is the outside compilation cluster
|
||||
// name for the outside compilation Placeholder node.
|
||||
extern const char kXlaOutsideCompilationInputsAttrName[];
|
||||
|
@ -941,7 +941,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest,
|
||||
// "const0"
|
||||
// "identity0" = "const0" (outside compilation cluster "0")
|
||||
// "identity1" = "const0" "^identity0" (outside compilation cluster "1",
|
||||
// control depdent on cluster "0")
|
||||
// control dependent on cluster "0")
|
||||
// "identity2" = "identity1"
|
||||
FunctionDefLibrary fdl;
|
||||
{
|
||||
|
@ -48,6 +48,15 @@ bool SetterForXlaAutoJitFlag(const string& value) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (value == "fusible") {
|
||||
mark_for_compilation_flags->xla_auto_jit_flag
|
||||
.optimization_level_single_gpu = 1;
|
||||
mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
|
||||
1;
|
||||
mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
|
||||
return true;
|
||||
}
|
||||
|
||||
absl::string_view value_sv(value);
|
||||
if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
|
||||
!absl::ConsumeSuffix(&value_sv, ")") ||
|
||||
@ -65,7 +74,9 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
|
||||
Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
|
||||
"Control compilation of operators into XLA computations on CPU and "
|
||||
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
|
||||
"things very likely to be improved; 2 = on for everything. "
|
||||
"things very likely to be improved; 2 = on for everything; "
|
||||
"(experimental) fusible = only for Tensorflow operations that XLA "
|
||||
"knows how to fuse. "
|
||||
"If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
|
||||
"graphs (graphs that have at least one node placed on a GPU and no "
|
||||
"more than one GPU is in use through the entire graph) and 0 "
|
||||
@ -78,6 +89,23 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
|
||||
Flag("tf_xla_max_cluster_size",
|
||||
&mark_for_compilation_flags->tf_xla_max_cluster_size,
|
||||
"Maximum number of operators in an XLA compilation."),
|
||||
Flag(
|
||||
"tf_xla_ops_to_cluster",
|
||||
&mark_for_compilation_flags->tf_xla_ops_to_cluster,
|
||||
"(experimental) "
|
||||
"Limit the operations clustered by XLA to these operations. "
|
||||
"If multiple, separate them with commas. Shortcuts: "
|
||||
" PW: All point-wise operations."
|
||||
" RED: All reduction operations."
|
||||
" MISC: Mixed operations."
|
||||
" PWRED: TF operations that get converted to PW+RED operation in XLA."
|
||||
" REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
|
||||
"converted to ReduceWindow in XLA."
|
||||
" REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
|
||||
"(LRN, LRNGrad)."
|
||||
" BN: TF FusedBatchNorm* operations."
|
||||
" FUSIBLE: All TF operations that XLA can fuse (All the above). "
|
||||
"You can also put any TF operation name, e.g. 'FUSIBLE,Matmul'."),
|
||||
Flag("tf_xla_clustering_debug",
|
||||
&mark_for_compilation_flags->tf_xla_clustering_debug,
|
||||
"Dump graphs during XLA compilation."),
|
||||
|
@ -55,6 +55,9 @@ struct MarkForCompilationPassFlags {
|
||||
// Maximum number of operators in an XLA compilation.
|
||||
int32 tf_xla_max_cluster_size;
|
||||
|
||||
// If non-empty, limit XLA clustering to the following TF operations.
|
||||
string tf_xla_ops_to_cluster;
|
||||
|
||||
// Dump graphs during XLA compilation.
|
||||
bool tf_xla_clustering_debug;
|
||||
|
||||
|
@ -123,7 +123,7 @@ class GraphCycles {
|
||||
absl::Span<const int32> Successors(int32 node) const;
|
||||
absl::Span<const int32> Predecessors(int32 node) const;
|
||||
|
||||
// Return a copy of the sucessors set. This is needed for code using the
|
||||
// Return a copy of the successors set. This is needed for code using the
|
||||
// collection while modifying the GraphCycles.
|
||||
std::vector<int32> SuccessorsCopy(int32 node) const;
|
||||
// Return a copy of the predecessors set. This is needed for code using the
|
||||
|
@ -1076,6 +1076,35 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
|
||||
return true;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<string> GetOrCreateWhitelist() {
|
||||
absl::flat_hash_map<string, std::vector<string>>* whitelist_table =
|
||||
tensorflow::GetWhitelistTable();
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
absl::flat_hash_set<string> whitelist;
|
||||
|
||||
for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
|
||||
if (s == "FUSIBLE") {
|
||||
for (auto pair : *whitelist_table) {
|
||||
whitelist.insert(pair.second.begin(), pair.second.end());
|
||||
}
|
||||
} else if (whitelist_table->contains(s)) {
|
||||
auto v = whitelist_table->at(s);
|
||||
whitelist.insert(v.begin(), v.end());
|
||||
} else if (!s.empty()) {
|
||||
// Should be a user provided TF operation.
|
||||
whitelist.insert(string(s));
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(2) && !whitelist.empty()) {
|
||||
std::vector<string> vwhitelist(whitelist.begin(), whitelist.end());
|
||||
absl::c_sort(vwhitelist);
|
||||
VLOG(2) << "XLA clustering will only consider the following TF operations: "
|
||||
<< absl::StrJoin(vwhitelist, " ");
|
||||
}
|
||||
return whitelist;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
@ -1087,9 +1116,8 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||
*graph_, /*compile_time_const_arg_indices=*/nullptr,
|
||||
&compile_time_const_nodes, lib_runtime));
|
||||
|
||||
// Iterate over nodes in sorted order so that compiler fuel is deterministic.
|
||||
// We can't simply pass op_nodes().begin() and op_nodes().end to the
|
||||
// We can't simply pass op_nodes().begin() and op_nodes().end() to the
|
||||
// std::vector constructor because they're not proper iterators, with
|
||||
// iterator_traits defined and so on.
|
||||
std::vector<Node*> sorted_nodes;
|
||||
@ -1108,6 +1136,19 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
|
||||
VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
|
||||
|
||||
auto whitelist = GetOrCreateWhitelist();
|
||||
|
||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||
// Check that user's provided TF operation really exists.
|
||||
for (auto s : whitelist) {
|
||||
if (!all_ops.contains(string(s))) {
|
||||
return errors::InvalidArgument(
|
||||
"The operation '", s,
|
||||
"' passed to --tf_xla_ops_to_cluster is not supported by XLA.");
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* node : sorted_nodes) {
|
||||
if (*debug_options_.fuel <= 0) {
|
||||
VLOG(1)
|
||||
@ -1145,6 +1186,12 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
|
||||
VLOG(1) << "Rejecting " << node->name()
|
||||
<< " as it is not listed in --tf_xla_ops_to_cluster.";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (compile_time_const_nodes[node->id()]) {
|
||||
const OpDef* op_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1366,7 +1413,7 @@ Status MarkForCompilationPassImpl::Run() {
|
||||
void MarkForCompilationPassImpl::DumpPostClusteringGraphs() {
|
||||
DumpGraphToFile("mark_for_compilation", *graph_, flib_def_);
|
||||
|
||||
// We also dump out an annoated version of the TF graph where the nodes
|
||||
// We also dump out an annotated version of the TF graph where the nodes
|
||||
// names are prefixed with the cluster names. This can help visualizing the
|
||||
// clustering decisions on TensorBoard.
|
||||
Graph new_graph(graph_->op_registry());
|
||||
@ -1714,7 +1761,301 @@ Status MarkForCompilationPass::RunForTest(
|
||||
return MarkForCompilation(options, debug_options);
|
||||
}
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
// Table format: category name: {list of TF operations in that category}
|
||||
static absl::flat_hash_map<string, std::vector<string>>* result =
|
||||
new absl::flat_hash_map<string, std::vector<string>>{
|
||||
// Unary
|
||||
{"PW",
|
||||
{"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
|
||||
"Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", "Expm1",
|
||||
"Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", "Log",
|
||||
"Log1p", "Invert", "LogicalNot", "Ndtri", "Neg", "Rint", "Round",
|
||||
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
|
||||
"Square", "Tan", "Tanh", "Real", "Imag", "Erf", "Erfc", "Erfinv",
|
||||
"Lgamma", "Digamma",
|
||||
// Binary
|
||||
"Add", "AddV2", "Sub", "Mul", "Div", "Atan2", "Complex", "DivNoNan",
|
||||
"MulNoNan", "FloorDiv", "Xlogy", "Xdivy", "FloorMod", "BitwiseAnd",
|
||||
"BitwiseOr", "BitwiseXor", "LeftShift", "RightShift", "LogicalAnd",
|
||||
"LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "TruncateDiv",
|
||||
"TruncateMod", "Equal", "NotEqual", "Greater", "GreaterEqual",
|
||||
"Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", "SoftsignGrad",
|
||||
"TanhGrad", "Pow", "SquaredDifference", "ApproximateEqual",
|
||||
// Others
|
||||
"AddN", "Bitcast", "Cast", "ClipByValue", "Const", "Empty",
|
||||
"Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad",
|
||||
"LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select",
|
||||
"SelectV2", "Transpose", "ConjugateTranspose",
|
||||
"_UnaryOpsComposition",
|
||||
// The following 4 operations are converted to identity
|
||||
"PlaceholderWithDefault", "PreventGradient", "StopGradient",
|
||||
"Snapshot"}},
|
||||
// clang-format off
|
||||
{"RED",
|
||||
{"All", "Any", "Min", "Max", "Mean", "Prod", "Sum"}},
|
||||
// clang-format on
|
||||
{"PWRED",
|
||||
{"ArgMax", "ArgMin", "DiagPart", "Softmax",
|
||||
"SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
|
||||
{"REDUCEWINDOW",
|
||||
{"ArgMax", "ArgMin", "DiagPart", "Softmax",
|
||||
"SparseSoftmaxCrossEntropyWithLogits", "LogSoftmax"}},
|
||||
{"REDUCEWINDOWPW", {"BiasAddGrad", "LRN", "LRNGrad"}},
|
||||
{"BN",
|
||||
{"FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
|
||||
"_FusedBatchNormEx", "FusedBatchNormGrad", "FusedBatchNormGradV2",
|
||||
"FusedBatchNormGradV3"}},
|
||||
{"SORT", {"TopKV2"}}, // XLA version much faster then TF version.
|
||||
{"MISC",
|
||||
// clang-format off
|
||||
{"BroadcastTo", "ExpandDims", "Fill", "NoOp",
|
||||
"Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze",
|
||||
"Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/,
|
||||
"BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2",
|
||||
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
|
||||
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
|
||||
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
||||
"Tile", "Transpose", "InvertPermutation", "Unpack"}}};
|
||||
// clang-format on
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace testing {
|
||||
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
||||
|
||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
absl::flat_hash_set<string> result{"AdjustContrastv2",
|
||||
"AdjustHue",
|
||||
"AdjustSaturation",
|
||||
"Asinh",
|
||||
"Assert",
|
||||
"AssignAddVariableOp",
|
||||
"AssignSubVariableOp",
|
||||
"AssignVariableOp",
|
||||
"AvgPool",
|
||||
"AvgPool3D",
|
||||
"AvgPool3DGrad",
|
||||
"AvgPoolGrad",
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2",
|
||||
"BatchToSpace",
|
||||
"BatchToSpaceND",
|
||||
"BesselI0e",
|
||||
"BesselI1e",
|
||||
"Betainc",
|
||||
"BiasAddV1",
|
||||
"Bucketize",
|
||||
"Case",
|
||||
"CheckNumerics",
|
||||
"Cholesky",
|
||||
"ControlTrigger",
|
||||
"Conv2D",
|
||||
"Conv2DBackpropFilter",
|
||||
"Conv2DBackpropInput",
|
||||
"Conv3D",
|
||||
"Conv3DBackpropFilterV2",
|
||||
"Conv3DBackpropInputV2",
|
||||
"Cross",
|
||||
"Cumprod",
|
||||
"Cumsum",
|
||||
"DataFormatDimMap",
|
||||
"DataFormatVecPermute",
|
||||
"DepthToSpace",
|
||||
"DepthwiseConv2dNative",
|
||||
"DepthwiseConv2dNativeBackpropFilter",
|
||||
"DepthwiseConv2dNativeBackpropInput",
|
||||
"Dequantize",
|
||||
"Diag",
|
||||
"DynamicStitch",
|
||||
"Einsum",
|
||||
"EmptyTensorList",
|
||||
"ExtractImagePatches",
|
||||
"FFT",
|
||||
"FFT2D",
|
||||
"FFT3D",
|
||||
"FakeParam",
|
||||
"FakeQuantWithMinMaxArgs",
|
||||
"FakeQuantWithMinMaxArgsGradient",
|
||||
"FakeQuantWithMinMaxVars",
|
||||
"FakeQuantWithMinMaxVarsGradient",
|
||||
"Gather",
|
||||
"GatherNd",
|
||||
"GatherV2",
|
||||
"HSVToRGB",
|
||||
"IFFT",
|
||||
"IFFT2D",
|
||||
"IFFT3D",
|
||||
"IRFFT",
|
||||
"IRFFT2D",
|
||||
"IRFFT3D",
|
||||
"If",
|
||||
"InTopKV2",
|
||||
"L2Loss",
|
||||
"LeakyRelu",
|
||||
"LinSpace",
|
||||
"ListDiff",
|
||||
"LogMatrixDeterminant",
|
||||
"MatMul",
|
||||
"MatrixBandPart",
|
||||
"MatrixDiag",
|
||||
"MatrixDiagPart",
|
||||
"MatrixDiagPartV2",
|
||||
"MatrixDiagPartV3",
|
||||
"MatrixDiagV2",
|
||||
"MatrixDiagV3",
|
||||
"MatrixInverse",
|
||||
"MatrixSetDiag",
|
||||
"MatrixSetDiagV2",
|
||||
"MatrixSetDiagV3",
|
||||
"MatrixSolve",
|
||||
"MatrixTriangularSolve",
|
||||
"MaxPool",
|
||||
"MaxPool3D",
|
||||
"MaxPool3DGrad",
|
||||
"MaxPool3DGradGrad",
|
||||
"MaxPoolGrad",
|
||||
"MaxPoolGradGrad",
|
||||
"MaxPoolGradGradV2",
|
||||
"MaxPoolGradV2",
|
||||
"MaxPoolV2",
|
||||
"Multinomial",
|
||||
"NextAfter",
|
||||
"NonMaxSuppressionV4",
|
||||
"ParallelDynamicStitch",
|
||||
"ParameterizedTruncatedNormal",
|
||||
"PartitionedCall",
|
||||
"Qr",
|
||||
"QuantizeAndDequantizeV2",
|
||||
"QuantizeAndDequantizeV3",
|
||||
"RFFT",
|
||||
"RFFT2D",
|
||||
"RFFT3D",
|
||||
"RGBToHSV",
|
||||
"RandomShuffle",
|
||||
"RandomStandardNormal",
|
||||
"RandomUniform",
|
||||
"RandomUniformInt",
|
||||
"ReadVariableOp",
|
||||
"ResizeBilinear",
|
||||
"ResizeBilinearGrad",
|
||||
"ResizeNearestNeighbor",
|
||||
"ResourceApplyAdaMax",
|
||||
"ResourceApplyAdadelta",
|
||||
"ResourceApplyAdagrad",
|
||||
"ResourceApplyAdagradDA",
|
||||
"ResourceApplyAdagradV2",
|
||||
"ResourceApplyAdam",
|
||||
"ResourceApplyAddSign",
|
||||
"ResourceApplyCenteredRMSProp",
|
||||
"ResourceApplyFtrl",
|
||||
"ResourceApplyFtrlV2",
|
||||
"ResourceApplyGradientDescent",
|
||||
"ResourceApplyKerasMomentum",
|
||||
"ResourceApplyMomentum",
|
||||
"ResourceApplyPowerSign",
|
||||
"ResourceApplyProximalAdagrad",
|
||||
"ResourceApplyProximalGradientDescent",
|
||||
"ResourceApplyRMSProp",
|
||||
"ResourceGather",
|
||||
"ResourceScatterAdd",
|
||||
"ResourceScatterDiv",
|
||||
"ResourceScatterMax",
|
||||
"ResourceScatterMin",
|
||||
"ResourceScatterMul",
|
||||
"ResourceScatterNdAdd",
|
||||
"ResourceScatterNdSub",
|
||||
"ResourceScatterNdUpdate",
|
||||
"ResourceScatterSub",
|
||||
"ResourceScatterUpdate",
|
||||
"Roll",
|
||||
"ScatterNd",
|
||||
"SelfAdjointEigV2",
|
||||
"SoftmaxCrossEntropyWithLogits",
|
||||
"SpaceToBatch",
|
||||
"SpaceToBatchND",
|
||||
"SpaceToDepth",
|
||||
"SparseMatMul",
|
||||
"SparseToDense",
|
||||
"StackCloseV2",
|
||||
"StackPopV2",
|
||||
"StackPushV2",
|
||||
"StackV2",
|
||||
"StatefulPartitionedCall",
|
||||
"StatefulStandardNormalV2",
|
||||
"StatefulTruncatedNormal",
|
||||
"StatefulUniform",
|
||||
"StatefulUniformFullInt",
|
||||
"StatefulUniformInt",
|
||||
"StatelessIf",
|
||||
"StatelessMultinomial",
|
||||
"StatelessRandomNormal",
|
||||
"StatelessRandomUniform",
|
||||
"StatelessRandomUniformInt",
|
||||
"StatelessTruncatedNormal",
|
||||
"StatelessWhile",
|
||||
"Svd",
|
||||
"SymbolicGradient",
|
||||
"TensorArrayCloseV3",
|
||||
"TensorArrayConcatV3",
|
||||
"TensorArrayGatherV3",
|
||||
"TensorArrayGradV3",
|
||||
"TensorArrayReadV3",
|
||||
"TensorArrayScatterV3",
|
||||
"TensorArraySizeV3",
|
||||
"TensorArraySplitV3",
|
||||
"TensorArrayV3",
|
||||
"TensorArrayWriteV3",
|
||||
"TensorListElementShape",
|
||||
"TensorListFromTensor",
|
||||
"TensorListGather",
|
||||
"TensorListGetItem",
|
||||
"TensorListLength",
|
||||
"TensorListPopBack",
|
||||
"TensorListPushBack",
|
||||
"TensorListReserve",
|
||||
"TensorListSetItem",
|
||||
"TensorListStack",
|
||||
"TensorScatterAdd",
|
||||
"TensorScatterSub",
|
||||
"TensorScatterUpdate",
|
||||
"TridiagonalSolve",
|
||||
"TruncatedNormal",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
"UnsortedSegmentSum",
|
||||
"VarIsInitializedOp",
|
||||
"VariableShape",
|
||||
"While",
|
||||
"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
"XlaDequantize",
|
||||
"XlaDot",
|
||||
"XlaDynamicSlice",
|
||||
"XlaDynamicUpdateSlice",
|
||||
"XlaEinsum",
|
||||
"XlaIf",
|
||||
"XlaKeyValueSort",
|
||||
"XlaPad",
|
||||
"XlaRecv",
|
||||
"XlaReduce",
|
||||
"XlaReduceWindow",
|
||||
"XlaReplicaId",
|
||||
"XlaSelectAndScatter",
|
||||
"XlaSelfAdjointEig",
|
||||
"XlaSend",
|
||||
"XlaSharding",
|
||||
"XlaSort",
|
||||
"XlaSvd",
|
||||
"XlaWhile",
|
||||
"_Arg",
|
||||
"_ArrayToList",
|
||||
"_ListToArray",
|
||||
"_Retval"};
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
@ -56,11 +57,16 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info = nullptr);
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable();
|
||||
|
||||
namespace testing {
|
||||
// DO NOT USE IN PRODUCTION.
|
||||
//
|
||||
// Resets some internal state to let us write reliable unit tests.
|
||||
void ResetClusterSequenceNumber();
|
||||
|
||||
// Return a list of operation that we choose not to put into the whitelist.
|
||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp();
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1803,6 +1803,35 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
|
||||
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
||||
}
|
||||
}
|
||||
TEST(XlaCompilationTest, XLALiteWhitelist) {
|
||||
auto* whitelist_table = tensorflow::GetWhitelistTable();
|
||||
absl::flat_hash_set<string> hwhitelist;
|
||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||
|
||||
// Check that all the operations in the table are existing TF operations
|
||||
for (auto pair : *whitelist_table) {
|
||||
hwhitelist.insert(pair.second.begin(), pair.second.end());
|
||||
for (auto op : pair.second) {
|
||||
ASSERT_TRUE(all_ops.contains(op));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all registered XLA operation are in the whitelist
|
||||
// table or are known to not be in it.
|
||||
|
||||
absl::flat_hash_set<string> known_not_in_list =
|
||||
tensorflow::testing::GetKnownXLAWhitelistOp();
|
||||
std::vector<string> unknow_op;
|
||||
for (string op : vall_ops) {
|
||||
if (!hwhitelist.contains(op) && !known_not_in_list.contains(op)) {
|
||||
unknow_op.push_back(op);
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(unknow_op.empty())
|
||||
<< "Someone added support for a new TF opeations inside XLA. They must "
|
||||
"be included in the XLALite whitelist or blacklist:\n"
|
||||
<< absl::StrJoin(unknow_op, "\n");
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -187,7 +187,7 @@ impl::NodeMatcherProperties Op(string op);
|
||||
// Matches a node with assigned device `assigned_device`.
|
||||
impl::NodeMatcherProperties AssignedDevice(string assigned_device);
|
||||
|
||||
// Matches a node with a boolean typed attrbute named `name` and with value
|
||||
// Matches a node with a boolean typed attribute named `name` and with value
|
||||
// `value`.
|
||||
template <typename ValueTy>
|
||||
impl::NodeMatcherProperties Attr(const string& name, ValueTy value) {
|
||||
|
@ -125,7 +125,7 @@ TEST(NodeMatchers, CheckControlDependence) {
|
||||
"is any node");
|
||||
}
|
||||
|
||||
TEST(NodeMatchers, ConstVaulue) {
|
||||
TEST(NodeMatchers, ConstValue) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output placeholder =
|
||||
ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
|
||||
|
@ -110,7 +110,7 @@ Merges the outputs from the PartitionedCall node and the _XlaRun node.
|
||||
Unlike the TensorFlow Merge op, which requires inputs of some types to be
|
||||
placed on the host, the _XlaMerge op can merge inputs of all types when
|
||||
placed on the device. This prevents the need for copy operations, in
|
||||
particluar when an XLA cluster has int32 outputs. The _XlaMerge up does not
|
||||
particular when an XLA cluster has int32 outputs. The _XlaMerge up does not
|
||||
have a value_index output that identifies the chosen input.
|
||||
)");
|
||||
|
||||
|
@ -262,7 +262,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
<< xla_tensor->shaped_buffer().ToString();
|
||||
// For devices don't allow sync on completion, the device execution is
|
||||
// deferred. We check the execution stream status here to avoid wrong
|
||||
// results from a failed stream being propogated to following
|
||||
// results from a failed stream being propagated to following
|
||||
// host-side ops.
|
||||
if (!device_allows_sync_on_completion) {
|
||||
done_status.Update(xla_tensor->RefreshStatusOfStreams());
|
||||
|
@ -109,6 +109,14 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto iter = session_options.config.device_count().find("GPU");
|
||||
if (iter != session_options.config.device_count().end() &&
|
||||
iter->second == 0) {
|
||||
// Device count for GPU is 0.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string allowed_gpus =
|
||||
session_options.config.gpu_options().visible_device_list();
|
||||
absl::optional<std::set<int>> gpu_ids =
|
||||
|
@ -222,7 +222,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
// using xla::ComputationDataHandle, which is just a symbolic handle that
|
||||
// xla::ComputationBuilder assigns. How does this handle gets assigned for
|
||||
// constant arguments? Even constant arguments get an _Arg node in the graph
|
||||
// instatiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
|
||||
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
|
||||
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
|
||||
// constant XlaLiteral is included in the HLO graph, and subsequently, in
|
||||
|
@ -84,9 +84,9 @@ VariableInfo::~VariableInfo() {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a vector of VaribleInfo instances for the resource variable inputs to
|
||||
// the kernel with context `ctx`. The input indices for the resource variable
|
||||
// inputs are in `variable_indices`.
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
static Status GetVariableInfosFromCtxInputs(
|
||||
OpKernelContext* ctx, absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
|
@ -73,6 +73,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
|
@ -282,8 +282,13 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
|
@ -98,6 +98,17 @@ using xla::StatusOr;
|
||||
namespace errors = tensorflow::errors;
|
||||
namespace tfl = mlir::TFL;
|
||||
|
||||
using llvm::cl::opt;
|
||||
|
||||
// Commandline flag to enable the control of flatbuffer import.
|
||||
bool use_external_constant;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> use_external_constant_flag(
|
||||
"use-external-constant",
|
||||
llvm::cl::desc("Use external constant during flatbuffer import"),
|
||||
llvm::cl::location(use_external_constant), llvm::cl::init(false));
|
||||
|
||||
namespace {
|
||||
bool IsScalar(const TensorT& tensor) {
|
||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||
@ -391,6 +402,21 @@ StatusOr<mlir::ElementsAttr> ConvertIntBuffer(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
|
||||
int32_t buffer_index,
|
||||
OpBuilder builder, Location loc) {
|
||||
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
auto shaped_type = type.dyn_cast<mlir::RankedTensorType>();
|
||||
if (!shaped_type) {
|
||||
return errors::Internal("Constant doesn't have a shape");
|
||||
}
|
||||
auto op = builder.create<tfl::ExternalConstOp>(
|
||||
loc, shaped_type, builder.getI32IntegerAttr(buffer_index));
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
const std::vector<uint8_t>& buffer,
|
||||
OpBuilder builder, Location loc) {
|
||||
@ -608,8 +634,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||
Location base_loc, Builder builder,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool is_entry_point) {
|
||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
||||
bool use_external_constant) {
|
||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||
|
||||
@ -723,8 +749,11 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
auto& const_tensor = *subgraph.tensors[input_num];
|
||||
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
|
||||
auto op_or_err =
|
||||
BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
use_external_constant
|
||||
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
|
||||
op_builder, const_loc)
|
||||
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
if (!op_or_err.ok()) {
|
||||
return emitError(const_loc, op_or_err.status().ToString()),
|
||||
op_or_err.status();
|
||||
@ -768,8 +797,11 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
auto& const_tensor = *subgraph.tensors[index];
|
||||
auto const_loc = TensorLoc(const_tensor, builder, base_loc);
|
||||
auto op_or_err =
|
||||
BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
use_external_constant
|
||||
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
|
||||
op_builder, const_loc)
|
||||
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
|
||||
op_builder, const_loc);
|
||||
if (!op_or_err.ok()) {
|
||||
return emitError(const_loc, op_or_err.status().ToString()),
|
||||
op_or_err.status();
|
||||
@ -804,7 +836,8 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
|
||||
OwningModuleRef tflite::FlatBufferToMlir(
|
||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays) {
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant) {
|
||||
auto model_ptr =
|
||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||
if (nullptr == model_ptr) {
|
||||
@ -858,7 +891,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
// Only the entry point needs pseudo_input_ops
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
builder, ordered_output_arrays,
|
||||
/*is_entry_point=*/e.index() == 0);
|
||||
/*is_entry_point=*/e.index() == 0,
|
||||
/*use_external_constant=*/use_external_constant);
|
||||
if (!func_or_error.ok()) {
|
||||
return emitError(base_loc, "could not translate function ")
|
||||
<< subgraph->name,
|
||||
@ -872,7 +906,8 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
}
|
||||
|
||||
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
MLIRContext* context) {
|
||||
MLIRContext* context,
|
||||
bool use_external_constant) {
|
||||
const llvm::MemoryBuffer* input =
|
||||
source_mgr->getMemoryBuffer(source_mgr->getMainFileID());
|
||||
std::string error;
|
||||
@ -889,11 +924,12 @@ static OwningModuleRef FlatBufferFileToMlirTrans(llvm::SourceMgr* source_mgr,
|
||||
|
||||
return tflite::FlatBufferToMlir(
|
||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||
context, loc, outputs);
|
||||
context, loc, outputs, use_external_constant);
|
||||
}
|
||||
|
||||
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
|
||||
"tflite-flatbuffer-to-mlir",
|
||||
[](llvm::SourceMgr& source_mgr, MLIRContext* context) {
|
||||
return FlatBufferFileToMlirTrans(&source_mgr, context);
|
||||
return FlatBufferFileToMlirTrans(&source_mgr, context,
|
||||
use_external_constant);
|
||||
});
|
||||
|
@ -29,10 +29,13 @@ namespace tflite {
|
||||
// If ordered_output_arrays is not empty, then the imported mlir function will
|
||||
// only return nodes in ordered_output_arrays in the same order. Returns nullptr
|
||||
// on failure, and more specific errors will be emitted via the context.
|
||||
// If `use_external_constant` is true, it will create `tfl.external_const`
|
||||
// instead of `tfl.const`.
|
||||
mlir::OwningModuleRef FlatBufferToMlir(
|
||||
absl::string_view buffer, mlir::MLIRContext* context,
|
||||
mlir::Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays);
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant = false);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
|
||||
|
@ -374,6 +374,10 @@ class Translator {
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
|
||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
|
||||
|
||||
@ -414,8 +418,8 @@ class Translator {
|
||||
// is marked as a stateful operand.
|
||||
bool IsStatefulOperand(mlir::Operation* op, int operand_index);
|
||||
|
||||
// Returns a unique name for `op`.
|
||||
std::string UniqueName(mlir::Operation* op);
|
||||
// Returns a unique name for `val`.
|
||||
std::string UniqueName(mlir::Value* val);
|
||||
|
||||
ModuleOp module_;
|
||||
|
||||
@ -445,8 +449,8 @@ class Translator {
|
||||
std::vector<std::string> failed_custom_ops_;
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Operation* op) {
|
||||
return name_mapper_.GetUniqueName(op);
|
||||
std::string Translator::UniqueName(mlir::Value* val) {
|
||||
return name_mapper_.GetUniqueName(val);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
@ -610,6 +614,21 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
|
||||
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
float tolerance = op.tolerance().convertToFloat();
|
||||
std::vector<uint8_t> custom_options(sizeof(float));
|
||||
memcpy(custom_options.data(), &tolerance, sizeof(float));
|
||||
auto opcode_index =
|
||||
GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
|
||||
return tflite::CreateOperator(
|
||||
builder_, opcode_index, builder_.CreateVector(operands),
|
||||
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
|
||||
/*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_options),
|
||||
tflite::CustomOptionsFormat_FLEXBUFFERS);
|
||||
}
|
||||
|
||||
Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
std::string node_def_str;
|
||||
@ -736,6 +755,9 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
|
||||
auto builtin_code = GetBuiltinOpCode(inst);
|
||||
if (!builtin_code) {
|
||||
if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
|
||||
return BuildNumericVerifyOperator(verify_op, operands, results);
|
||||
}
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
return llvm::None;
|
||||
}
|
||||
@ -864,17 +886,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
|
||||
return;
|
||||
}
|
||||
for (const auto& it : llvm::enumerate(term->getOperands())) {
|
||||
// TODO(jpienaar): If this isn't due to an op, then we'd need to either
|
||||
// ensure the name that will be assigned to the buffer is the same, or
|
||||
// insert an op so that we can have a buffer named such. This cannot
|
||||
// currently happen due to pseudo_input nodes.
|
||||
if (auto op = it.value()->getDefiningOp()) {
|
||||
name_mapper_.InitOpName(op, output_names[it.index()].trim());
|
||||
} else {
|
||||
fn.emitWarning() << "output is not due to an op and '"
|
||||
<< output_names[it.index()]
|
||||
<< "' may not be a named output";
|
||||
}
|
||||
name_mapper_.InitOpName(it.value(), output_names[it.index()].trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -941,16 +953,9 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
for (auto& inst : bb) {
|
||||
if (inst.isKnownTerminator()) break;
|
||||
|
||||
std::string name = UniqueName(&inst);
|
||||
for (size_t i = 0, e = inst.getNumResults(); i < e; ++i) {
|
||||
// Tensors are named by adding result index to name for the particular
|
||||
// operation such as name:0, name:1, name:2 etc. Default port is zero so
|
||||
// the first result can be specified without the port. This is based on
|
||||
// TensorFlow's naming scheme for inputs in the NodeDef proto.
|
||||
std::string suffix = i > 0 ? absl::StrCat(":", i) : "";
|
||||
if (!build_tensor_and_buffer(inst.getResult(i), name + suffix)) {
|
||||
return llvm::None;
|
||||
}
|
||||
for (auto val : inst.getResults()) {
|
||||
std::string name = UniqueName(val);
|
||||
if (!build_tensor_and_buffer(val, name)) return llvm::None;
|
||||
}
|
||||
|
||||
// Skip constant ops as they don't represent a TFLite operator.
|
||||
|
@ -577,6 +577,19 @@ def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
||||
let summary = "External const op.";
|
||||
|
||||
let description = [{
|
||||
External const op holds a `buffer_index` which points to a constant
|
||||
in the flatbuffer.
|
||||
}];
|
||||
|
||||
let arguments = (ins I32Attr:$buffer_index);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
}
|
||||
|
||||
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0>;
|
||||
|
||||
def TFL_CosOp: TFL_Op<"cos", [
|
||||
@ -3113,6 +3126,27 @@ the output tensor can vary depending on how many true values there are in
|
||||
);
|
||||
}
|
||||
|
||||
def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
|
||||
SameOperandsShape]> {
|
||||
|
||||
let summary = "Verifies the numericals of the two operands";
|
||||
|
||||
let description = [{
|
||||
The NumericVerify op is a debugging op to verify the numericals of the two
|
||||
activations. It is a custom op in TFLite.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[QI8, QUI8, QI16, QUI16]>:$input,
|
||||
TensorOf<[F32]>:$ref,
|
||||
|
||||
// Attributes
|
||||
DefaultValuedAttr<F32Attr, "0.1">:$tolerance
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def SVDFResultConstraint: PredOpTrait<
|
||||
"the input and result tensor elemental types must be same",
|
||||
TCresVTEtIsSameAsOp<0, 0>>;
|
||||
|
@ -57,8 +57,10 @@ const char kDetectionPostProcessOp[] =
|
||||
"type: 'int'} attr : { name: 'max_detections' type: 'int'} attr : { "
|
||||
"name: 'nms_iou_threshold' type: 'float'} attr : { name: "
|
||||
"'nms_score_threshold' type: 'float'} attr : { name: 'num_classes' type: "
|
||||
"'int'} attr : { name: 'w_scale' type: 'int'} attr : { name: 'x_scale' "
|
||||
"type: 'int'} attr : { name: 'y_scale' type: 'int'}";
|
||||
"'int'} attr : { name: 'w_scale' type: 'float'} attr : { name: 'x_scale' "
|
||||
"type: 'float'} attr : { name: 'y_scale' type: 'float'} attr { name: "
|
||||
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
||||
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
||||
|
||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||
// conversion mapping for constants defined in TFLite Python API.
|
||||
|
@ -458,13 +458,15 @@ void QuantizationDriver::QuantizeValue(Value *value, QuantParams params,
|
||||
void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
RequantizeState *state) {
|
||||
if (state->pos == RequantizeState::NO_REQUANTIZE) return;
|
||||
builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
|
||||
builder_.setInsertionPointAfter(op);
|
||||
Value *value = op->getResult(index);
|
||||
if (state->pos == RequantizeState::ON_OUTPUT) {
|
||||
Operation *op = value->getUses().begin().getUser(); // `quantize` op
|
||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||
value = op->getResult(0);
|
||||
builder_.setInsertionPoint(op->getBlock(), ++Block::iterator(op));
|
||||
Operation *user = value->getUses().begin().getUser();
|
||||
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||
value = user->getResult(0);
|
||||
builder_.setInsertionPointAfter(user);
|
||||
}
|
||||
}
|
||||
RequantizeValue(value, state, op->getLoc());
|
||||
}
|
||||
|
@ -416,7 +416,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
if (res->hasOneUse()) {
|
||||
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
|
||||
*res->getUsers().begin())) {
|
||||
// quantization parameters can be propgated to next_stats
|
||||
// quantization parameters can be propagated to next_stats
|
||||
redundant_stats_ops.insert(next_stats);
|
||||
// add next_stats to the work list so propagation can
|
||||
// continue.
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Matchers.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
@ -144,12 +145,16 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
//
|
||||
// Full integer quantization disallows "hybrid" operands or results.
|
||||
// Weight quantization allows "hybrid" operands and results.
|
||||
template <typename ConcretTy, typename Q, typename DQ>
|
||||
template <typename ConcretTy, typename Q, typename DQ, typename VERIFIER>
|
||||
struct QuantizationPattern : public RewritePattern {
|
||||
using BaseType = QuantizationPattern<ConcretTy, Q, DQ>;
|
||||
using BaseType = QuantizationPattern<ConcretTy, Q, DQ, VERIFIER>;
|
||||
|
||||
explicit QuantizationPattern(MLIRContext* context)
|
||||
: RewritePattern(DQ::getOperationName(), 1, context) {}
|
||||
explicit QuantizationPattern(MLIRContext* context, bool enable_verify,
|
||||
float error_tolerance, bool single_layer_verify)
|
||||
: RewritePattern(DQ::getOperationName(), 1, context),
|
||||
enable_verify(enable_verify),
|
||||
error_tolerance(error_tolerance),
|
||||
single_layer_verify(single_layer_verify) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
@ -230,7 +235,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(quantized_op);
|
||||
rewriter.setInsertionPointAfter(quantized_op);
|
||||
OperationState new_state(quantized_op->getLoc(),
|
||||
quantized_op->getName().getStringRef(), inputs,
|
||||
output_types, quantized_op->getAttrs());
|
||||
@ -239,9 +244,64 @@ struct QuantizationPattern : public RewritePattern {
|
||||
output.getFirst()->replaceAllUsesWith(
|
||||
new_op->getResult(output.getSecond()));
|
||||
}
|
||||
|
||||
// To verify the numericals, the original floating-point ops are
|
||||
// preserved in the graph. The result of these floating-point ops are sent
|
||||
// to a numeric verifier op as the reference.
|
||||
if (enable_verify) {
|
||||
// For constant operands, the floating-point constant is duplicated in
|
||||
// case it is quantized.
|
||||
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
|
||||
auto def = new_op->getOperand(i)->getDefiningOp();
|
||||
if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
|
||||
DenseFPElementsAttr attr;
|
||||
if (!matchPattern(q.input(), m_Constant(&attr))) {
|
||||
continue;
|
||||
}
|
||||
auto cst = rewriter.create<ConstantOp>(new_op->getLoc(), attr);
|
||||
quantized_op->setOperand(i, cst.getResult());
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
|
||||
if (!quantized_op->getResult(i)
|
||||
->getType()
|
||||
.cast<ShapedType>()
|
||||
.getElementType()
|
||||
.isa<FloatType>()) {
|
||||
continue;
|
||||
}
|
||||
rewriter.setInsertionPointAfter(new_op);
|
||||
FloatAttr tolerance = rewriter.getF32FloatAttr(error_tolerance);
|
||||
// Verify the quantized value by sending the result to the verifier.
|
||||
rewriter.create<VERIFIER>(quantized_op->getLoc(),
|
||||
new_op->getResult(i),
|
||||
quantized_op->getResult(i), tolerance);
|
||||
|
||||
if (single_layer_verify) continue;
|
||||
|
||||
// Find the Dequantize/Dequantize users of the new op results, and
|
||||
// replace the usage. Then all the floating-point ops are connected.
|
||||
// N.B. the return op will use this floating-point result.
|
||||
for (auto user : new_op->getResult(i)->getUsers()) {
|
||||
// Skip the Requantize op, and we know it has a single user.
|
||||
if (llvm::isa<Q>(user)) {
|
||||
user = *user->getResult(0)->getUsers().begin();
|
||||
}
|
||||
if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
|
||||
dequantize.getResult()->replaceAllUsesWith(
|
||||
quantized_op->getResult(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
bool enable_verify;
|
||||
float error_tolerance;
|
||||
bool single_layer_verify;
|
||||
};
|
||||
|
||||
// Converts quantize ops with unsigned quantized types to these with signed
|
||||
@ -342,14 +402,14 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type);
|
||||
// parameters in this type is based on the min and max element of the
|
||||
// attribute. When the elements in the `attr` are not in floating-point, or
|
||||
// the value range isn't straddling zero, an empty type is returned. The min/max
|
||||
// are ajusted to be symmetric if `symmetric` flag is set to True. And
|
||||
// are adjusted to be symmetric if `symmetric` flag is set to True. And
|
||||
// `symmetric` can only be set to true when it is signed and narrow_range.
|
||||
Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
|
||||
unsigned num_bits, bool is_sign,
|
||||
bool narrow_range);
|
||||
|
||||
// Returns the per channel quantized type for an element attribute.
|
||||
// `quant_dim` defines the quantization axis. The channel min/max are ajusted
|
||||
// `quant_dim` defines the quantization axis. The channel min/max are adjusted
|
||||
// to be symmetric if `symmetric` flag is set to True. And `symmetric` can only
|
||||
// be set to true when it is signed and narrow_range.
|
||||
Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
|
||||
|
@ -32,6 +32,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,7 +36,7 @@ class TestGraphDebugInfo(object):
|
||||
@tf.function(
|
||||
input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)])
|
||||
def model(x):
|
||||
y = tf.math.reciprocal(x) # Not supported
|
||||
y = tf.math.betainc(x, 0.5, 1.0) # Not supported
|
||||
return y + y
|
||||
|
||||
func = model.get_concrete_function()
|
||||
@ -47,14 +47,14 @@ class TestGraphDebugInfo(object):
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
# CHECK-LABEL: testConcreteFunctionDebugInfo
|
||||
# CHECK: error: 'tf.Reciprocal' op is neither a custom op nor a flex op
|
||||
# CHECK: error: 'tf.Betainc' op is neither a custom op nor a flex op
|
||||
# CHECK: attrs=attr_protos, op_def=op_def)
|
||||
# CHECK: ^
|
||||
# CHECK: {{.*tensorflow/python/ops/gen_math_ops.py:[0-9]+:[0-9]+: note: called from}}
|
||||
# CHECK: "Reciprocal", x=x, name=name)
|
||||
# CHECK: "Betainc", a=a, b=b, x=x, name=name)
|
||||
# CHECK: ^
|
||||
# CHECK: {{.*tensorflow/compiler/mlir/lite/tests/debuginfo/concrete_function_error.py:[0-9]+:[0-9]+: note: called from}}
|
||||
# CHECK: y = tf.math.reciprocal(x) # Not supported
|
||||
# CHECK: y = tf.math.betainc(x, 0.5, 1.0) # Not supported
|
||||
# CHECK: ^
|
||||
# CHECK: <unknown>:0: error: failed while converting: 'main'
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TestModule(tf.Module):
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec(shape=[3, 3], dtype=tf.float32)])
|
||||
def model(self, x):
|
||||
y = tf.math.reciprocal(x) # Not supported
|
||||
y = tf.math.betainc(x, 0.5, 1.0) # Not supported
|
||||
return y + y
|
||||
|
||||
|
||||
@ -56,14 +56,14 @@ class TestGraphDebugInfo(object):
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
# CHECK-LABEL: testSavedModelDebugInfo
|
||||
# CHECK: error: 'tf.Reciprocal' op is neither a custom op nor a flex op
|
||||
# CHECK: error: 'tf.Betainc' op is neither a custom op nor a flex op
|
||||
# CHECK: attrs=attr_protos, op_def=op_def)
|
||||
# CHECK: ^
|
||||
# CHECK: {{.*tensorflow/python/ops/gen_math_ops.py:[0-9]+:[0-9]+: note: called from}}
|
||||
# CHECK: "Reciprocal", x=x, name=name)
|
||||
# CHECK: "Betainc", a=a, b=b, x=x, name=name)
|
||||
# CHECK: ^
|
||||
# CHECK: {{.*tensorflow/compiler/mlir/lite/tests/debuginfo/saved_model_error.py:[0-9]+:[0-9]+: note: called from}}
|
||||
# CHECK: y = tf.math.reciprocal(x) # Not supported
|
||||
# CHECK: y = tf.math.betainc(x, 0.5, 1.0) # Not supported
|
||||
# CHECK: ^
|
||||
# CHECK: <unknown>:0: error: failed while converting: 'main'
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
# RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -57,7 +57,7 @@ files: "fake/user/code/file_x.py"
|
||||
files: "fake/user/code/file_y.py"
|
||||
files: "fake/user/code/file_z.py"
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 33
|
||||
@ -66,7 +66,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -75,7 +75,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 38
|
||||
@ -84,7 +84,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -93,7 +93,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
@ -102,7 +102,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -111,7 +111,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
@ -120,7 +120,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -129,7 +129,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights"
|
||||
key: "MobilenetV1/Conv2d_0/weights@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 54
|
||||
@ -138,7 +138,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights/read"
|
||||
key: "MobilenetV1/Conv2d_0/weights/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -147,7 +147,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm"
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 5
|
||||
@ -156,7 +156,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D"
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 2
|
||||
@ -165,7 +165,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "input"
|
||||
key: "input@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 40
|
||||
|
@ -1,9 +1,9 @@
|
||||
# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
# RUN: not tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
# CHECK: fake/user/code/file_D.py:28:1: note: called from
|
||||
# CHECK: fake/user/code/file_E.py:29:1: note: called from
|
||||
# CHECK: fake/user/code/file_F.py:30:1: note: called from
|
||||
# CHECK: fake/user/code/file_C.py: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format
|
||||
# CHECK: fake/user/code/file_D.py: note: called from
|
||||
# CHECK: fake/user/code/file_E.py: note: called from
|
||||
# CHECK: fake/user/code/file_F.py: note: called from
|
||||
|
||||
node {
|
||||
name: "input"
|
||||
|
@ -57,7 +57,7 @@ files: "fake/user/code/file_x.py"
|
||||
files: "fake/user/code/file_y.py"
|
||||
files: "fake/user/code/file_z.py"
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 33
|
||||
@ -66,7 +66,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/beta/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -75,7 +75,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 38
|
||||
@ -84,7 +84,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/gamma/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -93,7 +93,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
@ -102,7 +102,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_mean/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -111,7 +111,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 44
|
||||
@ -120,7 +120,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read"
|
||||
key: "MobilenetV1/Conv2d_0/BatchNorm/moving_variance/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -129,7 +129,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights"
|
||||
key: "MobilenetV1/Conv2d_0/weights@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 54
|
||||
@ -138,7 +138,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/Conv2d_0/weights/read"
|
||||
key: "MobilenetV1/Conv2d_0/weights/read@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 49
|
||||
@ -147,7 +147,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm"
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 5
|
||||
@ -156,7 +156,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D"
|
||||
key: "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 2
|
||||
@ -177,7 +177,7 @@ traces {
|
||||
}
|
||||
}
|
||||
traces {
|
||||
key: "input"
|
||||
key: "input@"
|
||||
value {
|
||||
file_line_cols {
|
||||
file_index: 40
|
||||
|
@ -0,0 +1,14 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --use-external-constant - -o - | FileCheck --dump-input-on-failure %s
|
||||
// Ensure that `tfl.external_const` is imported when the flag `-use-external-constant` is enabled.
|
||||
|
||||
func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
^bb0(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>):
|
||||
%cst = constant dense<1.0> : tensor<40xf32>
|
||||
%0:2 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>, tensor<40x40xf32>)
|
||||
return %0 : tensor<40x40xf32>
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32>
|
||||
// CHECK: %[[CONST:[0-9]+]] = "tfl.external_const"() {buffer_index = 3 : i32} : () -> tensor<40xf32>
|
||||
// CHECK-NEXT: %[[FULL:[0-9]+]]:2 = "tfl.fully_connected"(%arg0, %arg1, %[[CONST]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK-NEXT: return %[[FULL]]#0
|
||||
}
|
@ -1290,3 +1290,53 @@ func @assert_remove(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi1>
|
||||
// CHECK-NOT: Assert
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_f16(%arg0: tensor<8xf16>) -> tensor<8xf16> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xf16>) -> tensor<8xf16>
|
||||
return %0: tensor<8xf16>
|
||||
|
||||
// CHECK-LABEL: reciprocal_f16
|
||||
// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf16>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf16>, tensor<8xf16>) -> tensor<8xf16>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_f32(%arg0: tensor<8xf32>) -> tensor<8xf32> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xf32>) -> tensor<8xf32>
|
||||
return %0: tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: reciprocal_f32
|
||||
// CHECK: %cst = constant dense<1.000000e+00> : tensor<1xf32>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xf32>, tensor<8xf32>) -> tensor<8xf32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_complex_f32(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
|
||||
return %0: tensor<8xcomplex<f32>>
|
||||
|
||||
// CHECK-LABEL: reciprocal_complex_f32
|
||||
// CHECK: %cst = constant opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2031207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3030305C30303022"> : tensor<1xcomplex<f32>>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xcomplex<f32>>, tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_i32(%arg0: tensor<8xi32>) -> tensor<8xi32> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xi32>) -> tensor<8xi32>
|
||||
return %0: tensor<8xi32>
|
||||
|
||||
// CHECK-LABEL: reciprocal_i32
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi32>, tensor<8xi32>) -> tensor<8xi32>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reciprocal_i64(%arg0: tensor<8xi64>) -> tensor<8xi64> {
|
||||
%0 = "tf.Reciprocal"(%arg0) : (tensor<8xi64>) -> tensor<8xi64>
|
||||
return %0: tensor<8xi64>
|
||||
|
||||
// CHECK-LABEL: reciprocal_i64
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi64>
|
||||
// CHECK: "tfl.div"(%cst, %arg0) {fused_activation_function = "NONE"} : (tensor<1xi64>, tensor<8xi64>) -> tensor<8xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
@ -16,5 +16,6 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
],
|
||||
)
|
||||
|
@ -1,10 +1,10 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
# CHECK: loc("disable_builtin.mlir":2:1): is a TFLite builtin op but builtin emission is not enabled
|
||||
# CHECK-NEXT: Verification failed.
|
||||
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer -emit-builtin-tflite-ops=false %s 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: 'tfl.add' op is a TFLite builtin op but builtin emission is not enabled
|
||||
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
%0 = "std.constant" () {name = "Const2", value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Add" (%0, %1) {name = "add"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%0 = "std.constant"() {name = "Const2", value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tfl.add"(%0, %arg0) {fused_activation_function = "NONE", name = "add"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %1 : tensor<3x2xi32>
|
||||
}
|
||||
|
@ -1,7 +1,8 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
// CHECK: error: 'tf.Div' op is neither a custom op nor a flex op
|
||||
// CHECK: error: failed while converting: 'main'
|
||||
// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): Div.
|
||||
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: error: 'tf.Div' op is neither a custom op nor a flex op
|
||||
// CHECK: error: failed while converting: 'main'
|
||||
// CHECK: Ops that can be supported by the flex runtime (enabled via setting the -emit-select-tf-ops flag): Div.
|
||||
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
|
@ -0,0 +1,49 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "NumericVerify"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: type: UINT8,
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-NEXT: scale: [ 0.1 ],
|
||||
// CHECK-NEXT: zero_point: [ 0 ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 0 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 1, 0 ],
|
||||
// CHECK-NEXT: outputs: [ ],
|
||||
// CHECK-NEXT: custom_options: [ 205, 204, 204, 61 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT:}
|
||||
|
||||
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4x!quant.uniform<u8:f32, 0.1>>) -> tensor<4xf32> {
|
||||
"tfl.NumericVerify"(%arg1, %arg0) {tolerance = 0.1 : f32} : (tensor<4x!quant.uniform<u8:f32, 0.1>>, tensor<4xf32>) -> ()
|
||||
return %arg0 : tensor<4xf32>
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
^bb0(%arg0: tensor<3x2xi32>):
|
||||
|
@ -418,6 +418,18 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
|
||||
// CHECK return %6 : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RequantizeAlreadyQuantizedModel
|
||||
func @RequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
|
||||
%9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>
|
||||
%10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
|
||||
return %10 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %0 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
|
||||
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: %2 = "tfl.concatenation"(%arg0, %1) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: return %2 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeChain
|
||||
func @QuantizeChain(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
|
||||
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
|
@ -1,4 +1,5 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify | FileCheck --check-prefix=DEBUG %s
|
||||
|
||||
// CHECK-LABEL: QuantizeFloatConst
|
||||
func @QuantizeFloatConst() -> tensor<f32> {
|
||||
@ -48,22 +49,31 @@ func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform<u8:f32, 7.84313725490
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConv2D
|
||||
// DEBUG-LABEL: QuantizeConv2D
|
||||
func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
|
||||
^bb0(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
%cst = constant dense<-1.23697901> : tensor<32xf32>
|
||||
%2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||
%3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32>
|
||||
%w = constant dense<-1.0> : tensor<32x3x3x3xf32>
|
||||
%3 = "tfl.quantize"(%w) {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.1>>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.1>>
|
||||
%4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.1>>) -> tensor<32x3x3x3xf32>
|
||||
%5 = "tfl.conv_2d"(%2, %4, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
return %6 : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
|
||||
// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>, value = dense<-7254> : tensor<32xi32>}
|
||||
// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>}
|
||||
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst0]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>, tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform<i32:f32, 1.7052092479439231E-4>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform<i32:f32, 7.812500e-04>>, value = dense<-1583> : tensor<32xi32>}
|
||||
// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32, 1.000000e-01>>, value = dense<1> : tensor<32x3x3x3xi8>}
|
||||
// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst0]])
|
||||
// CHECK: return %[[conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
}
|
||||
|
||||
// DEBUG: %[[wt:.*]] = constant dense<-1.000000e+00> : tensor<32x3x3x3xf32>
|
||||
// DEBUG: %[[bias:.*]] = constant dense<-1.23697901> : tensor<32xf32>
|
||||
// DEBUG: %[[act:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x224x224x3xf32>
|
||||
// DEBUG: %[[f_conv:.*]] = "tfl.conv_2d"(%[[act]], %[[wt]], %[[bias]])
|
||||
// DEBUG: %[[q_conv:.*]] = "tfl.conv_2d"
|
||||
// DEBUG: "tfl.NumericVerify"(%[[q_conv]], %[[f_conv]]) {tolerance = 1.000000e-01 : f32}
|
||||
// DEBUG: return %[[q_conv]] : tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeDepthwiseConv2D
|
||||
func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>> {
|
||||
@ -212,6 +222,7 @@ func @QuantizeMaxPool2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSplit
|
||||
// DEBUG-LABEL: QuantizeSplit
|
||||
func @QuantizeSplit(%arg: tensor<4x!quant.uniform<u8:f32, 1.0>>, %cst: tensor<i32>) -> (tensor<2x!quant.uniform<u8:f32, 1.0>>,tensor<2x!quant.uniform<u8:f32, 1.0>>) {
|
||||
%0 = "tfl.dequantize"(%arg) : (tensor<4x!quant.uniform<u8:f32, 1.0>>) -> tensor<4xf32>
|
||||
%1:2 = "tfl.split"(%cst, %0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
@ -221,6 +232,11 @@ func @QuantizeSplit(%arg: tensor<4x!quant.uniform<u8:f32, 1.0>>, %cst: tensor<i3
|
||||
|
||||
// CHECK: %[[sp:.*]]:2 = "tfl.split"(%arg1, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<4x!quant.uniform<u8:f32, 1.000000e+00>>)
|
||||
// CHECK: return %[[sp]]#0, %[[sp]]#1
|
||||
|
||||
// DEUBG: %[[f_split:.*]]:2 = "tfl.split"
|
||||
// DEUBG: %[[q_split:.*]]:2 = "tfl.split"
|
||||
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#1, %[[f_split]]#1) {tolerance = 1.000000e-01 : f32}
|
||||
// DEUBG: "tfl.NumericVerify"(%[[q_split]]#0, %[[f_split]]#0) {tolerance = 1.000000e-01 : f32}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeSplitUnusedResults
|
||||
@ -265,6 +281,7 @@ func @QuantizeMultipleUsers(%arg1: tensor<4x!quant.uniform<u8:f32, 1.0>>) -> (te
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotQuantizePow
|
||||
// DEBUG-LABEL: NotQuantizePow
|
||||
func @NotQuantizePow(%arg0: tensor<4x!quant.uniform<u8:f32, 1.0>>,
|
||||
%arg1: tensor<4x!quant.uniform<u8:f32, 1.0>>) -> (tensor<4x!quant.uniform<u8:f32, 1.0>>) {
|
||||
%1 = "tfl.dequantize"(%arg0) : (tensor<4x!quant.uniform<u8:f32, 1.0>>) -> tensor<4xf32>
|
||||
@ -279,4 +296,6 @@ func @NotQuantizePow(%arg0: tensor<4x!quant.uniform<u8:f32, 1.0>>,
|
||||
// CHECK-NEXT: %[[pow:.*]] = tfl.pow %[[dq1]], %[[dq2]]
|
||||
// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[pow]])
|
||||
// CHECK-NEXT: return %[[q]]
|
||||
|
||||
// DEBUG-NOT: "tfl.NumericVerify"
|
||||
}
|
||||
|
@ -413,13 +413,13 @@ void PreprocessTopoSortGraph(
|
||||
}
|
||||
operation_to_in_degrees->try_emplace(&op, input_ops.size());
|
||||
for (auto* input_op : input_ops) {
|
||||
auto preceeding_op_it = operation_to_outputs->find(input_op);
|
||||
if (preceeding_op_it == operation_to_outputs->end()) {
|
||||
auto preceding_op_it = operation_to_outputs->find(input_op);
|
||||
if (preceding_op_it == operation_to_outputs->end()) {
|
||||
auto result = operation_to_outputs->try_emplace(
|
||||
input_op, llvm::DenseSet<Operation*>());
|
||||
preceeding_op_it = result.first;
|
||||
preceding_op_it = result.first;
|
||||
}
|
||||
preceeding_op_it->second.insert(&op);
|
||||
preceding_op_it->second.insert(&op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -233,6 +233,18 @@ def : Pat<(TF_FakeQuantWithMinMaxVarsOp $inputs,
|
||||
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
|
||||
$num_bits, $narrow_range)))>;
|
||||
|
||||
// TODO(rocky): Not all of the attributes are handled correctly. Make this
|
||||
// more general if there is a need.
|
||||
def : Pat<(TF_QuantizeAndDequantizeV2Op $inputs,
|
||||
(ConstantOp F32ElementsAttr:$min),
|
||||
(ConstantOp F32ElementsAttr:$max),
|
||||
$signed_input, $num_bits, $range_given, $round_mode,
|
||||
$narrow_range, $axis),
|
||||
(TFL_DequantizeOp
|
||||
(TFL_QuantizeOp $inputs,
|
||||
(ConvertToQuantTypeFromAttrs $inputs, $min, $max,
|
||||
$num_bits, $narrow_range)))>;
|
||||
|
||||
def : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
|
||||
|
||||
def : Pat<(TF_SquaredDifferenceOp $l, $r), (TFL_SquaredDifferenceOp $l, $r)>;
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
// constant folding support for the TensorFlow ops.
|
||||
|
||||
#include <climits>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
@ -42,6 +43,13 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
@ -50,6 +58,9 @@ namespace TFL {
|
||||
// The actual LegalizeTF Pass.
|
||||
namespace {
|
||||
|
||||
using xla::Status;
|
||||
using xla::StatusOr;
|
||||
|
||||
// Legalize operations in functions.
|
||||
struct LegalizeTF : public FunctionPass<LegalizeTF> {
|
||||
void runOnFunction() override;
|
||||
@ -80,6 +91,7 @@ DECL_CONVERT_OP(Split);
|
||||
DECL_CONVERT_OP(SplitV);
|
||||
DECL_CONVERT_OP(StridedSlice);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
DECL_CONVERT_OP(Reciprocal);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
@ -383,6 +395,103 @@ PatternMatchResult ConvertTFAssertOp::matchAndRewrite(
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
|
||||
Location loc,
|
||||
ShapedType shaped_type,
|
||||
int value) {
|
||||
Type element_type = shaped_type.getElementType();
|
||||
ShapedType ranked_tensor_type = RankedTensorType::get({1}, element_type);
|
||||
Type type = ranked_tensor_type;
|
||||
Attribute attr;
|
||||
switch (element_type.getKind()) {
|
||||
case mlir::StandardTypes::F16: {
|
||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(ranked_tensor_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::F32: {
|
||||
attr = DenseElementsAttr::get<float>(ranked_tensor_type,
|
||||
static_cast<float>(value));
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr;
|
||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||
|
||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
shape->add_dim()->set_size(int64_t{1});
|
||||
std::string content;
|
||||
auto complex_value =
|
||||
std::complex<float>(static_cast<float>(value), 0.0f);
|
||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||
sizeof(complex_value));
|
||||
repr.set_tensor_content(content);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
||||
attr =
|
||||
mlir::OpaqueElementsAttr::get(dialect, ranked_tensor_type, mangled);
|
||||
break;
|
||||
}
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = element_type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 8:
|
||||
attr = DenseElementsAttr::get<int8_t>(ranked_tensor_type,
|
||||
static_cast<int8_t>(value));
|
||||
break;
|
||||
case 16:
|
||||
attr = DenseElementsAttr::get<int16_t>(ranked_tensor_type,
|
||||
static_cast<int16_t>(value));
|
||||
break;
|
||||
case 32:
|
||||
attr = DenseElementsAttr::get<int32_t>(ranked_tensor_type,
|
||||
static_cast<int32_t>(value));
|
||||
break;
|
||||
case 64:
|
||||
attr = DenseElementsAttr::get<int64_t>(ranked_tensor_type,
|
||||
static_cast<int64_t>(value));
|
||||
break;
|
||||
default:
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
return rewriter->create<ConstantOp>(loc, type, attr);
|
||||
}
|
||||
|
||||
PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
|
||||
|
||||
auto status_or_const_op = CreateConstOpWithSingleValue(
|
||||
&rewriter, op->getLoc(),
|
||||
tf_reciprocal_op.x()->getType().cast<ShapedType>(), 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
|
||||
rewriter.replaceOpWithNewOp<TFL::DivOp>(op, status_or_const_op.ValueOrDie(),
|
||||
tf_reciprocal_op.x(),
|
||||
fused_activation_function);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto* ctx = &getContext();
|
||||
@ -390,12 +499,11 @@ void LegalizeTF::runOnFunction() {
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
populateWithGenerated(ctx, &patterns);
|
||||
patterns
|
||||
.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, ConvertTFPackOp,
|
||||
ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFSplitVOp,
|
||||
ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFAssertOp>(
|
||||
ctx);
|
||||
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
|
||||
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
||||
ConvertTFAssertOp, ConvertTFReciprocalOp>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
@ -71,9 +71,7 @@ class TensorListPatternRewriter : public PatternRewriter {
|
||||
explicit TensorListPatternRewriter(FuncOp fn)
|
||||
: PatternRewriter(fn.getContext()) {}
|
||||
|
||||
Operation *createOperation(const OperationState &state) override {
|
||||
return OpBuilder::createOperation(state);
|
||||
}
|
||||
Operation *insert(Operation *op) override { return OpBuilder::insert(op); }
|
||||
};
|
||||
|
||||
/// Lower TensorList ops in functions for subsequent legalization.
|
||||
|
@ -394,14 +394,14 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
// w * (x ' c) + b => (w ' c) x + b
|
||||
// so we have to update the weight.
|
||||
bool is_mul = llvm::isa<MulOp>(binary_op);
|
||||
auto new_fitler =
|
||||
auto new_filter =
|
||||
filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
|
||||
return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
|
||||
});
|
||||
// We recreate the constant op in case it is shared by the other ops. This
|
||||
// might increase the model size.
|
||||
auto new_filter_op = rewriter.create<ConstOp>(
|
||||
fc_op.getLoc(), filter->getType(), new_fitler);
|
||||
fc_op.getLoc(), filter->getType(), new_filter);
|
||||
fc_op.setOperand(0, binary_op->getOperand(0));
|
||||
if (fc_op.filter() != filter) {
|
||||
// This filter goes through quantize and dequantize ops. Then we just
|
||||
|
@ -132,8 +132,8 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
|
||||
// Erases functions from the given candidates that are not referenced by any of
|
||||
// the ops in the module.
|
||||
static void EraseDeadFuncs(const FuncSet& candiate_funcs, ModuleOp module) {
|
||||
if (candiate_funcs.empty()) return;
|
||||
static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
|
||||
if (candidate_funcs.empty()) return;
|
||||
|
||||
SymbolTable manager(module);
|
||||
|
||||
@ -149,7 +149,7 @@ static void EraseDeadFuncs(const FuncSet& candiate_funcs, ModuleOp module) {
|
||||
}
|
||||
});
|
||||
|
||||
for (FuncOp func : candiate_funcs) {
|
||||
for (FuncOp func : candidate_funcs) {
|
||||
if (!in_use_funcs.count(func)) manager.erase(func);
|
||||
}
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
// This is a special case that the quant_dim is the last dimentions.
|
||||
// This is a special case that the quant_dim is the last dimensions.
|
||||
quant_dim = res->getType().template cast<ShapedType>().getRank() - 1;
|
||||
}
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
|
@ -35,6 +35,26 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> enable_numeric_verify(
|
||||
"tfl-numeric-verify", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Whether verify numericals at runtime."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<float> error_tolerance(
|
||||
"tfl-error-tolerance", llvm::cl::value_desc("float"),
|
||||
llvm::cl::desc("Error tolerance for numeric verify. Valid when "
|
||||
"`-tfl-numeric-verify` is set."),
|
||||
llvm::cl::init(1e-1f));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> enable_single_layer_verify(
|
||||
"tfl-single-layer-verify", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Whether verify numericals layer by layer. Valid when "
|
||||
"`-tfl-numeric-verify` is set."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
@ -45,9 +65,11 @@ namespace {
|
||||
|
||||
// Full integer quantization rewrite pattern for TFLite.
|
||||
struct TFLFullQuantization
|
||||
: public QuantizationPattern<TFLFullQuantization, QuantizeOp,
|
||||
DequantizeOp> {
|
||||
explicit TFLFullQuantization(MLIRContext* ctx) : BaseType(ctx) {}
|
||||
: public QuantizationPattern<TFLFullQuantization, QuantizeOp, DequantizeOp,
|
||||
NumericVerifyOp> {
|
||||
explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric,
|
||||
float tolerance, bool verify_single_layer)
|
||||
: BaseType(ctx, verify_numeric, tolerance, verify_single_layer) {}
|
||||
static bool AllowHybridOperand() { return false; }
|
||||
static bool AllowHybridResult() { return false; }
|
||||
};
|
||||
@ -64,7 +86,8 @@ void QuantizePass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<TFLFullQuantization>(ctx);
|
||||
patterns.insert<TFLFullQuantization>(
|
||||
ctx, enable_numeric_verify, error_tolerance, enable_single_layer_verify);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -98,7 +98,7 @@ Value* SliceRankedTensor(OpBuilder* builder, Value* input,
|
||||
ArrayRef<int64_t> size_values,
|
||||
mlir::Location location) {
|
||||
// If the size of the tensor to be sliced from the input overflows
|
||||
// the input tensor's dimenions, return 0-valued tensor of the requested
|
||||
// the input tensor's dimensions, return 0-valued tensor of the requested
|
||||
// shape.
|
||||
ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
|
||||
for (int i = 0; i < input_shape.size(); i++) {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
@ -71,26 +72,26 @@ llvm::StringRef OpOrArgNameMapper::GetUniqueName(llvm::StringRef prefix) {
|
||||
}
|
||||
}
|
||||
|
||||
llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrArg op_or_arg) {
|
||||
auto& name = op_or_arg_to_name_[op_or_arg];
|
||||
llvm::StringRef OpOrArgNameMapper::GetUniqueName(OpOrVal op_or_val) {
|
||||
auto& name = op_or_val_to_name_[op_or_val];
|
||||
if (!name.empty()) return StringViewToRef(name);
|
||||
// Update the value in the map with unique name.
|
||||
llvm::StringRef ref = GetUniqueName(GetName(op_or_arg));
|
||||
llvm::StringRef ref = GetUniqueName(GetName(op_or_val));
|
||||
name = StringRefToView(ref);
|
||||
return ref;
|
||||
}
|
||||
|
||||
absl::string_view OpOrArgNameMapper::GetUniqueNameView(OpOrArg op_or_arg) {
|
||||
auto& name = op_or_arg_to_name_[op_or_arg];
|
||||
absl::string_view OpOrArgNameMapper::GetUniqueNameView(OpOrVal op_or_val) {
|
||||
auto& name = op_or_val_to_name_[op_or_val];
|
||||
if (!name.empty()) return name;
|
||||
// Update the value in the map with unique name.
|
||||
name = StringRefToView(GetUniqueName(GetName(op_or_arg)));
|
||||
name = StringRefToView(GetUniqueName(GetName(op_or_val)));
|
||||
return name;
|
||||
}
|
||||
|
||||
int OpOrArgNameMapper::InitOpName(OpOrArg op_or_arg, llvm::StringRef name) {
|
||||
int OpOrArgNameMapper::InitOpName(OpOrVal op_or_val, llvm::StringRef name) {
|
||||
auto it = name_to_count_.try_emplace(name, 0);
|
||||
op_or_arg_to_name_[op_or_arg] = StringRefToView(it.first->first());
|
||||
op_or_val_to_name_[op_or_val] = StringRefToView(it.first->first());
|
||||
return it.first->second++;
|
||||
}
|
||||
|
||||
@ -139,22 +140,31 @@ std::string GetNameFromLoc(mlir::Location loc) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
std::string OpOrArgLocNameMapper::GetName(OpOrArg op_or_arg) {
|
||||
if (auto* op = op_or_arg.dyn_cast<mlir::Operation*>()) {
|
||||
std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
|
||||
if (auto* op = op_or_val.dyn_cast<mlir::Operation*>()) {
|
||||
auto name_from_loc = GetNameFromLoc(op->getLoc());
|
||||
if (!name_from_loc.empty()) return name_from_loc;
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type.
|
||||
return op->getName().getStringRef();
|
||||
}
|
||||
|
||||
if (auto* arg = op_or_arg.dyn_cast<mlir::BlockArgument*>())
|
||||
return GetNameFromLoc(arg->getLoc());
|
||||
|
||||
auto* val = op_or_val.dyn_cast<mlir::Value*>();
|
||||
auto name_from_loc = GetNameFromLoc(val->getLoc());
|
||||
if (!name_from_loc.empty()) return name_from_loc;
|
||||
// If the location is none of the expected types, then simply use name
|
||||
// generated using the op type. Follow TF convention and append the result
|
||||
// index unless 0.
|
||||
if (auto* result = llvm::dyn_cast<mlir::OpResult>(val)) {
|
||||
if (result->getResultNumber() > 0)
|
||||
return llvm::formatv("{0}:{1}",
|
||||
result->getOwner()->getName().getStringRef(),
|
||||
result->getResultNumber());
|
||||
return result->getOwner()->getName().getStringRef();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string OpOrArgStripNameMapper::GetName(OpOrArg op_or_arg) {
|
||||
std::string OpOrArgStripNameMapper::GetName(OpOrVal op_or_val) {
|
||||
return llvm::APInt(32, count_++).toString(/*Radix=*/36, /*Signed=*/false);
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_op_or_val_NAME_MAPPER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_op_or_val_NAME_MAPPER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
@ -28,28 +28,29 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// PointerUnion for operation and argument.
|
||||
using OpOrArg = llvm::PointerUnion<mlir::Operation*, mlir::BlockArgument*>;
|
||||
// PointerUnion for operation and value.
|
||||
// TODO(jpienaar): Rename the files.
|
||||
using OpOrVal = llvm::PointerUnion<mlir::Operation*, mlir::Value*>;
|
||||
|
||||
// Mapper from operation or argument to name.
|
||||
// Mapper from operation or value to name.
|
||||
class OpOrArgNameMapper {
|
||||
public:
|
||||
// Returns unique name for the given prefix.
|
||||
llvm::StringRef GetUniqueName(llvm::StringRef prefix);
|
||||
|
||||
// Returns unique name for the operation or argument.
|
||||
llvm::StringRef GetUniqueName(OpOrArg op_or_arg);
|
||||
// Returns unique name for the operation or value.
|
||||
llvm::StringRef GetUniqueName(OpOrVal op_or_val);
|
||||
|
||||
// Returns unique name as a string_view for the operation or argument.
|
||||
absl::string_view GetUniqueNameView(OpOrArg op_or_arg);
|
||||
// Returns unique name as a string_view for the operation or value.
|
||||
absl::string_view GetUniqueNameView(OpOrVal op_or_val);
|
||||
|
||||
// Initializes operation or argument to map to name. Returns number of
|
||||
// operations or arguments already named 'name' which should be 0 else
|
||||
// Initializes operation or value to map to name. Returns number of
|
||||
// operations or value already named 'name' which should be 0 else
|
||||
// GetUniqueName could return the same names for different operations or
|
||||
// arguments.
|
||||
// values.
|
||||
// Note: Its up to the caller to decide the behavior when assigning two
|
||||
// operations or arguments to the same name.
|
||||
int InitOpName(OpOrArg op_or_arg, llvm::StringRef name);
|
||||
// operations or values to the same name.
|
||||
int InitOpName(OpOrVal op_or_val, llvm::StringRef name);
|
||||
|
||||
virtual ~OpOrArgNameMapper();
|
||||
|
||||
@ -59,35 +60,35 @@ class OpOrArgNameMapper {
|
||||
virtual bool IsUnique(llvm::StringRef name);
|
||||
|
||||
// Returns a constant view of the underlying map.
|
||||
const llvm::DenseMap<OpOrArg, absl::string_view>& GetMap() const {
|
||||
return op_or_arg_to_name_;
|
||||
const llvm::DenseMap<OpOrVal, absl::string_view>& GetMap() const {
|
||||
return op_or_val_to_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns name from the location of the operation or argument.
|
||||
virtual std::string GetName(OpOrArg op_or_arg) = 0;
|
||||
// Returns name from the location of the operation or value.
|
||||
virtual std::string GetName(OpOrVal op_or_val) = 0;
|
||||
|
||||
// Maps string name to count. This map is used to help keep track of unique
|
||||
// names for operations or arguments.
|
||||
// names for operations or values.
|
||||
llvm::StringMap<int64_t> name_to_count_;
|
||||
// Maps operation or argument to name. Value in map is a view of the string
|
||||
// Maps operation or values to name. Value in map is a view of the string
|
||||
// name in `name_to_count_`. Names in `name_to_count_` are never removed.
|
||||
llvm::DenseMap<OpOrArg, absl::string_view> op_or_arg_to_name_;
|
||||
llvm::DenseMap<OpOrVal, absl::string_view> op_or_val_to_name_;
|
||||
};
|
||||
|
||||
// OpOrArgNameMapper that returns, for operations or arguments not initialized
|
||||
// OpOrArgNameMapper that returns, for operations or values not initialized
|
||||
// to a specific name, a name based on the location of the operation or
|
||||
// argument.
|
||||
// value.
|
||||
class OpOrArgLocNameMapper : public OpOrArgNameMapper {
|
||||
private:
|
||||
std::string GetName(OpOrArg op_or_arg) override;
|
||||
std::string GetName(OpOrVal op_or_val) override;
|
||||
};
|
||||
|
||||
// OpOrArgNameMapper that returns, for operations or arguments not initialized
|
||||
// OpOrArgNameMapper that returns, for operations or values not initialized
|
||||
// to a specific name, a short name.
|
||||
class OpOrArgStripNameMapper : public OpOrArgNameMapper {
|
||||
private:
|
||||
std::string GetName(OpOrArg op_or_arg) override;
|
||||
std::string GetName(OpOrVal op_or_val) override;
|
||||
|
||||
// Number of ops mapped.
|
||||
int count_ = 0;
|
||||
@ -95,4 +96,4 @@ class OpOrArgStripNameMapper : public OpOrArgNameMapper {
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_op_or_val_NAME_MAPPER_H_
|
||||
|
@ -11,7 +11,6 @@ package_group(
|
||||
includes = ["@local_config_mlir//:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/...",
|
||||
"//tensorflow/core/tfrt_delegate/...",
|
||||
"//tensorflow/python/...",
|
||||
],
|
||||
)
|
||||
@ -176,6 +175,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:CallOpInterfacesIncGen",
|
||||
@ -252,6 +252,7 @@ cc_library(
|
||||
"transforms/sink_constant.cc",
|
||||
"transforms/test_side_effect_analysis.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_dynamic_padding_mapper.cc",
|
||||
"transforms/tpu_merge_variables_with_execute.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
"translate/breakup-islands.cc",
|
||||
@ -921,6 +922,7 @@ cc_library(
|
||||
":lower_tf_inc_gen",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -30,11 +30,14 @@ limitations under the License.
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Value.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
@ -75,6 +78,37 @@ int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id,
|
||||
return emplace_res.first->second;
|
||||
}
|
||||
|
||||
// If the return value for `func_op` at `return_index` is a pass-through of an
|
||||
// argument of this function, returns the argument index; otherwise, returns -1.
|
||||
int64_t FindPassthroughArgumentForReturnValue(int64_t return_index,
|
||||
FuncOp func_op) {
|
||||
auto value =
|
||||
func_op.getBody().front().getTerminator()->getOperand(return_index);
|
||||
assert(mlir::getElementTypeOrSelf(value->getType()).isa<TF::ResourceType>());
|
||||
int64_t arg_index = -1;
|
||||
auto try_parse_arg_index = [&arg_index](Value* v) {
|
||||
auto resource_arg = llvm::dyn_cast<BlockArgument>(v);
|
||||
if (resource_arg) arg_index = resource_arg->getArgNumber();
|
||||
return arg_index;
|
||||
};
|
||||
while (try_parse_arg_index(value) == -1) {
|
||||
auto op = value->getDefiningOp();
|
||||
assert(op);
|
||||
int64_t res_num = llvm::dyn_cast<OpResult>(value)->getResultNumber();
|
||||
if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) {
|
||||
value = graph.GetFetch().getOperand(res_num);
|
||||
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
|
||||
value = island.GetYield().getOperand(res_num);
|
||||
} else if (llvm::isa<TF::IdentityNOp>(op) ||
|
||||
llvm::isa<TF::IdentityOp>(op)) {
|
||||
value = op->getOperand(res_num);
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return arg_index;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) {
|
||||
@ -108,7 +142,8 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
result_ids.insert(operand_it->getSecond().begin(),
|
||||
operand_it->getSecond().end());
|
||||
};
|
||||
// TODO(yuanzx): Consider control-flow ops.
|
||||
auto module = func_op.getParentOfType<ModuleOp>();
|
||||
|
||||
func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||
resource_value_to_ids_[var_handle.resource()].insert(
|
||||
@ -122,7 +157,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
std::get<1>(operand_and_result));
|
||||
}
|
||||
} else if (auto replicate = llvm::dyn_cast<tf_device::ReplicateOp>(op)) {
|
||||
// The nested block for RepliateOp is handled separately in side-effect
|
||||
// The nested block for ReplicateOp is handled separately in side-effect
|
||||
// analysis. Inside that block, we can still treat its block arguments as
|
||||
// different resources.
|
||||
for (auto arg : replicate.GetBody().getArguments()) {
|
||||
@ -131,6 +166,49 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
resource_value_to_ids_[arg].insert(next_unique_id++);
|
||||
}
|
||||
}
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
|
||||
// If a result is a passthrough of the body input, use the corresponding
|
||||
// operand's resource IDs.
|
||||
for (auto result : llvm::enumerate(while_op.getResults())) {
|
||||
if (!mlir::getElementTypeOrSelf(result.value()->getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
int64_t passthrough_operand =
|
||||
FindPassthroughArgumentForReturnValue(result.index(), body);
|
||||
if (passthrough_operand >= 0) {
|
||||
forward_input_to_output(while_op.getOperand(passthrough_operand),
|
||||
result.value());
|
||||
} else {
|
||||
resource_value_to_ids_[result.value()].insert(kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
|
||||
auto then_branch =
|
||||
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));
|
||||
auto else_branch =
|
||||
llvm::cast<FuncOp>(module.lookupSymbol(if_op.else_branch()));
|
||||
// If a result is a passthrough of both branches' inputs, merge the
|
||||
// resource IDs of corresponding operands for the two inputs.
|
||||
for (auto result : llvm::enumerate(if_op.getResults())) {
|
||||
if (!mlir::getElementTypeOrSelf(result.value()->getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
int64_t passthrough_then_arg =
|
||||
FindPassthroughArgumentForReturnValue(result.index(), then_branch);
|
||||
int64_t passthrough_else_arg =
|
||||
FindPassthroughArgumentForReturnValue(result.index(), else_branch);
|
||||
if (passthrough_then_arg >= 0 && passthrough_else_arg >= 0) {
|
||||
forward_input_to_output(if_op.getOperand(passthrough_then_arg + 1),
|
||||
result.value());
|
||||
forward_input_to_output(if_op.getOperand(passthrough_else_arg + 1),
|
||||
result.value());
|
||||
} else {
|
||||
resource_value_to_ids_[result.value()].insert(kUnknownResourceId);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto result : op->getResults()) {
|
||||
if (!mlir::getElementTypeOrSelf(result->getType())
|
||||
@ -223,6 +301,18 @@ bool OpIsDeclaration(Operation* op,
|
||||
!FindAccessedResources(op, alias_analysis).empty());
|
||||
}
|
||||
|
||||
// Returns if `op` is know to not have any side effect.
|
||||
bool OpIsKnownToHaveNoSideEffect(Operation* op) {
|
||||
if (op->hasNoSideEffect()) return true;
|
||||
if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) {
|
||||
return if_op.is_stateless();
|
||||
}
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
return while_op.is_stateless();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
|
||||
@ -242,11 +332,15 @@ void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op,
|
||||
auto& info = per_resource_access_info_[resource_id];
|
||||
if (read_only) {
|
||||
info.reads_since_last_write.push_back(op);
|
||||
// Resource read must have carried control dependencies of unknown write.
|
||||
info.tracked_last_unknown_write = true;
|
||||
// Resource read must have carried control dependencies of unknown write. It
|
||||
// can only avoid adding control edges (from uknown accesses) for a later
|
||||
// write, but not for a later read, because this read can be reordered with
|
||||
// a later read.
|
||||
info.tracked_last_unknown_write_for_write = true;
|
||||
} else {
|
||||
// Resource write must have carried control dependencies of unknown access.
|
||||
info.tracked_last_unknown_write = true;
|
||||
info.tracked_last_unknown_write_for_read = true;
|
||||
info.tracked_last_unknown_write_for_write = true;
|
||||
info.tracked_last_unknown_read = true;
|
||||
info.last_write = op;
|
||||
info.reads_since_last_write.clear();
|
||||
@ -305,7 +399,7 @@ void SideEffectAnalysis::AnalyzeRegion(
|
||||
// region, and tracking resource accesses in per_resource_access_info_.
|
||||
|
||||
// Returns whether an access to `resource` can skip control edges from
|
||||
// prevoius accesses to unknown resources, due to that earlier accesses to
|
||||
// previous accesses to unknown resources, due to that earlier accesses to
|
||||
// `resource` already indirectly tracked previous accesses to uknown
|
||||
// resources. `read_only` specifies the type of access of the current op being
|
||||
// considered.
|
||||
@ -318,8 +412,8 @@ void SideEffectAnalysis::AnalyzeRegion(
|
||||
unknown_it == per_resource_access_info_.end() ||
|
||||
unknown_it->getSecond().reads_since_last_write.empty();
|
||||
return read_only
|
||||
? it->second.tracked_last_unknown_write
|
||||
: it->second.tracked_last_unknown_write &&
|
||||
? it->second.tracked_last_unknown_write_for_read
|
||||
: it->second.tracked_last_unknown_write_for_write &&
|
||||
(it->second.tracked_last_unknown_read || no_unknown_read);
|
||||
};
|
||||
|
||||
@ -340,7 +434,7 @@ void SideEffectAnalysis::AnalyzeRegion(
|
||||
if (OpIsDeclaration(&op, alias_analysis)) continue;
|
||||
|
||||
auto resource_op_info = GetResourceInfoForOp(&op);
|
||||
if (!resource_op_info && op.hasNoSideEffect()) continue;
|
||||
if (!resource_op_info && OpIsKnownToHaveNoSideEffect(&op)) continue;
|
||||
|
||||
llvm::SmallDenseSet<int64_t, 8> resources =
|
||||
resource_op_info ? FindAccessedResources(&op, alias_analysis)
|
||||
|
@ -105,7 +105,7 @@ class SideEffectAnalysis {
|
||||
void ConsumeChildAnalyses(
|
||||
llvm::SmallVector<SideEffectAnalysis, 4>&& children);
|
||||
|
||||
// Updates control_predecessors_ for `op` that is being visted, on the given
|
||||
// Updates control_predecessors_ for `op` that is being visited, on the given
|
||||
// `resource_id`.
|
||||
void AddPredecessorsForAccess(int64_t resource_id, Operation* op,
|
||||
bool read_only);
|
||||
@ -124,17 +124,22 @@ class SideEffectAnalysis {
|
||||
sorted_control_successors_;
|
||||
|
||||
// Internal per-resource data structure when we build the dependencies.
|
||||
struct PerResourceAcessInfo {
|
||||
struct PerResourceAccessInfo {
|
||||
// Last op that writes the resource before the current op being analyzed.
|
||||
Operation* last_write = nullptr;
|
||||
// Read ops since last_write before the current op being analyzed.
|
||||
llvm::SmallVector<Operation*, 8> reads_since_last_write;
|
||||
// Whether previous accesses of this resource already tracked last unknown
|
||||
// read/write.
|
||||
// read for the current access being analyzed.
|
||||
bool tracked_last_unknown_read = false;
|
||||
bool tracked_last_unknown_write = false;
|
||||
// Whether previous accesses of this resource already tracked last unknown
|
||||
// write for a the current read being analyzed.
|
||||
bool tracked_last_unknown_write_for_read = false;
|
||||
// Whether previous accesses of this resource already tracked last unknown
|
||||
// write for a the current write being analyzed.
|
||||
bool tracked_last_unknown_write_for_write = false;
|
||||
};
|
||||
llvm::SmallDenseMap<int64_t, PerResourceAcessInfo, 8>
|
||||
llvm::SmallDenseMap<int64_t, PerResourceAccessInfo, 8>
|
||||
per_resource_access_info_;
|
||||
};
|
||||
|
||||
|
@ -1269,6 +1269,93 @@ def TF_DivNoNanOp : TF_Op<"DivNoNan", [Broadcastable, NoSideEffect]>,
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [NoSideEffect, SameVariadicOperandSize]> {
|
||||
let summary = [{
|
||||
Interleave the values from the `data` tensors into a single tensor.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Builds a merged tensor such that
|
||||
|
||||
```python
|
||||
merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]
|
||||
```
|
||||
|
||||
For example, if each `indices[m]` is scalar or vector, we have
|
||||
|
||||
```python
|
||||
# Scalar indices:
|
||||
merged[indices[m], ...] = data[m][...]
|
||||
|
||||
# Vector indices:
|
||||
merged[indices[m][i], ...] = data[m][i, ...]
|
||||
```
|
||||
|
||||
Each `data[i].shape` must start with the corresponding `indices[i].shape`,
|
||||
and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we
|
||||
must have `data[i].shape = indices[i].shape + constant`. In terms of this
|
||||
`constant`, the output shape is
|
||||
|
||||
merged.shape = [max(indices)] + constant
|
||||
|
||||
Values are merged in order, so if an index appears in both `indices[m][i]` and
|
||||
`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the
|
||||
merged result. If you do not need this guarantee, ParallelDynamicStitch might
|
||||
perform better on some devices.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
indices[0] = 6
|
||||
indices[1] = [4, 1]
|
||||
indices[2] = [[5, 2], [0, 3]]
|
||||
data[0] = [61, 62]
|
||||
data[1] = [[41, 42], [11, 12]]
|
||||
data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]
|
||||
merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],
|
||||
[51, 52], [61, 62]]
|
||||
```
|
||||
|
||||
This method can be used to merge partitions created by `dynamic_partition`
|
||||
as illustrated on the following example:
|
||||
|
||||
```python
|
||||
# Apply function (increments x_i) on elements for which a certain condition
|
||||
# apply (x_i != -1 in this example).
|
||||
x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])
|
||||
condition_mask=tf.not_equal(x,tf.constant(-1.))
|
||||
partitioned_data = tf.dynamic_partition(
|
||||
x, tf.cast(condition_mask, tf.int32) , 2)
|
||||
partitioned_data[1] = partitioned_data[1] + 1.0
|
||||
condition_indices = tf.dynamic_partition(
|
||||
tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)
|
||||
x = tf.dynamic_stitch(condition_indices, partitioned_data)
|
||||
# Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
|
||||
# unchanged.
|
||||
```
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/DynamicStitch.png" alt>
|
||||
</div>
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<I32Tensor>:$indices,
|
||||
Variadic<TF_Tensor>:$data
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$merged
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_EinsumOp : TF_Op<"Einsum", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Tensor contraction according to Einstein summation convention.
|
||||
@ -1317,7 +1404,7 @@ Operations are applied to the input(s) according to the following rules:
|
||||
Considering the batch matrix multiplication equation again
|
||||
(`bij,bjk->bik`), the contracted axis label is `j`.
|
||||
|
||||
(e) Expand Diagonal: If the output subcripts contain repeated (explicit) axis
|
||||
(e) Expand Diagonal: If the output subscripts contain repeated (explicit) axis
|
||||
labels, the opposite operation of (a) is applied. For example, in the
|
||||
equation `i->iii`, and input shape `[3]`, the output of shape `[3, 3, 3]`
|
||||
are all zeros, except for the (generalized) diagonal which is populated
|
||||
@ -1325,7 +1412,7 @@ Operations are applied to the input(s) according to the following rules:
|
||||
Note: This operation is not supported by `np.einsum` or `tf.einsum`; it is
|
||||
provided to enable computing the symbolic gradient of `tf.einsum`.
|
||||
|
||||
The output subcripts must contain only labels appearing in at least one of the
|
||||
The output subscripts must contain only labels appearing in at least one of the
|
||||
input subscripts. Furthermore, all dimensions mapping to the same axis label
|
||||
must be equal.
|
||||
|
||||
@ -1337,7 +1424,7 @@ according to standard NumPy broadcasting
|
||||
|
||||
The broadcasted dimensions are placed in the corresponding location of the
|
||||
ellipsis in the output subscript. If the broadcasted dimensions are non-empty
|
||||
and the output subcripts do not contain ellipsis, then an InvalidArgument error
|
||||
and the output subscripts do not contain ellipsis, then an InvalidArgument error
|
||||
is raised.
|
||||
|
||||
@compatibility(numpy)
|
||||
@ -2054,6 +2141,10 @@ See also `tf.batch_gather` and `tf.gather_nd`.
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr Tparams = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_GreaterOp : TF_Op<"Greater", [Broadcastable, NoSideEffect]>,
|
||||
@ -2368,6 +2459,55 @@ def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_LeftShiftOp : TF_Op<"LeftShift", [Broadcastable, NoSideEffect]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Elementwise computes the bitwise left-shift of `x` and `y`.";
|
||||
|
||||
let description = [{
|
||||
If `y` is negative, or greater than or equal to the width of `x` in bits the
|
||||
result is implementation defined.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
import numpy as np
|
||||
dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64]
|
||||
|
||||
for dtype in dtype_list:
|
||||
lhs = tf.constant([-1, -5, -3, -14], dtype=dtype)
|
||||
rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
|
||||
|
||||
left_shift_result = bitwise_ops.left_shift(lhs, rhs)
|
||||
|
||||
print(left_shift_result)
|
||||
|
||||
# This will print:
|
||||
# tf.Tensor([ -32 -5 -128 0], shape=(4,), dtype=int8)
|
||||
# tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int16)
|
||||
# tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int32)
|
||||
# tf.Tensor([ -32 -5 -384 -28672], shape=(4,), dtype=int64)
|
||||
|
||||
lhs = np.array([-2, 64, 101, 32], dtype=np.int8)
|
||||
rhs = np.array([-1, -5, -3, -14], dtype=np.int8)
|
||||
bitwise_ops.left_shift(lhs, rhs)
|
||||
# <tf.Tensor: shape=(4,), dtype=int8, numpy=array([ -2, 64, 101, 32], dtype=int8)>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntTensor:$x,
|
||||
TF_IntTensor:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntTensor:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_LessOp : TF_Op<"Less", [Broadcastable, NoSideEffect]>,
|
||||
WithBroadcastableCmpOpBuilder {
|
||||
let summary = "Returns the truth value of (x < y) element-wise.";
|
||||
@ -4366,10 +4506,10 @@ def TF_ResourceApplyAdamOp : TF_Op<"ResourceApplyAdam", []> {
|
||||
let summary = "Update '*var' according to the Adam algorithm.";
|
||||
|
||||
let description = [{
|
||||
$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
|
||||
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
|
||||
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
|
||||
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
$$\text{lr}_t := \mathrm{learning_rate} * \sqrt{1 - \beta_2^t} / (1 - \beta_1^t)$$
|
||||
$$m_t := \beta_1 * m_{t-1} + (1 - \beta_1) * g$$
|
||||
$$v_t := \beta_2 * v_{t-1} + (1 - \beta_2) * g * g$$
|
||||
$$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilon)$$
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@ -4413,9 +4553,7 @@ def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []
|
||||
}
|
||||
|
||||
def TF_ResourceApplyKerasMomentumOp : TF_Op<"ResourceApplyKerasMomentum", []> {
|
||||
let summary = [{
|
||||
Update '*var' according to the momentum scheme.
|
||||
}];
|
||||
let summary = "Update '*var' according to the momentum scheme.";
|
||||
|
||||
let description = [{
|
||||
Set use_nesterov = True if you want to use Nesterov momentum.
|
||||
@ -4581,6 +4719,58 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11],
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_RightShiftOp : TF_Op<"RightShift", [Broadcastable, NoSideEffect]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Elementwise computes the bitwise right-shift of `x` and `y`.";
|
||||
|
||||
let description = [{
|
||||
Performs a logical shift for unsigned integer types, and an arithmetic shift
|
||||
for signed integer types.
|
||||
|
||||
If `y` is negative, or greater than or equal to than the width of `x` in bits
|
||||
the result is implementation defined.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
import numpy as np
|
||||
dtype_list = [tf.int8, tf.int16, tf.int32, tf.int64]
|
||||
|
||||
for dtype in dtype_list:
|
||||
lhs = tf.constant([-1, -5, -3, -14], dtype=dtype)
|
||||
rhs = tf.constant([5, 0, 7, 11], dtype=dtype)
|
||||
|
||||
right_shift_result = bitwise_ops.right_shift(lhs, rhs)
|
||||
|
||||
print(right_shift_result)
|
||||
|
||||
# This will print:
|
||||
# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int8)
|
||||
# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int16)
|
||||
# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int32)
|
||||
# tf.Tensor([-1 -5 -1 -1], shape=(4,), dtype=int64)
|
||||
|
||||
lhs = np.array([-2, 64, 101, 32], dtype=np.int8)
|
||||
rhs = np.array([-1, -5, -3, -14], dtype=np.int8)
|
||||
bitwise_ops.right_shift(lhs, rhs)
|
||||
# <tf.Tensor: shape=(4,), dtype=int8, numpy=array([ -2, 64, 101, 32], dtype=int8)>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntTensor:$x,
|
||||
TF_IntTensor:$y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntTensor:$z
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
@ -5122,6 +5312,34 @@ x = [[[[1, 2, 3, 4],
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SparseSoftmaxCrossEntropyWithLogitsOp : TF_Op<"SparseSoftmaxCrossEntropyWithLogits", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes softmax cross entropy cost and gradients to backpropagate.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
|
||||
a matrix of label probabilities, but rather a single label per row
|
||||
of features. This label is considered to have probability 1.0 for the
|
||||
given row.
|
||||
|
||||
Inputs are the logits, not probabilities.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$features,
|
||||
TF_I32OrI64Tensor:$labels
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$loss,
|
||||
TF_FpTensor:$backprop
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tlabels = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_SparseToDenseOp : TF_Op<"SparseToDense", [NoSideEffect]> {
|
||||
let summary = "Converts a sparse representation into a dense tensor.";
|
||||
|
||||
@ -5455,8 +5673,67 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
let verifier = [{ return VerifyStridedSliceBase(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// If sliced shape is able to be deduced, returns true, updates
|
||||
// `begin_indices`, `end_indices`, and `strides` with their canonical
|
||||
// values, respectively.
|
||||
bool GetSlicedBoundRanges(
|
||||
::llvm::ArrayRef<int64_t> shape,
|
||||
::llvm::SmallVectorImpl<int64_t> *begin_indices,
|
||||
::llvm::SmallVectorImpl<int64_t> *end_indices,
|
||||
::llvm::SmallVectorImpl<int64_t> *strides);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_StridedSliceGradOp : TF_Op<"StridedSliceGrad", [NoSideEffect]> {
|
||||
let summary = "Returns the gradient of `StridedSlice`.";
|
||||
|
||||
let description = [{
|
||||
Since `StridedSlice` cuts out pieces of its `input` which is size
|
||||
`shape`, its gradient will have the same shape (which is passed here
|
||||
as `shape`). The gradient will be zero in any element that the slice
|
||||
does not select.
|
||||
|
||||
Arguments are the same as StridedSliceGrad with the exception that
|
||||
`dy` is the input gradient to be propagated and `shape` is the
|
||||
shape of `StridedSlice`'s `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TF_I32OrI64Tensor:$begin,
|
||||
TF_I32OrI64Tensor:$end,
|
||||
TF_I32OrI64Tensor:$strides,
|
||||
TF_Tensor:$dy,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "0">:$begin_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$end_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$ellipsis_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$new_axis_mask,
|
||||
DefaultValuedAttr<I64Attr, "0">:$shrink_axis_mask
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>;
|
||||
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// If sliced shape is able to be deduced, returns true, updates `shape`
|
||||
// with the final shape after performing StridedSlice, and updates
|
||||
// `begin_indices`, `end_indices`, and `strides` with their canonical
|
||||
// values, respectively.
|
||||
bool GetSlicedShapeAndBoundRanges(
|
||||
::llvm::SmallVectorImpl<int64_t> *shape,
|
||||
::llvm::SmallVectorImpl<int64_t> *begin_indices,
|
||||
::llvm::SmallVectorImpl<int64_t> *end_indices,
|
||||
::llvm::SmallVectorImpl<int64_t> *strides);
|
||||
}];
|
||||
}
|
||||
|
||||
@ -5986,7 +6263,7 @@ def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> {
|
||||
let description = [{
|
||||
This operation returns a tensor `y` containing all of the unique elements of `x`
|
||||
sorted in the same order that they occur in `x`; `x` does not need to be sorted.
|
||||
This operation also returns a tensor `idx` the same size as `x` that contains
|
||||
This operation also returns a tensor `idx` the same size as `x` that contains
|
||||
the index of each value of `x` in the unique output `y`. In other words:
|
||||
|
||||
`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
|
||||
@ -6057,6 +6334,205 @@ This is the opposite of `pack`.
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TF_UnsortedSegmentMaxOp : TF_Op<"UnsortedSegmentMax", [NoSideEffect]> {
|
||||
let summary = "Computes the maximum along segments of a tensor.";
|
||||
|
||||
let description = [{
|
||||
Read
|
||||
[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
|
||||
for an explanation of segments.
|
||||
|
||||
This operator is similar to the unsorted segment sum operator found
|
||||
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
|
||||
Instead of computing the sum over segments, it computes the maximum such that:
|
||||
|
||||
\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
|
||||
that `segment_ids[j...] == i`.
|
||||
|
||||
If the maximum is empty for a given segment ID `i`, it outputs the smallest
|
||||
possible value for the specific numeric type,
|
||||
`output[i] = numeric_limits<T>::lowest()`.
|
||||
|
||||
If the given segment ID `i` is negative, then the corresponding value is
|
||||
dropped, and will not be included in the result.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
|
||||
</div>
|
||||
|
||||
For example:
|
||||
|
||||
``` python
|
||||
c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
|
||||
tf.unsorted_segment_max(c, tf.constant([0, 1, 0]), num_segments=2)
|
||||
# ==> [[ 4, 3, 3, 4],
|
||||
# [5, 6, 7, 8]]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$data,
|
||||
TF_I32OrI64Tensor:$segment_ids,
|
||||
TF_I32OrI64Tensor:$num_segments
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntOrFpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tnumsegments = TF_DerivedOperandTypeAttr<2>;
|
||||
|
||||
let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }];
|
||||
}
|
||||
|
||||
def TF_UnsortedSegmentMinOp : TF_Op<"UnsortedSegmentMin", [NoSideEffect]> {
|
||||
let summary = "Computes the minimum along segments of a tensor.";
|
||||
|
||||
let description = [{
|
||||
Read
|
||||
[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
|
||||
for an explanation of segments.
|
||||
|
||||
This operator is similar to the unsorted segment sum operator found
|
||||
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
|
||||
Instead of computing the sum over segments, it computes the minimum such that:
|
||||
|
||||
\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
|
||||
that `segment_ids[j...] == i`.
|
||||
|
||||
If the minimum is empty for a given segment ID `i`, it outputs the largest
|
||||
possible value for the specific numeric type,
|
||||
`output[i] = numeric_limits<T>::max()`.
|
||||
|
||||
For example:
|
||||
|
||||
``` python
|
||||
c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
|
||||
tf.unsorted_segment_min(c, tf.constant([0, 1, 0]), num_segments=2)
|
||||
# ==> [[ 1, 2, 2, 1],
|
||||
# [5, 6, 7, 8]]
|
||||
```
|
||||
|
||||
If the given segment ID `i` is negative, then the corresponding value is
|
||||
dropped, and will not be included in the result.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$data,
|
||||
TF_I32OrI64Tensor:$segment_ids,
|
||||
TF_I32OrI64Tensor:$num_segments
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntOrFpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tnumsegments = TF_DerivedOperandTypeAttr<2>;
|
||||
|
||||
let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }];
|
||||
}
|
||||
|
||||
def TF_UnsortedSegmentProdOp : TF_Op<"UnsortedSegmentProd", [NoSideEffect]> {
|
||||
let summary = "Computes the product along segments of a tensor.";
|
||||
|
||||
let description = [{
|
||||
Read
|
||||
[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
|
||||
for an explanation of segments.
|
||||
|
||||
This operator is similar to the unsorted segment sum operator found
|
||||
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
|
||||
Instead of computing the sum over segments, it computes the product of all
|
||||
entries belonging to a segment such that:
|
||||
|
||||
\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
|
||||
`j...` such that `segment_ids[j...] == i`.
|
||||
|
||||
For example:
|
||||
|
||||
``` python
|
||||
c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
|
||||
tf.unsorted_segment_prod(c, tf.constant([0, 1, 0]), num_segments=2)
|
||||
# ==> [[ 4, 6, 6, 4],
|
||||
# [5, 6, 7, 8]]
|
||||
```
|
||||
|
||||
If there is no entry for a given segment ID `i`, it outputs 1.
|
||||
|
||||
If the given segment ID `i` is negative, then the corresponding value is
|
||||
dropped, and will not be included in the result.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data,
|
||||
TF_I32OrI64Tensor:$segment_ids,
|
||||
TF_I32OrI64Tensor:$num_segments
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tnumsegments = TF_DerivedOperandTypeAttr<2>;
|
||||
|
||||
let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }];
|
||||
}
|
||||
|
||||
def TF_UnsortedSegmentSumOp : TF_Op<"UnsortedSegmentSum", [NoSideEffect]> {
|
||||
let summary = "Computes the sum along segments of a tensor.";
|
||||
|
||||
let description = [{
|
||||
Read
|
||||
[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
|
||||
for an explanation of segments.
|
||||
|
||||
Computes a tensor such that
|
||||
\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
|
||||
that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
|
||||
need not be sorted and need not cover all values in the full
|
||||
range of valid values.
|
||||
|
||||
If the sum is empty for a given segment ID `i`, `output[i] = 0`.
|
||||
If the given segment ID `i` is negative, the value is dropped and will not be
|
||||
added to the sum of the segment.
|
||||
|
||||
`num_segments` should equal the number of distinct segment IDs.
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt>
|
||||
</div>
|
||||
|
||||
``` python
|
||||
c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
|
||||
tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2)
|
||||
# ==> [[ 5, 5, 5, 5],
|
||||
# [5, 6, 7, 8]]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data,
|
||||
TF_I32OrI64Tensor:$segment_ids,
|
||||
TF_I32OrI64Tensor:$num_segments
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tnumsegments = TF_DerivedOperandTypeAttr<2>;
|
||||
|
||||
let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }];
|
||||
}
|
||||
|
||||
def TF_VariableShapeOp : TF_Op<"VariableShape", []> {
|
||||
let summary = "Returns the shape of the variable pointed to by `resource`.";
|
||||
|
||||
|
@ -24,6 +24,8 @@ limitations under the License.
|
||||
#include <type_traits>
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -726,6 +728,101 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<DivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicStitchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(DynamicStitchOp op) {
|
||||
if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1");
|
||||
|
||||
if (RankedTensorType out_ty = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (out_ty.getRank() == 0) {
|
||||
return op.emitOpError("requires non scalar output");
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallDenseSet<int64_t, 8> index_values;
|
||||
bool all_indices_const = true;
|
||||
int32_t max_index = -1;
|
||||
llvm::Optional<SmallVector<int64_t, 4>> inferred_item_shape;
|
||||
for (auto it : llvm::zip(op.indices(), op.data())) {
|
||||
Value *index = std::get<0>(it);
|
||||
|
||||
DenseIntElementsAttr index_attr;
|
||||
if (matchPattern(index, m_Constant(&index_attr))) {
|
||||
for (int32_t index : index_attr.getValues<int32_t>()) {
|
||||
if (index < 0)
|
||||
return op.emitOpError()
|
||||
<< "requires non-negative index values; found " << index;
|
||||
max_index = std::max(index, max_index);
|
||||
index_values.insert(index);
|
||||
}
|
||||
} else {
|
||||
all_indices_const = false;
|
||||
}
|
||||
|
||||
Value *data = std::get<1>(it);
|
||||
RankedTensorType index_ty = index->getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType data_ty = data->getType().dyn_cast<RankedTensorType>();
|
||||
if (!index_ty || !data_ty) continue;
|
||||
|
||||
int64_t index_rank = index_ty.getRank();
|
||||
ArrayRef<int64_t> data_shape = data_ty.getShape();
|
||||
ArrayRef<int64_t> index_shape = index_ty.getShape();
|
||||
if (failed(mlir::verifyCompatibleShape(index_shape,
|
||||
data_shape.take_front(index_rank))))
|
||||
return op.emitOpError() << "requires shape of data with type " << data_ty
|
||||
<< " to have prefix matching with shape of the "
|
||||
"corresponding index type "
|
||||
<< index_ty;
|
||||
|
||||
ArrayRef<int64_t> item_shape = data_shape.drop_front(index_rank);
|
||||
if (!inferred_item_shape) {
|
||||
inferred_item_shape = llvm::to_vector<4>(item_shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape)))
|
||||
return op.emitOpError() << "has inconsistent shaped data and index "
|
||||
"pairs; inferred item shapes ["
|
||||
<< llvm::makeArrayRef(*inferred_item_shape)
|
||||
<< "] and [" << item_shape << "] don't match";
|
||||
for (int i = 0, e = item_shape.size(); i < e; ++i) {
|
||||
int64_t &inferred_dim = (*inferred_item_shape)[i];
|
||||
int64_t dim = item_shape[i];
|
||||
if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim;
|
||||
}
|
||||
}
|
||||
|
||||
// If all indices are constants, then verify that they cover all indices in
|
||||
// the range [0, max_index] and the output type is legal.
|
||||
if (all_indices_const) {
|
||||
for (int32_t i = 0; i <= max_index; i++) {
|
||||
if (!index_values.count(i))
|
||||
return op.emitOpError() << "missing index " << i;
|
||||
}
|
||||
|
||||
if (inferred_item_shape) {
|
||||
SmallVector<int64_t, 4> expected_shape;
|
||||
expected_shape.push_back(max_index + 1);
|
||||
expected_shape.append(inferred_item_shape->begin(),
|
||||
inferred_item_shape->end());
|
||||
|
||||
auto out_ty = op.getType().cast<TensorType>();
|
||||
auto expected_out_ty =
|
||||
RankedTensorType::get(expected_shape, out_ty.getElementType());
|
||||
|
||||
if (!AreCastCompatible(out_ty, expected_out_ty)) {
|
||||
return op.emitOpError() << "has invalid output type; should be "
|
||||
"compatible with inferred type "
|
||||
<< expected_out_ty;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EinsumOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -891,6 +988,44 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(GatherV2Op op) {
|
||||
int64_t batch_dims = op.batch_dims().getSExtValue();
|
||||
if (auto ty = op.indices()->getType().dyn_cast<RankedTensorType>()) {
|
||||
int64_t rank = ty.getRank();
|
||||
if (batch_dims > rank || batch_dims < -rank)
|
||||
return op.emitOpError()
|
||||
<< "batch_dims (" << batch_dims << ") must be in range [" << -rank
|
||||
<< ", " << rank + 1 << ")";
|
||||
if (batch_dims < 0) batch_dims += rank;
|
||||
}
|
||||
|
||||
if (!HasRankAtMost(op.axis(), 1))
|
||||
return op.emitOpError("requires axis to have rank at most 1");
|
||||
|
||||
DenseIntElementsAttr axis_attr;
|
||||
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
|
||||
int64_t axis = (*axis_attr.begin()).getSExtValue();
|
||||
if (auto ty = op.params()->getType().dyn_cast<RankedTensorType>()) {
|
||||
int64_t rank = ty.getRank();
|
||||
if (axis >= rank || axis < -rank)
|
||||
return op.emitOpError() << "axis (" << axis << ") must be in range ["
|
||||
<< -rank << ", " << rank << ")";
|
||||
if (axis < 0) axis += rank;
|
||||
}
|
||||
|
||||
if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) {
|
||||
return op.emitOpError() << "requires axis (" << axis
|
||||
<< ") to be greater than or equal to batch_dims ("
|
||||
<< batch_dims << ")";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1752,12 +1887,14 @@ void SumOp::build(Builder *builder, OperationState &result, Value *input,
|
||||
// elements. Here, the number of elements should be less than 32 to support
|
||||
// 32-bit mask attributes.
|
||||
// - None of the strides values are zero.
|
||||
//
|
||||
static LogicalResult Verify(StridedSliceOp op) {
|
||||
// - Ellipsis mask can have at most one bit set.
|
||||
|
||||
template <class OpTy>
|
||||
static LogicalResult VerifyStridedSliceBase(OpTy op) {
|
||||
// Expected size for operands begin, end and strides vector operands.
|
||||
int64_t expected_size = -1;
|
||||
|
||||
for (Value *val : llvm::drop_begin(op.getOperands(), 1)) {
|
||||
for (Value *val : {op.begin(), op.end(), op.strides()}) {
|
||||
auto operand_ty = val->getType().dyn_cast<ShapedType>();
|
||||
if (!operand_ty || !operand_ty.hasStaticShape()) {
|
||||
// TensorFlow constant ops may have non-static shape because the shape is
|
||||
@ -1797,11 +1934,179 @@ static LogicalResult Verify(StridedSliceOp op) {
|
||||
return op.emitOpError("requires non-zero strides");
|
||||
}
|
||||
|
||||
// TODO(hinsu): Validate attributes.
|
||||
// Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there
|
||||
// exists only no more than one ellipsis.
|
||||
uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue();
|
||||
if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask))
|
||||
return op.emitOpError("cannot have multiple ellipses");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Clamps the given `val`: returns `low` if `val` is less than `low`; returns
|
||||
// `high` if `high` is less than `val`; otherwise returns `val`.
|
||||
template <class T>
|
||||
constexpr const T &Clamp(const T &val, const T &low, const T &high) {
|
||||
assert(!(high < low));
|
||||
return (val < low) ? low : (high < val) ? high : val;
|
||||
}
|
||||
|
||||
// For the given `input_shape`, calculates the sliced shape using the given
|
||||
// `begin`, `end`, and `stride` ranges and `begin_mask` and `end_mask` masks.
|
||||
// Updates the result back to `input_shape`. At the same time, canonicalizes
|
||||
// `begin`, `end`, and `strides. The calculation follows tf.StridedSlice op
|
||||
// semantics.
|
||||
static void CalculateSlicedShapeAndBoundRanges(
|
||||
MutableArrayRef<int64_t> input_shape, int32_t begin_mask, int32_t end_mask,
|
||||
MutableArrayRef<int64_t> begin, MutableArrayRef<int64_t> end,
|
||||
MutableArrayRef<int64_t> stride) {
|
||||
assert(input_shape.size() <= 32); // Only 32-bit masks are supported.
|
||||
|
||||
// Make sure ranges' ranks are consistent with the input.
|
||||
assert(input_shape.size() == begin.size());
|
||||
assert(input_shape.size() == end.size());
|
||||
assert(input_shape.size() == stride.size());
|
||||
|
||||
for (int i = 0, e = input_shape.size(); i < e; ++i) {
|
||||
if (ShapedType::isDynamic(input_shape[i])) continue;
|
||||
|
||||
int64_t dim_i = input_shape[i];
|
||||
int64_t begin_i = begin[i];
|
||||
int64_t end_i = end[i];
|
||||
int64_t stride_i = stride[i];
|
||||
|
||||
// [0]: mask for begin, [1]: mask for end
|
||||
int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)};
|
||||
// [0]: bound for begin, [1]: bound for end
|
||||
int64_t bounds[] = {stride_i > 0 ? 0 : -1,
|
||||
stride_i > 0 ? dim_i : dim_i - 1};
|
||||
|
||||
// Canonicalizes the given range `point` (begin/end) according to the
|
||||
// current dimension. `c` means case: 0 for begin, 1 for end.
|
||||
auto canonicalize = [&](int64_t point, int c) {
|
||||
if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1];
|
||||
|
||||
// Add dim as offset to negative range point.
|
||||
point = point < 0 ? dim_i + point : point;
|
||||
return Clamp(point, bounds[0], bounds[1]);
|
||||
};
|
||||
|
||||
begin_i = canonicalize(begin_i, 0);
|
||||
end_i = canonicalize(end_i, 1);
|
||||
|
||||
int64_t interval_len = end_i - begin_i;
|
||||
int64_t size_i = 0;
|
||||
// If internal length is zero or has different sign from stride, it's a
|
||||
// degenerated case: we are slicing nothing. Otherwise, calculate the sliced
|
||||
// size.
|
||||
if (interval_len != 0 && (interval_len < 0) == (stride_i < 0))
|
||||
size_i = (interval_len / stride_i) + (interval_len % stride_i != 0);
|
||||
|
||||
input_shape[i] = size_i;
|
||||
begin[i] = begin_i;
|
||||
end[i] = end_i;
|
||||
stride[i] = stride_i;
|
||||
}
|
||||
}
|
||||
|
||||
bool StridedSliceOp::GetSlicedBoundRanges(
|
||||
ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> *begin_indices,
|
||||
SmallVectorImpl<int64_t> *end_indices, SmallVectorImpl<int64_t> *strides) {
|
||||
if (this->ellipsis_mask().getZExtValue() ||
|
||||
this->new_axis_mask().getZExtValue() ||
|
||||
this->shrink_axis_mask().getZExtValue())
|
||||
return false; // TODO(antiagainst): support these masks
|
||||
|
||||
// TODO(hinsu): Support lowering for ops with dynamic begin and end values
|
||||
// when it is possible to derive indices based on mask attributes.
|
||||
DenseIntElementsAttr begin_indices_attr, end_indices_attr, strides_attr;
|
||||
if (!matchPattern(this->begin(), m_Constant(&begin_indices_attr)) ||
|
||||
!matchPattern(this->end(), m_Constant(&end_indices_attr)) ||
|
||||
!matchPattern(this->strides(), m_Constant(&strides_attr)))
|
||||
return false;
|
||||
|
||||
auto input_shape = llvm::to_vector<4>(shape);
|
||||
int rank = input_shape.size();
|
||||
|
||||
begin_indices->clear();
|
||||
begin_indices->reserve(rank);
|
||||
end_indices->clear();
|
||||
end_indices->reserve(rank);
|
||||
strides->clear();
|
||||
strides->reserve(rank);
|
||||
|
||||
for (const APInt &index : begin_indices_attr)
|
||||
begin_indices->push_back(index.getSExtValue());
|
||||
for (const APInt &index : end_indices_attr)
|
||||
end_indices->push_back(index.getSExtValue());
|
||||
for (const APInt &stride : strides_attr)
|
||||
strides->push_back(stride.getSExtValue());
|
||||
|
||||
CalculateSlicedShapeAndBoundRanges(
|
||||
input_shape, this->begin_mask().getZExtValue(),
|
||||
this->end_mask().getZExtValue(), *begin_indices, *end_indices, *strides);
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StridedSliceGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(StridedSliceGradOp op) {
|
||||
auto shape_type = op.shape()->getType().dyn_cast<RankedTensorType>();
|
||||
if (shape_type && shape_type.getRank() != 1)
|
||||
return op.emitOpError("'shape' operand must be 1D tensor, but got ")
|
||||
<< shape_type.getRank() << "D tensor";
|
||||
|
||||
if (failed(VerifyStridedSliceBase(op))) return failure();
|
||||
|
||||
// TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with
|
||||
// the sliced type from StridedSlice.
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges(
|
||||
SmallVectorImpl<int64_t> *shape, SmallVectorImpl<int64_t> *begin_indices,
|
||||
SmallVectorImpl<int64_t> *end_indices, SmallVectorImpl<int64_t> *strides) {
|
||||
if (this->ellipsis_mask().getZExtValue() ||
|
||||
this->new_axis_mask().getZExtValue() ||
|
||||
this->shrink_axis_mask().getZExtValue())
|
||||
return false; // TODO(antiagainst): support these masks
|
||||
|
||||
DenseIntElementsAttr shape_attr;
|
||||
DenseIntElementsAttr begin_indices_attr, end_indices_attr, strides_attr;
|
||||
if (!matchPattern(this->shape(), m_Constant(&shape_attr)) ||
|
||||
!matchPattern(this->begin(), m_Constant(&begin_indices_attr)) ||
|
||||
!matchPattern(this->end(), m_Constant(&end_indices_attr)) ||
|
||||
!matchPattern(this->strides(), m_Constant(&strides_attr)))
|
||||
return false;
|
||||
|
||||
int rank = std::distance(shape_attr.begin(), shape_attr.end());
|
||||
|
||||
shape->clear();
|
||||
shape->reserve(rank);
|
||||
begin_indices->clear();
|
||||
begin_indices->reserve(rank);
|
||||
end_indices->clear();
|
||||
end_indices->reserve(rank);
|
||||
strides->clear();
|
||||
strides->reserve(rank);
|
||||
|
||||
for (const APInt &dim : shape_attr) shape->push_back(dim.getSExtValue());
|
||||
for (const APInt &index : begin_indices_attr)
|
||||
begin_indices->push_back(index.getSExtValue());
|
||||
for (const APInt &index : end_indices_attr)
|
||||
end_indices->push_back(index.getSExtValue());
|
||||
for (const APInt &stride : strides_attr)
|
||||
strides->push_back(stride.getSExtValue());
|
||||
|
||||
CalculateSlicedShapeAndBoundRanges(*shape, this->begin_mask().getZExtValue(),
|
||||
this->end_mask().getZExtValue(),
|
||||
*begin_indices, *end_indices, *strides);
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorListReserveOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1943,6 +2248,49 @@ static LogicalResult Verify(UnpackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unsorted segment reduction ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <class Op>
|
||||
static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
|
||||
if (!HasRankAtMost(op.num_segments(), 0))
|
||||
return op.emitOpError("number of segments should be a 0-D tensor");
|
||||
|
||||
auto data_type = op.data()->getType().template dyn_cast<RankedTensorType>();
|
||||
auto segment_ids_type =
|
||||
op.segment_ids()->getType().template dyn_cast<RankedTensorType>();
|
||||
if (data_type && segment_ids_type) {
|
||||
if (data_type.getRank() < segment_ids_type.getRank())
|
||||
return op.emitOpError(
|
||||
"requires segment ids rank to be less than or equal to data's rank");
|
||||
|
||||
int index = 0;
|
||||
for (auto shape_pair :
|
||||
llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) {
|
||||
int64_t segment_id_dim = std::get<0>(shape_pair);
|
||||
int64_t data_dim = std::get<1>(shape_pair);
|
||||
if (!ShapedType::isDynamic(segment_id_dim) &&
|
||||
!ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim)
|
||||
return op.emitOpError(
|
||||
"requires segment ids shape to be a prefix of data shape, "
|
||||
"but dimension #")
|
||||
<< index << " differs: " << segment_id_dim << " vs. "
|
||||
<< data_dim;
|
||||
++index;
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntElementsAttr num_segments_attr;
|
||||
if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) {
|
||||
int64_t num_segments = (*num_segments_attr.begin()).getSExtValue();
|
||||
if (num_segments < 0)
|
||||
return op.emitOpError("num of segments cannot be negative");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VariableShapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -15,5 +15,6 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
],
|
||||
)
|
||||
|
@ -65,3 +65,162 @@ func @decompose_resource_apply_gradient_descent(%arg0: tensor<f32>) -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyKerasMomentum (non-Nesterov) operation
|
||||
// is decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_keras_momentum_non_nesterov
|
||||
// CHECK-SAME: (%[[LR:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<f32>, %[[MOMENTUM:.*]]: tensor<f32>)
|
||||
func @decompose_resource_apply_keras_momentum_non_nesterov(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> () {
|
||||
|
||||
// CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
// CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
|
||||
// CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
|
||||
// CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[NEW_ACCUM]])
|
||||
|
||||
// CHECK: %[[VAR:[0-9]*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]])
|
||||
// CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_ACCUM]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
|
||||
|
||||
"tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyKerasMomentum (with Nesterov) operation
|
||||
// is decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_keras_momentum_nesterov
|
||||
// CHECK-SAME: (%[[LR:.*]]: tensor<f32>, %[[GRAD:.*]]: tensor<f32>, %[[MOMENTUM:.*]]: tensor<f32>)
|
||||
func @decompose_resource_apply_keras_momentum_nesterov(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> () {
|
||||
|
||||
// CHECK: %[[VAR_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
// CHECK: %[[ACCUM_HANDLE:[0-9]*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: %[[ACCUM:[0-9]*]] = "tf.ReadVariableOp"(%[[ACCUM_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: %[[ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[ACCUM]], %[[MOMENTUM]])
|
||||
// CHECK: %[[GRAD_LR:[0-9]*]] = "tf.Mul"(%[[GRAD]], %[[LR]])
|
||||
// CHECK: %[[NEW_ACCUM:[0-9]*]] = "tf.Sub"(%[[ACCUM_MOMENTUM]], %[[GRAD_LR]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[ACCUM_HANDLE]], %[[NEW_ACCUM]])
|
||||
|
||||
// CHECK: %[[NEW_ACCUM_MOMENTUM:[0-9]*]] = "tf.Mul"(%[[NEW_ACCUM]], %[[MOMENTUM]])
|
||||
// CHECK: %[[NEW_DELTA:[0-9]*]] = "tf.Sub"(%[[NEW_ACCUM_MOMENTUM]], %[[GRAD_LR]])
|
||||
// CHECK: %[[VAR:[0-9]*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]])
|
||||
// CHECK: %[[NEW_VAR:[0-9]*]] = "tf.AddV2"(%[[VAR]], %[[NEW_DELTA]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[NEW_VAR]])
|
||||
|
||||
"tf.ResourceApplyKerasMomentum"(%0, %1, %arg0, %arg1, %arg2) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyAdam (non-Nesterov) operation is
|
||||
// decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_adam_non_nesterov
|
||||
// CHECK-SAME: ([[BETA1_POWER:%.*]]: tensor<f32>, [[BETA2_POWER:%.*]]: tensor<f32>, [[LR:%.*]]: tensor<f32>, [[BETA1:%.*]]: tensor<f32>, [[BETA2:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>)
|
||||
func @decompose_resource_apply_adam_non_nesterov(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>) -> () {
|
||||
|
||||
// CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||
// CHECK: [[M_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||
// CHECK: [[V_HANDLE:%.*]] = "tf.VarHandleOp"()
|
||||
// CHECK: [[ONE_MINUS_BETA2_POWER:%.*]] = "tf.Sub"([[ONE]], [[BETA2_POWER]])
|
||||
// CHECK: [[SQRT_ONE_MINUS_BETA2_POWER:%.*]] = "tf.Sqrt"([[ONE_MINUS_BETA2_POWER]])
|
||||
// CHECK: [[ONE_MINUS_BETA1_POWER:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
|
||||
// CHECK: [[ALPHA_NO_LR:%.*]] = "tf.Div"([[SQRT_ONE_MINUS_BETA2_POWER]], [[ONE_MINUS_BETA1_POWER]])
|
||||
// CHECK: [[ALPHA:%.*]] = "tf.Mul"([[LR]], [[ALPHA_NO_LR]])
|
||||
// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[BETA1_OLD_M:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
|
||||
// CHECK: [[ONE_MINUS_BETA1:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
|
||||
// CHECK: [[ONE_MINUS_BETA1_GRAD:%.*]] = "tf.Mul"([[ONE_MINUS_BETA1]], [[GRAD]])
|
||||
// CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[BETA1_OLD_M]], [[ONE_MINUS_BETA1_GRAD]])
|
||||
// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[BETA2_OLD_V:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
|
||||
// CHECK: [[ONE_MINUS_BETA2:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
|
||||
// CHECK: [[GRAD_SQUARE:%.*]] = "tf.Square"([[GRAD]])
|
||||
// CHECK: [[V_DELTA:%.*]] = "tf.Mul"([[ONE_MINUS_BETA2]], [[GRAD_SQUARE]])
|
||||
// CHECK: [[NEW_V:%.*]] = "tf.AddV2"([[BETA2_OLD_V]], [[V_DELTA]])
|
||||
// CHECK: [[ALPHA_NEW_M:%.*]] = "tf.Mul"([[ALPHA]], [[NEW_M]])
|
||||
// CHECK: [[SQRT_NEW_V:%.*]] = "tf.Sqrt"([[NEW_V]])
|
||||
// CHECK: [[SQRT_NEW_V_EPSILON:%.*]] = "tf.AddV2"([[SQRT_NEW_V]], [[EPSILON]])
|
||||
// CHECK: [[VAR_DELTA:%.*]] = "tf.Div"([[ALPHA_NEW_M]], [[SQRT_NEW_V_EPSILON]])
|
||||
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAR_DELTA]])
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]])
|
||||
// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]])
|
||||
// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]])
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
"tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyAdam (with Nesterov) operation is
|
||||
// decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_adam_nesterov(
|
||||
// CHECK-SAME: [[BETA1_POWER:%.*]]: tensor<f32>, [[BETA2_POWER:%.*]]: tensor<f32>, [[LR:%.*]]: tensor<f32>, [[BETA1:%.*]]: tensor<f32>, [[BETA2:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>) {
|
||||
func @decompose_resource_apply_adam_nesterov(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>) -> () {
|
||||
|
||||
// CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"}
|
||||
// CHECK: [[M_HANDLE:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"}
|
||||
// CHECK: [[V_HANDLE:%.*]] = "tf.VarHandleOp"() {container = "c", shared_name = "v"}
|
||||
// CHECK: [[VAL_82:%.*]] = "tf.Sub"([[ONE]], [[BETA2_POWER]])
|
||||
// CHECK: [[VAL_83:%.*]] = "tf.Sqrt"([[VAL_82]])
|
||||
// CHECK: [[VAL_84:%.*]] = "tf.Sub"([[ONE]], [[BETA1_POWER]])
|
||||
// CHECK: [[VAL_85:%.*]] = "tf.Div"([[VAL_83]], [[VAL_84]])
|
||||
// CHECK: [[VAL_86:%.*]] = "tf.Mul"([[LR]], [[VAL_85]])
|
||||
// CHECK: [[OLD_M:%.*]] = "tf.ReadVariableOp"([[M_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[VAL_88:%.*]] = "tf.Mul"([[BETA1]], [[OLD_M]])
|
||||
// CHECK: [[VAL_89:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
|
||||
// CHECK: [[VAL_90:%.*]] = "tf.Mul"([[VAL_89]], [[GRAD]])
|
||||
// CHECK: [[NEW_M:%.*]] = "tf.AddV2"([[VAL_88]], [[VAL_90]])
|
||||
// CHECK: [[OLD_V:%.*]] = "tf.ReadVariableOp"([[V_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[VAL_93:%.*]] = "tf.Mul"([[BETA2]], [[OLD_V]])
|
||||
// CHECK: [[VAL_94:%.*]] = "tf.Sub"([[ONE]], [[BETA2]])
|
||||
// CHECK: [[VAL_95:%.*]] = "tf.Square"([[GRAD]])
|
||||
// CHECK: [[VAL_96:%.*]] = "tf.Mul"([[VAL_94]], [[VAL_95]])
|
||||
// CHECK: [[NEW_V:%.*]] = "tf.AddV2"([[VAL_93]], [[VAL_96]])
|
||||
// CHECK: [[VAL_98:%.*]] = "tf.Mul"([[NEW_M]], [[BETA1]])
|
||||
// CHECK: [[VAL_99:%.*]] = "tf.Sub"([[ONE]], [[BETA1]])
|
||||
// CHECK: [[VAL_100:%.*]] = "tf.Mul"([[VAL_99]], [[GRAD]])
|
||||
// CHECK: [[VAL_101:%.*]] = "tf.AddV2"([[VAL_98]], [[VAL_100]])
|
||||
// CHECK: [[VAL_102:%.*]] = "tf.Mul"([[VAL_86]], [[VAL_101]])
|
||||
// CHECK: [[VAL_103:%.*]] = "tf.Sqrt"([[NEW_V]])
|
||||
// CHECK: [[VAL_104:%.*]] = "tf.AddV2"([[VAL_103]], [[EPSILON]])
|
||||
// CHECK: [[VAL_105:%.*]] = "tf.Div"([[VAL_102]], [[VAL_104]])
|
||||
// CHECK: [[OLD_VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]]) : (tensor<*x!tf.resource>) -> tensor<*xf32>
|
||||
// CHECK: [[NEW_VAR:%.*]] = "tf.Sub"([[OLD_VAR]], [[VAL_105]])
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[NEW_VAR]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||
// CHECK: "tf.AssignVariableOp"([[M_HANDLE]], [[NEW_M]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||
// CHECK: "tf.AssignVariableOp"([[V_HANDLE]], [[NEW_V]]) : (tensor<*x!tf.resource>, tensor<*xf32>) -> ()
|
||||
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
"tf.ResourceApplyAdam"(%0, %1, %2, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) {use_locking = false, use_nesterov = true} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -1,7 +1,8 @@
|
||||
// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowForXlaPass 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
// RUN: not tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowForXlaPass 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: FunctionalizeControlFlowPass: Graph contains node with inputs predicated on incompatible predicates: {s(Cond:0,then)} and {s(Cond:0,else)}
|
||||
// CHECK: error: FunctionalizeControlFlowForXlaPass: Graph contains node with inputs predicated on incompatible predicates: {s(Cond:0,then)} and {s(Cond:0,else)}
|
||||
// CHECK-NEXT: for node {{[{][{]node Add[}][}]}}
|
||||
|
||||
func @main() {
|
||||
%0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate")
|
||||
return
|
||||
|
@ -167,3 +167,16 @@ func @control_fetch(%arg0 : i32) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check that @main function is pruned.
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
// CHECK-NOT: tf_executor.island
|
||||
%0 = tf_executor.island {
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -0,0 +1,14 @@
|
||||
// RUN: tf-opt %s -tf-executor-graph-pruning=skip-main-func | FileCheck %s --dump-input=fail
|
||||
|
||||
// Check that @main function is skipped by default.
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() {
|
||||
tf_executor.graph {
|
||||
// CHECKT: tf_executor.island
|
||||
%0 = tf_executor.island {
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
@ -15,5 +15,6 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-mlir-translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,99 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-output-arrays=func_call -o - | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "x"
|
||||
op: "VarHandleOp"
|
||||
device: "/CPU:0"
|
||||
attr {
|
||||
key: "container"
|
||||
value {
|
||||
s: "a"
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shared_name"
|
||||
value {
|
||||
s: "x"
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "func_call"
|
||||
op: "test_func_name"
|
||||
input: "x"
|
||||
input: "x"
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "test_func_name"
|
||||
input_arg {
|
||||
name: "a_0"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
input_arg {
|
||||
name: "a_1"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
output_arg {
|
||||
name: "a"
|
||||
type: DT_RESOURCE
|
||||
}
|
||||
}
|
||||
resource_arg_unique_id {
|
||||
key: 0
|
||||
value: 0
|
||||
}
|
||||
resource_arg_unique_id {
|
||||
key: 1
|
||||
value: 0
|
||||
}
|
||||
ret {
|
||||
key: "a"
|
||||
value: "a_0"
|
||||
}
|
||||
attr {
|
||||
key: "_disable_call_shape_inference"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Check that the `resource_arg_unique_id` for each argument is propagated to the
|
||||
# `tf.resource_arg_unique_id` argument attribute of the function
|
||||
# @test_func_name0.
|
||||
|
||||
# CHECK: func @main
|
||||
# CHECK: tf_executor.graph
|
||||
# CHECK: "tf.VarHandleOp"()
|
||||
# CHECK: "tf.LegacyCall"
|
||||
# CHECK-SAME: {_disable_call_shape_inference = true, f = @test_func_name0}
|
||||
# CHECK: tf_executor.fetch
|
||||
# CHECK: return
|
||||
# CHECK: func @test_func_name0
|
||||
# CHECK-SAME: tf.resource_arg_unique_id = 0
|
||||
# CHECK-SAME tf.resource_arg_unique_id = 0
|
||||
# CHECK: tf_executor.graph
|
||||
# CHECK: tf_executor.fetch
|
||||
# CHECK: return
|
@ -1,6 +1,6 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
# RUN: not tf-mlir-translate -graphdef-to-mlir %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
this is not a valid graph def
|
||||
|
||||
#CHECK: {{(.|\.)+ Error parsing Protobuf:.*}}
|
||||
#CHECK: {{(.|\.)+ Graph import failed: Invalid argument: Could not parse input file}}
|
||||
# CHECK: Error parsing Protobuf
|
||||
# CHECK: Graph import failed: Invalid argument: Could not parse input proto
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
# RUN: not tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: Graph import failed: Invalid argument: Output NotANodeInTheGraph was not found in graph
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
# RUN: not tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: Graph import failed: Invalid argument: Invalid output index 1 specified for node: input
|
||||
|
||||
|
@ -115,6 +115,16 @@ func @pad(%arg0: tensor<3xf32>) -> tensor<6xf32> {
|
||||
return %0 : tensor<6xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @pad_bf16
|
||||
func @pad_bf16(%arg0: tensor<3xbf16>) -> tensor<6xbf16> {
|
||||
%padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64>
|
||||
// CHECK-DAG: [[PAD:%.+]] = "tf.Const"() {
|
||||
// CHECK-DAG: [[CST:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<bf16>}
|
||||
// CHECK: "tf.PadV2"(%arg0, [[PAD]], [[CST]])
|
||||
%0 = "tf.Pad"(%arg0, %padding) : (tensor<3xbf16>, tensor<1x2xi64>) -> tensor<6xbf16>
|
||||
return %0 : tensor<6xbf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @BiasAddGrad_NHWC
|
||||
func @BiasAddGrad_NHWC(%arg0: tensor<2x3x4x5xf32>) -> tensor<5xf32> {
|
||||
// CHECK: "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>}
|
||||
@ -266,3 +276,92 @@ func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf.
|
||||
%0 = "tf.AddN"(%arg0, %arg1, %arg2) : (tensor<!tf.variant<tensor<2xf32>>>, tensor<!tf.variant<tensor<2xf32>>>, tensor<!tf.variant<tensor<2xf32>>>) -> tensor<!tf.variant<tensor<2xf32>>>
|
||||
return %0 : tensor<!tf.variant<tensor<2xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @DynamicStitch_simple
|
||||
func @DynamicStitch_simple(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32>
|
||||
// CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<2x2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DynamicStitch_scalar_matrix_indices
|
||||
func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x2x2xf32>) -> (tensor<5x2xf32>) {
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK-DAG: %[[INP0:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
|
||||
// CHECK-DAG: %[[ITEMS0:.*]] = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64} : (tensor<1x2xf32>) -> tensor<2xf32>
|
||||
// CHECK-DAG: %[[INP1:.*]] = "tf.Reshape"(%arg1, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
|
||||
// CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %6 = "tf.ConcatV2"(%[[ITEMS1]]#3, %[[ITEMS1]]#2, %[[ITEMS1]]#1, %[[ITEMS1]]#0, %[[ITEMS0]], %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
|
||||
|
||||
%indices0 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
|
||||
%indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0, %arg1) : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>) -> tensor<5x2xf32>
|
||||
return %0 : tensor<5x2xf32>
|
||||
}
|
||||
|
||||
// Verify that custom types are lowered and have legal output.
|
||||
// CHECK-LABEL: func @DynamicStitch_uint8
|
||||
func @DynamicStitch_uint8(%arg0: tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8> {
|
||||
// CHECK-NOT: tf.DynamicStitch
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2x!tf.uint8>) -> tensor<2x2x!tf.uint8>
|
||||
return %0 : tensor<2x2x!tf.uint8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @DynamicStitch_scalar_item
|
||||
func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64>
|
||||
// CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xf32>, tensor<1xi64>) -> tensor<2xf32>
|
||||
// CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<f32>, tensor<f32>, tensor<i64>) -> tensor<2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @DynamicStitch_matrix_item
|
||||
func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> {
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2, 2]> : tensor<3xi64>} : () -> tensor<3xi64>
|
||||
// CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
|
||||
// CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2x2xf32>) -> tensor<2x2x2xf32>
|
||||
return %0 : tensor<2x2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @DynamicStitch_dynamic
|
||||
func @DynamicStitch_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: tf.DynamicStitch
|
||||
%0 = "tf.DynamicStitch"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @DynamicStitch_duplicates
|
||||
func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> {
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32>
|
||||
// CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[AXIS]]) : (tensor<2xf32>, tensor<i64>) -> tensor<1x2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[0, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
@ -15,5 +15,6 @@ filegroup(
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-mlir-translate",
|
||||
"@llvm//:FileCheck",
|
||||
"@llvm//:not",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,62 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main() -> tensor<*x!tf.resource> attributes {tf.entry_function = {inputs = "", outputs = "func_call"}} {
|
||||
%0 = tf_executor.graph {
|
||||
%outputs, %control = tf_executor.island wraps "tf.VarHandleOp"() {container = "a", device = "/CPU:0", dtype = i64, name = "x", shape = "tfshape$", shared_name = "x"} : () -> tensor<!tf.resource<tensor<i64>>>
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs, %outputs) {_disable_call_shape_inference = true, f = @test_func_name0} : (tensor<!tf.resource<tensor<i64>>>, tensor<!tf.resource<tensor<i64>>>) -> tensor<*x!tf.resource>
|
||||
tf_executor.fetch %outputs_0 : tensor<*x!tf.resource>
|
||||
}
|
||||
return %0 : tensor<*x!tf.resource>
|
||||
}
|
||||
func @test_func_name0(%arg0: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}, %arg1: tensor<*x!tf.resource> {tf.resource_arg_unique_id = 0 : i64}) -> tensor<*x!tf.resource> attributes {tf._disable_call_shape_inference = true} {
|
||||
%0 = tf_executor.graph {
|
||||
tf_executor.fetch %arg0 : tensor<*x!tf.resource>
|
||||
}
|
||||
return %0 : tensor<*x!tf.resource>
|
||||
}
|
||||
|
||||
// Check that the `tf.resource_arg_unique_id` argument attributes of
|
||||
// test_func_name0 are propagated to the function's arg_attr and
|
||||
// resource_arg_unique_id.
|
||||
|
||||
// CHECK: name: "x"
|
||||
// CHECK: op: "VarHandleOp"
|
||||
|
||||
// CHECK: name: "func_call"
|
||||
// CHECK: input: "x"
|
||||
// CHECK: input: "x"
|
||||
|
||||
// CHECK: library
|
||||
// CHECK: function
|
||||
// CHECK: signature
|
||||
// CHECK: input_arg
|
||||
// CHECK: type: DT_RESOURCE
|
||||
// CHECK: input_arg
|
||||
// CHECK: type: DT_RESOURCE
|
||||
// CHECK: output_arg
|
||||
// CHECK: type: DT_RESOURCE
|
||||
// CHECK: ret
|
||||
|
||||
// Check _resource_arg_unique_id for each argument. Since they alias each other,
|
||||
// both values are 0.
|
||||
// CHECK: arg_attr
|
||||
// CHECK-NEXT: key: 0
|
||||
// CHECK-NEXT: value
|
||||
// CHECK: key: "_resource_arg_unique_id"
|
||||
// CHECK-NEXT: value
|
||||
// CHECK-NEXT: i: 0
|
||||
// CHECK: arg_attr
|
||||
// CHECK-NEXT: key: 1
|
||||
// CHECK-NEXT: value
|
||||
// CHECK: key: "_resource_arg_unique_id"
|
||||
// CHECK-NEXT: value
|
||||
// CHECK-NEXT: i: 0
|
||||
|
||||
// Check resource_arg_unique_id for each argument. Since they alias each other,
|
||||
// both values are 0.
|
||||
// CHECK: resource_arg_unique_id
|
||||
// CHECK-NEXT: key: 0
|
||||
// CHECK-NEXT: value: 0
|
||||
// CHECK: resource_arg_unique_id
|
||||
// CHECK-NEXT: key: 1
|
||||
// CHECK-NEXT: value: 0
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0
|
||||
// RUN: not tf-mlir-translate -mlir-to-graphdef %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: Graph export failed: Failed precondition: entry function `main` must be present
|
||||
|
||||
@ -7,4 +7,3 @@ func @const() {
|
||||
%0:2 = "_tf.Const"() {device = "TPU:0", name = "const", dtype = "tfdtype$DT_INT32", value = dense<[1, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>, !_tf.control)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -79,4 +79,87 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
%1 = "tf.Conv2DBackpropInput"(%0, %arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<3x3x32x64xf32>, tensor<200x24x24x64xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %1 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_from_if_to_branch_functions
|
||||
func @shape_from_if_to_branch_functions(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", output_shapes = ["tfshape$"], then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_then_branch
|
||||
// CHECK-SAME: (%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
func @if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<1x2x3xf32>
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_else_branch
|
||||
// CHECK-SAME: (%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
func @if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Identity"(%arg0) : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<1x2x3xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_from_while_to_cond_body_functions
|
||||
func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_cond_func
|
||||
// CHECK-SAME: %arg0: tensor<4xf32>) -> tensor<i1>
|
||||
func @while_cond_func(%arg0: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<[1.000000e-04,2.000000e-04,3.000000e-04,4.000000e-04]> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
%1 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: tf.Equal
|
||||
// CHECK-SAME: (tensor<4xf32>, tensor<4xf32>) -> tensor<*xi1>
|
||||
// TODO(ycao): Investigate why result type of tf.Equal is not inferred.
|
||||
%2 = "tf.Equal"(%0, %arg0) : (tensor<4xf32>, tensor<*xf32>) -> tensor<*xi1>
|
||||
%3 = "tf.Any"(%2, %1) : (tensor<*xi1>, tensor<i32>) -> (tensor<i1>)
|
||||
return %3 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_body_func
|
||||
func @while_body_func(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e-04> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: tf.AddV2
|
||||
// CHECK-SAME: (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) : (tensor<*xf32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<4xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @invalid_function_reused_by_control_flows
|
||||
func @invalid_function_reused_by_control_flows(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
// expected-warning @+1 {{unable to refine shape}}
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = ["tfshape$"], then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
// expected-warning @+1 {{unable to refine shape}}
|
||||
%1 = "tf.If"(%arg0, %0) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @reused_if_else_branch, is_stateless = true, name = "if", output_shapes = ["tfshape$"], then_branch = @reused_if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reused_if_then_branch
|
||||
// CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// expected-error @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}}
|
||||
func @reused_if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<*xf32>
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reused_if_else_branch
|
||||
// CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32>
|
||||
// expected-error @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}}
|
||||
func @reused_if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
// CHECK: return
|
||||
// CHECK-SAME: tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
}
|
||||
|
@ -120,14 +120,14 @@ func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @unknown_side_effecting_op
|
||||
func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{ID: 14}}
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Successors: {12}}}
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Successors: {13}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Successors: {10}}}
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
%vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
%vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v1"} : () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
@ -141,30 +141,34 @@ func @unknown_side_effecting_op(%arg0: tensor<32xf32>) -> () {
|
||||
"tf._UnknownSideEffectingOp_"() : () -> ()
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {2,3}}}
|
||||
// expected-remark@above {{Successors: {5,6}}}
|
||||
// expected-remark@above {{Successors: {5,6,7}}}
|
||||
%read1 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
"tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
%read2 = "tf.ReadVariableOp"(%vh1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
"tf.AssignVariableOp"(%vh0, %read1) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
// expected-remark@above {{Successors: {8}}}
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
"tf.AssignVariableOp"(%vh1, %read0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Predecessors: {6,7}}}
|
||||
// expected-remark@above {{Predecessors: {5,6}}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {7,8}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Predecessors: {9}}}
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 12}}
|
||||
// expected-remark@above {{Predecessors: {11}}}
|
||||
// expected-remark@above {{ID: 13}}
|
||||
// expected-remark@above {{Predecessors: {12}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -270,3 +274,466 @@ func @with_replicate(
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass does not add control dependencies a stateless if op.
|
||||
|
||||
// CHECK-LABEL: func @stateless_if_op
|
||||
func @stateless_if_op(
|
||||
// expected-remark@above {{ID: 8}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
%r0 = "tf.ReadVariableOp"(%arg0) :
|
||||
// expected-remark@above {{ID: 0}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%if = "tf.If"(%arg1, %arg1) {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
then_branch = @if_then, else_branch = @if_else, is_stateless = true}
|
||||
: (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.AssignVariableOp"(%arg0, %r0) :
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
// expected-remark@above {{Successors: {3}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Predecessors: {2}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {6}}}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_then
|
||||
func @if_then(%arg0: tensor<i1>) -> tensor<i1> {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
%graph = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%island:2 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
tf_executor.yield %arg0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
}
|
||||
return %graph : tensor<i1>
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_else
|
||||
func @if_else(%arg0: tensor<i1>) -> tensor<i1> {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
%graph = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%island:2 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
tf_executor.yield %arg0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
}
|
||||
return %graph : tensor<i1>
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass does not add control dependencies a stateless while op.
|
||||
|
||||
// CHECK-LABEL: func @stateless_if_op
|
||||
func @stateless_if_op(
|
||||
// expected-remark@above {{ID: 8}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
%r0 = "tf.ReadVariableOp"(%arg0) :
|
||||
// expected-remark@above {{ID: 0}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%if = "tf.While"(%arg1) {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
body = @while_body, cond = @while_cond, is_stateless = true}
|
||||
: (tensor<i1>) -> tensor<i1>
|
||||
"tf.AssignVariableOp"(%arg0, %r0) :
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
// expected-remark@above {{Successors: {3}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Predecessors: {2}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {6}}}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_body
|
||||
func @while_body(%arg0: tensor<i1>) -> tensor<i1> {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
%graph = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%island:2 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
tf_executor.yield %arg0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
}
|
||||
return %graph : tensor<i1>
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_cond
|
||||
func @while_cond(%arg0: tensor<i1>) -> tensor<i1> {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
%graph = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%island:2 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
tf_executor.yield %arg0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
}
|
||||
return %graph : tensor<i1>
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass tracks control dependencies for variables from an if op's
|
||||
// output.
|
||||
|
||||
// CHECK-LABEL: func @output_of_if_op
|
||||
func @output_of_if_op(
|
||||
// expected-remark@above {{ID: 12}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
// expected-remark@above {{ID: 0}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%if:3 = "tf.If"(%arg2, %id0, %arg1) {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {2,3,4}}}
|
||||
then_branch = @if_then, else_branch = @if_else, is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
%r0 = "tf.ReadVariableOp"(%if#0) :
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
// expected-remark@above {{Successors: {5,6}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%r1 = "tf.ReadVariableOp"(%if#1) :
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%r2 = "tf.ReadVariableOp"(%if#2) :
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
// expected-remark@above {{Successors: {5,6}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %r0) :
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {2,3,4}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg1, %r0) :
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {2,4}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {5,6}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_then
|
||||
func @if_then(
|
||||
// expected-remark@above {{ID: 7}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
%graph:3 = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
%island:4 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%u0 = "tf._UnknownSideEffectingOp_"() : ()
|
||||
// expected-remark@above {{ID: 0}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
// expected-remark@above {{ID: 1}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
tf_executor.yield %u0, %id0, %id0 :
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {0}}}
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
tf_executor.fetch %island#0, %island#1, %island#2 :
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
return %graph#0, %graph#1, %graph#2 :
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_else
|
||||
func @if_else(
|
||||
// expected-remark@above {{ID: 7}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
%graph:3 = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
%island:4 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
// expected-remark@above {{ID: 0}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%id1 = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
// expected-remark@above {{ID: 1}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
tf_executor.yield %id0, %id0, %id1 :
|
||||
// expected-remark@above {{ID: 2}}
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
tf_executor.fetch %island#0, %island#1, %island#2 :
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
return %graph#0, %graph#1, %graph#2 :
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass tracks control dependencies for variables from a while
|
||||
// op's output.
|
||||
|
||||
// CHECK-LABEL: func @output_of_while_op
|
||||
func @output_of_while_op(
|
||||
// expected-remark@above {{ID: 12}}
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
// expected-remark@above {{ID: 10}}
|
||||
// expected-remark@above {{Successors: {11}}}
|
||||
// CHECK: tf_executor.island
|
||||
%island = tf_executor.island {
|
||||
// expected-remark@above {{ID: 8}}
|
||||
// expected-remark@above {{Successors: {9}}}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
// expected-remark@above {{ID: 0}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%while:4 = "tf.While"(%arg2, %id0, %arg1, %arg1) {
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {2,3,4}}}
|
||||
body = @while_body, cond = @while_cond, is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
%r0 = "tf.ReadVariableOp"(%while#1) :
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%r1 = "tf.ReadVariableOp"(%while#1) :
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
// expected-remark@above {{Successors: {5}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%r2 = "tf.ReadVariableOp"(%while#2) :
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
"tf.AssignVariableOp"(%arg0, %r0) :
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Predecessors: {2,3}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg1, %r0) :
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {4}}}
|
||||
// expected-remark@above {{Successors: {7}}}
|
||||
(tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
tf_executor.yield
|
||||
// expected-remark@above {{ID: 7}}
|
||||
// expected-remark@above {{Predecessors: {5,6}}}
|
||||
}
|
||||
tf_executor.fetch %island : !tf_executor.control
|
||||
// expected-remark@above {{ID: 9}}
|
||||
// expected-remark@above {{Predecessors: {8}}}
|
||||
}
|
||||
return
|
||||
// expected-remark@above {{ID: 11}}
|
||||
// expected-remark@above {{Predecessors: {10}}}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_body
|
||||
func @while_body(
|
||||
// expected-remark@above {{ID: 7}}
|
||||
%pred: tensor<i1>,
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) {
|
||||
%graph:4 = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
%island:5 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%id0 = "tf.Identity"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
// expected-remark@above {{ID: 0}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
%u0 = "tf._UnknownSideEffectingOp_"() : ()
|
||||
// expected-remark@above {{ID: 1}}
|
||||
// expected-remark@above {{Successors: {2}}}
|
||||
-> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
tf_executor.yield %pred, %id0, %arg1, %u0 :
|
||||
// expected-remark@above {{ID: 2}}
|
||||
// expected-remark@above {{Predecessors: {1}}}
|
||||
tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
tf_executor.fetch %island#0, %island#1, %island#2, %island#3 :
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
return %graph#0, %graph#1, %graph#2, %graph#3 :
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_cond
|
||||
func @while_cond(
|
||||
// expected-remark@above {{ID: 7}}
|
||||
%pred: tensor<i1>,
|
||||
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<i1> {
|
||||
%graph = tf_executor.graph {
|
||||
// expected-remark@above {{ID: 5}}
|
||||
// expected-remark@above {{Successors: {6}}}
|
||||
%island:2 = tf_executor.island {
|
||||
// expected-remark@above {{ID: 3}}
|
||||
// expected-remark@above {{Successors: {4}}}
|
||||
%const = "tf.Const"() { value = dense<0> : tensor<i1> } : () -> tensor<i1>
|
||||
// expected-remark@above {{ID: 0}}
|
||||
%eq = "tf.Equal"(%pred, %const) : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
// expected-remark@above {{ID: 1}}
|
||||
tf_executor.yield %eq : tensor<i1>
|
||||
// expected-remark@above {{ID: 2}}
|
||||
}
|
||||
tf_executor.fetch %island#0 : tensor<i1>
|
||||
// expected-remark@above {{ID: 4}}
|
||||
// expected-remark@above {{Predecessors: {3}}}
|
||||
}
|
||||
return %graph : tensor<i1>
|
||||
// expected-remark@above {{ID: 6}}
|
||||
// expected-remark@above {{Predecessors: {5}}}
|
||||
}
|
||||
|
@ -1538,6 +1538,14 @@ func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi32>, %end: ten
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{cannot have multiple ellipses}}
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {ellipsis_mask = 3}: (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testOneHot(%indices: tensor<3xi32>, %depth: tensor<i32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> {
|
||||
%result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<3x5xf32>
|
||||
return %result : tensor<3x5xf32>
|
||||
@ -1823,3 +1831,272 @@ func @testAxisDim(%input: tensor<2x6xf32>) {
|
||||
%0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<2x6xf32>) -> (tensor<6xf32>, tensor<6xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.UnsortedSegment{Max|Min|Prod|Sum}
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: unsortedSegmentReduction
|
||||
func @unsortedSegmentReduction(%data: tensor<?x10x8xf32>, %segment_ids: tensor<7x?xi32>, %num_segments: tensor<i32>) {
|
||||
// CHECK: tf.UnsortedSegmentMin
|
||||
%0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<?x10x8xf32>, tensor<7x?xi32>, tensor<i32>) -> (tensor<?x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unsortedSegmentReduction(%data: tensor<7x10x8xf32>, %segment_ids: tensor<7x10xi32>, %num_segments: tensor<2x3xi32>) {
|
||||
// expected-error @+1 {{number of segments should be a 0-D tensor}}
|
||||
%0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<7x10x8xf32>, tensor<7x10xi32>, tensor<2x3xi32>) -> (tensor<?x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unsortedSegmentReduction(%data: tensor<7x10x8xf32>, %segment_ids: tensor<7x9xi32>, %num_segments: tensor<i32>) {
|
||||
// expected-error @+1 {{requires segment ids shape to be a prefix of data shape, but dimension #1 differs: 9 vs. 10}}
|
||||
%0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<7x10x8xf32>, tensor<7x9xi32>, tensor<i32>) -> (tensor<?x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unsortedSegmentReduction(%data: tensor<7x10x8xf32>, %segment_ids: tensor<7x10x8x1xi32>, %num_segments: tensor<i32>) {
|
||||
// expected-error @+1 {{requires segment ids rank to be less than or equal to data's rank}}
|
||||
%0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<7x10x8xf32>, tensor<7x10x8x1xi32>, tensor<i32>) -> (tensor<?x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unsortedSegmentReduction(%data: tensor<7x10x8xf32>, %segment_ids: tensor<7x10xi32>) {
|
||||
%num_segments = "tf.Const"() {value = dense<-5> : tensor<i32>} : () -> (tensor<i32>)
|
||||
// expected-error @+1 {{num of segments cannot be negative}}
|
||||
%0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<7x10x8xf32>, tensor<7x10xi32>, tensor<i32>) -> (tensor<?x8xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.GatherV2
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
func @testGatherV2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5x3xf32>
|
||||
return %1 : tensor<16x2x5x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify that the batch_dims can be equal to the rank of the indices.
|
||||
func @testGatherV2(%arg0: tensor<16x4xf32>, %arg1: tensor<16xi32>) -> tensor<16xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = 1 : i64} : (tensor<16x4xf32>, tensor<16xi32>, tensor<1xi32>) -> tensor<16xf32>
|
||||
return %1 : tensor<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testGatherV2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
// expected-error @+1 {{batch_dims (-3) must be in range [-2, 3)}}
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -3 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5x3xf32>
|
||||
return %1 : tensor<16x2x5x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testGatherV2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[[-4]]> : tensor<1x1xi32> } : () -> tensor<1x1xi32>
|
||||
// expected-error @+1 {{requires axis to have rank at most 1}}
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1x1xi32>) -> tensor<16x2x5x3xf32>
|
||||
return %1 : tensor<16x2x5x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testGatherV2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[-4]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
// expected-error @+1 {{axis (-4) must be in range [-3, 3)}}
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5x3xf32>
|
||||
return %1 : tensor<16x2x5x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testGatherV2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5x3xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[0]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
// expected-error @+1 {{requires axis (0) to be greater than or equal to batch_dims (1)}}
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5x3xf32>
|
||||
return %1 : tensor<16x2x5x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.StridedSliceGrad
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>, %shape: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// CHECK: tf.StridedSliceGrad
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<i64>, %end: tensor<2xi64>, %strides: tensor<2xi64>, %shape: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{requires begin, end and strides to be 1D tensors}}
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<2xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<32xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>, %shape: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{with less than 32 elements}}
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<2xi64>, tensor<32xi64>, tensor<2xi64>, tensor<2xi64>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<?xi64>, %end: tensor<3xi64>, %strides: tensor<2xi64>, %shape: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{have the same number of elements}}
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<2xi64>, tensor<?xi64>, tensor<3xi64>, tensor<2xi64>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %shape: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
%begin = "tf.Const"() { value = dense<[0, 0]> : tensor<2xi64> } : () -> tensor<?xi64>
|
||||
%end = "tf.Const"() { value = dense<[5, 10]> : tensor<2xi64> } : () -> tensor<?xi64>
|
||||
%strides = "tf.Const"() { value = dense<[2, 3, 4]> : tensor<3xi64> } : () -> tensor<?xi64>
|
||||
|
||||
// expected-error @+1 {{have the same number of elements}}
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<2xi64>, tensor<?xi64>, tensor<?xi64>, tensor<?xi64>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %shape: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
%strides = "tf.Const"() { value = dense<[2, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
|
||||
// expected-error @+1 {{requires non-zero strides}}
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi32>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>, %shape: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{cannot have multiple ellipses}}
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) {ellipsis_mask = 3} : (tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @stridedSliceGrad(%dy: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>, %shape: tensor<1x2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{'shape' operand must be 1D tensor, but got 2D tensor}}
|
||||
%0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %dy) : (tensor<1x2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>, tensor<4x8xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch() -> tensor<2x2xf32> {
|
||||
// expected-error @+1 {{requires attribute N with value >= 1}}
|
||||
%0 = "tf.DynamicStitch"() : () -> (tensor<2x2xf32>)
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<f32> {
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error @+1 {{requires non scalar output}}
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%indices = "tf.Const"() {value = dense<[-1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error @+1 {{requires non-negative index values; found -1}}
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<3x2xf32>) -> tensor<2x2xf32> {
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error @+1 {{requires shape of data with type 'tensor<3x2xf32>' to have prefix matching with shape of the corresponding index type 'tensor<2xi32>'}}
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<3x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<2xf32>, %arg1: tensor<2x2x3xf32>) -> (tensor<5x2xf32>) {
|
||||
%indices0 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
|
||||
%indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
|
||||
// expected-error @+1 {{inconsistent shaped data and index pairs; inferred item shapes [2] and [3] don't match}}
|
||||
%0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0, %arg1) : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x3xf32>) -> tensor<5x2xf32>
|
||||
return %0 : tensor<5x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%indices = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error @+1 {{missing index 1}}
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<2x2xf32>) -> tensor<3x2xf32> {
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// expected-error @+1 {{has invalid output type; should be compatible with inferred type 'tensor<2x2xf32>'}}
|
||||
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<3x2xf32>
|
||||
return %0 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<?x2xi32>, %arg1: tensor<?x3x3xf32>) -> (tensor<*xf32>) {
|
||||
// expected-error @+1 {{requires shape of data with type 'tensor<?x3x3xf32>' to have prefix matching with shape of the corresponding index type 'tensor<?x2xi32>'}}
|
||||
%0 = "tf.DynamicStitch"(%arg0, %arg1) : (tensor<?x2xi32>, tensor<?x3x3xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDynamicStitch(%arg0: tensor<?x3xf32>, %arg1: tensor<2x?xf32>) -> (tensor<2x3x2xf32>) {
|
||||
%indices0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%indices1 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
|
||||
// expected-error @+1 {{has invalid output type; should be compatible with inferred type 'tensor<2x2x3xf32>'}}
|
||||
%0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0, %arg1) : (tensor<i32>, tensor<i32>, tensor<?x3xf32>, tensor<2x?xf32>) -> tensor<2x3x2xf32>
|
||||
return %0 : tensor<2x3x2xf32>
|
||||
}
|
||||
|
@ -344,7 +344,8 @@ func @replication(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<f32>) ->
|
||||
// CHECK: %[[OP_B:[0-9]*]] = "tf.opB"
|
||||
// CHECK: %[[OP_C:[0-9]*]] = "tf.opC"
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK-SAME: ([%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>, [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>)
|
||||
// CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
|
||||
// CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
|
||||
// CHECK-SAME: n = 2 : i32
|
||||
// CHECK-NEXT: %[[LAUNCH:[0-9]*]]:2 = "tf_device.launch"() ( {
|
||||
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
|
||||
@ -357,6 +358,32 @@ func @replication(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<f32>) ->
|
||||
// CHECK: return %[[REPLICATE]]#0, %[[REPLICATE]]#3
|
||||
|
||||
|
||||
// Test `tf.TPUReplicatedInput` ops are sorted by their `index` attribute.
|
||||
// Non-negative `index` should preceed `index` of -1, and ordering of ops with
|
||||
// `index` of -1 does not matter.
|
||||
// CHECK-LABEL: func @sort_replicated_input
|
||||
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>, %[[ARG_3:.*]]: tensor<i1>, %[[ARG_4:.*]]: tensor<i1>, %[[ARG_5:.*]]: tensor<i1>)
|
||||
func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>, %arg5: tensor<i1>) {
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.TPUReplicatedInput"(%arg1, %arg1) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
%2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 0 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
%3 = "tf.TPUReplicatedInput"(%arg3, %arg3) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
%4 = "tf.TPUReplicatedInput"(%arg4, %arg4) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
%5 = "tf.TPUReplicatedInput"(%arg5, %arg5) {index = -1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0, %1, %2, %3, %4, %5) {_tpu_replicate = "replicate", device = "device"} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: [%[[ARG_4]], %[[ARG_4]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@ -441,3 +468,44 @@ func @leftover_replicated_output(%arg0: tensor<i1>) {
|
||||
%0:2 = "tf.TPUReplicatedOutput"(%arg0) : (tensor<i1>) -> (tensor<i1>, tensor<i1>)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test bad TPUReplicatedInput positive `index` attribute.
|
||||
func @bad_positive_index_input(%arg0: tensor<i1>) {
|
||||
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got 1}}
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test bad TPUReplicatedInput negative `index` attribute.
|
||||
func @bad_negative_index_input(%arg0: tensor<i1>) {
|
||||
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got -2}}
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test TPUReplicatedInput with conflicting `index` attribute. This will result
|
||||
// in gaps in the TPUReplicatedInput ordering.
|
||||
func @input_index_gaps(%arg0: tensor<i1>) {
|
||||
// expected-error@+1 {{failed to sort 'tf.TPUReplicatedInput' ops, gap(s) found in indices}}
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
%1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
@ -0,0 +1,329 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-dynamic-padding | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test single argument with padding map lifted to associated encapsulated
|
||||
// function.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @single_arg_single_shape
|
||||
func @single_arg_single_shape(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @func0, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func0
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func0(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test single argument with multiple padding maps lifted to associated
|
||||
// encapsulated function.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
//
|
||||
// Padding map "\10\03\18\02":
|
||||
// arg_index: 0
|
||||
// shape_index: 3
|
||||
// padding_arg_index: 2
|
||||
// CHECK-LABEL: func @single_arg_multiple_shapes
|
||||
func @single_arg_multiple_shapes(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.launch_func"(%ri_0, %ri_1, %ri_2) {device = "", func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func1
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func1(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test multiple arguments with multiple padding maps lifted to associated
|
||||
// encapsulated function.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
//
|
||||
// Padding map "\10\03\18\02":
|
||||
// arg_index: 0
|
||||
// shape_index: 3
|
||||
// padding_arg_index: 2
|
||||
//
|
||||
// Padding map "\08\04\10\01\18\03":
|
||||
// arg_index: 4
|
||||
// shape_index: 1
|
||||
// padding_arg_index: 3
|
||||
// CHECK-LABEL: func @multiple_args
|
||||
func @multiple_args(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>, [%arg0, %arg0] as %ri_4: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.launch_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {device = "", func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func2
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [3 : i32], shape_indices = [1 : i32]}})
|
||||
func @func2(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test remapping of replicated inputs to encapsulated function arguments.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @remap_indices
|
||||
func @remap_indices(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func3, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func3
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [2 : i32]}})
|
||||
func @func3(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test no padding maps are added to encapsulated function if there is no
|
||||
// replication.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @no_replicate
|
||||
func @no_replicate(%arg0: tensor<i1>) {
|
||||
"tf_device.launch_func"(%arg0, %arg0, %arg0) {device = "", func = @func4, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func4
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func4(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test encapsulated function is not modified when there are no padding maps.
|
||||
// CHECK-LABEL: func @no_padding_map
|
||||
func @no_padding_map(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func5} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func5
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func5(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test encapsulated function is not modified when padding maps is empty.
|
||||
// CHECK-LABEL: func @empty_padding_map
|
||||
func @empty_padding_map(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.launch_func"(%ri_1, %arg0, %ri_0) {device = "", func = @func6, padding_map = []} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func6
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func6(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test unused padding map is not added to the encapsulated function.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @unused_padding_map
|
||||
func @unused_padding_map(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.launch_func"(%ri_1) {device = "", func = @func7, padding_map = ["\10\02\18\01"]} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func7
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func7(%arg0: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test bad padding map attribute (not an array).
|
||||
func @bad_padding_map() {
|
||||
tf_device.replicate {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op requires 'padding_map' array attribute}}
|
||||
"tf_device.launch_func"() {device = "", func = @_func, padding_map = 0 : i32} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func() {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test bad padding map attribute (element in array is not a string).
|
||||
func @bad_padding_map_element() {
|
||||
tf_device.replicate {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, not a string}}
|
||||
"tf_device.launch_func"() {device = "", func = @_func, padding_map = [0 : i32]} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func() {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test unparsable padding map.
|
||||
func @bad_padding_map_proto() {
|
||||
tf_device.replicate {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}}
|
||||
"tf_device.launch_func"() {device = "", func = @_func, padding_map = ["z"]} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func() {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test negative arg index.
|
||||
//
|
||||
// Padding map "\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01":
|
||||
// arg_index: -1
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
func @negative_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}}
|
||||
"tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test out of bound arg index.
|
||||
//
|
||||
// Padding map "\08\02\10\02\18\01":
|
||||
// arg_index: 2
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
func @bad_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}}
|
||||
"tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test negative padding arg index.
|
||||
//
|
||||
// Padding map "\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01":
|
||||
// arg_index: 1
|
||||
// shape_index: 2
|
||||
// padding_arg_index: -1
|
||||
func @negative_padding_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}}
|
||||
"tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test out of bound padding arg index.
|
||||
//
|
||||
// Padding map "\08\01\10\02\18\02":
|
||||
// arg_index: 1
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 2
|
||||
func @bad_padding_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}}
|
||||
"tf_device.launch_func"(%ri_0, %ri_1) {device = "", func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test arg that requires a padding arg but padding arg is not an arg to the
|
||||
// encapsulated function.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
func @missing_padding_arg(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.launch_func' op bad 'padding_map' attribute at index 0, unused padding_arg_index 1}}
|
||||
"tf_device.launch_func"(%ri_0) {device = "", func = @_func, padding_map = ["\10\02\18\01"]} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func(%arg0: tensor<i1>) {
|
||||
return
|
||||
}
|
@ -40,6 +40,7 @@ void CreateTPUBridge(OpPassManager &pm) {
|
||||
|
||||
pm.addPass(TF::CreateResourceDeviceInferencePass());
|
||||
pm.addPass(TFDevice::CreateClusterOutliningPass());
|
||||
pm.addPass(CreateTPUDynamicPaddingMapperPass());
|
||||
pm.addPass(CreateTPURewritePass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
|
||||
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user