Merge branch 'master' into pluggable_device_load

This commit is contained in:
Zhoulong Jiang 2020-10-23 16:03:48 +00:00
commit 22d22ff5b3
1932 changed files with 56329 additions and 31099 deletions

View File

@ -174,6 +174,12 @@ build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_opensource=true build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only -c opt build:mkl_opensource_only -c opt
# Config setting to build with oneDNN for Arm.
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_aarch64 --define=build_with_mkl_opensource=true
build:mkl_aarch64 -c opt
# This config refers to building with CUDA available. It does not necessarily # This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels. # mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true build:using_cuda --define=using_cuda=true

View File

@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees # A list of assignees
assignees: assignees:

28
.github/workflows/update-nightly.yml vendored Normal file
View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
name: Set nightly branch to master HEAD
jobs:
master-to-nightly:
runs-on: ubuntu-latest
steps:
- uses: zofrex/mirror-branch@v1
name: Set nightly branch to master HEAD
with:
target-branch: 'nightly'

View File

@ -4,7 +4,7 @@
/tensorflow/core/common_runtime/eager @qqfish @kkimdev /tensorflow/core/common_runtime/eager @qqfish @kkimdev
/tenosrflow/core/debug @caisq /tenosrflow/core/debug @caisq
/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @gunan @mihaimaruseac /tensorflow/core/platform/windows/ @mihaimaruseac
/tensorflow/lite/experimental/micro @petewarden @advaitjain /tensorflow/lite/experimental/micro @petewarden @advaitjain
/tensorflow/python/autograph/ @mdanatg @kkimdev /tensorflow/python/autograph/ @mdanatg @kkimdev
/tensorflow/python/debug @caisq /tensorflow/python/debug @caisq

View File

@ -34,6 +34,7 @@
shape assumptions (note that you can pass shapes with `None` entries for axes shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`. entirely by setting `model.input_spec = None`.
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use * XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed). removed).
@ -46,6 +47,13 @@
* `tf.data.experimental.service.WorkerServer` now takes a config tuple * `tf.data.experimental.service.WorkerServer` now takes a config tuple
instead of individual arguments. Usages should be updated to instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.WorkerServer(worker_config)`. `tf.data.experimental.service.WorkerServer(worker_config)`.
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
updates the gradient definition for quantization which is outside the range
to be 0. To simulate the V1 the behavior of
tf.quantization.quantize_and_dequantize(...) use
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
use `tf.data.Dataset.from_tensor_slices` instead.
## Known Caveats ## Known Caveats
@ -63,143 +71,168 @@
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES> * <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE> * <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA> * <NOTES SHOULD BE GROUPED PER AREA>
* Security: * Security:
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format * Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and * Fixes several vulnerabilities in `RaggedCountSparseOutput` and
`SparseCountSparseOutput` operations `SparseCountSparseOutput` operations
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196), ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197), [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198), [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199), [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200), [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)) [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
* Fixes an integer truncation vulnerability in code using the work sharder API * Fixes an integer truncation vulnerability in code using the work sharder
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) API
* Fixes a format string vulnerability in `tf.strings.as_string` ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) * Fixes a format string vulnerability in `tf.strings.as_string`
* Fixes segfault raised by calling session-only ops in eager mode ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) * Fixes segfault raised by calling session-only ops in eager mode
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams` ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) * Fixes data leak and potential ASLR violation from
* Fixes segfaults caused by incomplete `SavedModel` validation `tf.raw_ops.StringNGrams`
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes a data corruption due to a bug in negative indexing support in TFLite * Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to dimension mismatch in TFLite * Fixes a data corruption due to a bug in negative indexing support in
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) TFLite
* Fixes several vulnerabilities in TFLite saved model format ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), * Fixes a data corruption due to dimension mismatch in TFLite
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) * Fixes several vulnerabilities in TFLite saved model format
* Fixes several vulnerabilities in TFLite implementation of segment sum ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) * Fixes several vulnerabilities in TFLite implementation of segment sum
* TF Core: ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
type annotation for variables representing a Tensor or a value that can be [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
converted to Tensor by `tf.convert_to_tensor`. * TF Core:
* Calling ops with a python constants or numpy values is now consistent with * `tf.types.experimental.TensorLike` is a new `Union` type that can be
tf.convert_to_tensor behavior. This avoids operations like tf.reshape used as type annotation for variables representing a Tensor or a value
truncating inputs such as from int64 to int32. that can be converted to Tensor by `tf.convert_to_tensor`.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments. * Calling ops with a python constants or numpy values is now consistent
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` with tf.convert_to_tensor behavior. This avoids operations like
and `__invert__` now support non-`bool` arguments and apply the tf.reshape truncating inputs such as from int64 to int32.
corresponding bitwise ops. `bool` arguments continue to be supported and * Added `tf.sparse.map_values` to apply a function to the `.value`s of
dispatch to logical ops. This brings them more in line with Python and NumPy `SparseTensor` arguments.
benavior. * The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with `__xor__` and `__invert__` now support non-`bool` arguments and apply
the same sparsity pattern, but with new provided values. It is similar to the corresponding bitwise ops. `bool` arguments continue to be supported
the `with_values` function of `RaggedTensor`. and dispatch to logical ops. This brings them more in line with Python
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops. and NumPy behavior.
* Added `tf.config.experimental.get_memory_usage` to return total memory usage * Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
of the device. with the same sparsity pattern, but with new provided values. It is
* `tf.data`: similar to the `with_values` function of `RaggedTensor`.
* tf.data service: * Added `StatelessCase` op, and uses it if none of case branches has
* Added new `tf.data.experimental.service.register_dataset` and stateful ops.
`tf.data.experimental.service.from_dataset_id` APIs to enable one process * Added `tf.config.experimental.get_memory_usage` to return total memory
to register a dataset with the tf.data service, and another process to usage of the device.
consume data from the dataset. * `tf.data`:
* Added support for dispatcher fault tolerance. To enable fault tolerance, * tf.data service:
configure a `work_dir` when running your dispatcher server and set * Added new `tf.data.experimental.service.register_dataset` and
`dispatcher_fault_tolerance=True`. The dispatcher will store its state to `tf.data.experimental.service.from_dataset_id` APIs to enable one
`work_dir`, so that on restart it can continue from its previous state process to register a dataset with the tf.data service, and another
after restart. process to consume data from the dataset.
* Added support for sharing dataset graphs via shared filesystem instead of * Added support for dispatcher fault tolerance. To enable fault tolerance,
over RPC. This reduces load on the dispatcher, improving performance of configure a `work_dir` when running your dispatcher server and set
distributing datasets. For this to work, the dispatcher's `work_dir` must `dispatcher_fault_tolerance=True`. The dispatcher will store its state
be accessible from workers. If the worker fails to read from the to `work_dir`, so that on restart it can continue from its previous
`work_dir`, it falls back to using RPC for dataset graph transfer. state after restart.
* Added support for a new "distributed_epoch" processing mode. This * Added support for sharing dataset graphs via shared filesystem instead
processing mode distributes a dataset across all tf.data workers, instead of over RPC. This reduces load on the dispatcher, improving performance
of having each worker process the full dataset. See of distributing datasets. For this to work, the dispatcher's `work_dir`
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode) must be accessible from workers. If the worker fails to read from the
to learn more. `work_dir`, it falls back to using RPC for dataset graph transfer.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is * Added support for a new "distributed_epoch" processing mode. This
the complement of `select_cols`; at most one of these should be specified. processing mode distributes a dataset across all tf.data workers,
* We have implemented an optimization which reorders data-discarding instead of having each worker process the full dataset. See
transformations such as `take` and `shard` to happen earlier in the [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
dataset when it is safe to do so. The optimization can be disabled via to learn more.
the `experimental_optimization.reorder_data_discarding_ops` dataset * Added optional `exclude_cols` parameter to CsvDataset. This parameter is
option. the complement of `select_cols`; at most one of these should be
* `tf.data.Options` were previously immutable and can now be overriden. specified.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors * We have implemented an optimization which reorders data-discarding
with a new `output_signature` argument, which allows `from_generator` to transformations such as `take` and `shard` to happen earlier in the
produce any type describable by a `tf.TypeSpec`. dataset when it is safe to do so. The optimization can be disabled via
* `tf.data.experimental.AUTOTUNE` is now available in the core API as the `experimental_optimization.reorder_data_discarding_ops` dataset
`tf.data.AUTOTUNE`. option.
* `tf.image`: * `tf.data.Options` were previously immutable and can now be overridden.
* Added deterministic `tf.image.stateless_random_*` functions for each * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
`tf.image.random_*` function. Added a new op with a new `output_signature` argument, which allows `from_generator` to
`stateless_sample_distorted_bounding_box` which is a determinstic produce any type describable by a `tf.TypeSpec`.
version of `sample_distorted_bounding_box` op. Given the same seed, these * `tf.data.experimental.AUTOTUNE` is now available in the core API as
stateless functions/ops produce the same results independent of how many `tf.data.AUTOTUNE`.
times the function is called, and independent of global seed settings. * `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a deterministic
version of `sample_distorted_bounding_box` op. Given the same seed,
these stateless functions/ops produce the same results independent of
how many times the function is called, and independent of global seed
settings.
* `tf.distribute`: * `tf.distribute`:
* <ADD RELEASE NOTES HERE> * <ADD RELEASE NOTES HERE>
* `tf.keras`: * `tf.keras`:
* Improvements from the functional API refactoring: * Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models. * Functional model construction does not need to maintain a global
* Functional model construction should be ~8-10% faster on average. workspace graph, removing memory leaks especially when building many
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument. models or very large models.
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale` * Functional model construction should be ~8-10% faster on average.
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. * Functional models can now contain non-symbolic values in their call
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` inputs inside of the first positional argument.
as an alternative to accepting a `callable` loss. * Several classes of TF ops that were not reliably converted to Keras
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) layers during functional API construction should now work, e.g.
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf). `tf.image.ssim_multiscale`
* Added `mobilenet_v3` to keras application model. * Error messages when Functional API construction goes wrong (and when
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for ops cannot be converted to Keras layers automatically) should be
customization of how gradients are aggregated across devices, as well as clearer and easier to understand.
`gradients_transformers` to allow for custom gradient transformations * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
(such as gradient clipping). as an alternative to accepting a `callable` loss.
* The `steps_per_execution` argument in `compile()` is no longer * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
experimental; if you were passing `experimental_steps_per_execution`, to match FTRL paper
rename it to `steps_per_execution` in your code. This argument controls (https://research.google.com/pubs/archive/41159.pdf).
the number of batches to run during each `tf.function` call when calling * Added `mobilenet_v3` to keras application model.
`fit()`. Running multiple batches inside a single `tf.function` call can * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
greatly improve performance on TPUs or small models with a large Python customization of how gradients are aggregated across devices, as well as
overhead. `gradients_transformers` to allow for custom gradient transformations
* `tf.function` / AutoGraph: (such as gradient clipping).
* Added `experimental_follow_type_hints` argument for `tf.function`. When * The `steps_per_execution` argument in `compile()` is no longer
True, the function may use type annotations to optimize the tracing experimental; if you were passing `experimental_steps_per_execution`,
performance. rename it to `steps_per_execution` in your code. This argument controls
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. the number of batches to run during each `tf.function` call when calling
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if `fit()`. Running multiple batches inside a single `tf.function` call can
the values of these symbols at an iteration does not depend on the previous greatly improve performance on TPUs or small models with a large Python
iteration. These types of loops must run at least one iteration, and will overhead.
raise a runtime error otherwise. * Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an
init arg.
* Normalization can now accept mean and variance values as init args.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
accepts a `return_attention_scores` argument. When set to
True, the layer returns the attention scores as an additional output
argument.
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
with the same implementation as their `tf.losses` equivalent.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the
previous iteration. These types of loops must run at least one
iteration, and will raise a runtime error otherwise.
Example: Example:
@ -208,51 +241,97 @@
outputs = train_step(batch) outputs = train_step(batch)
tf.print('final outputs', outputs) tf.print('final outputs', outputs)
``` ```
See tensorflow/python/autograph/g3doc/reference/limitations.md for more See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info. info.
* `tf.lite`: * `tf.lite`:
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty. * `TFLiteConverter`:
* `TFLiteConverter`: * Support optional flags `inference_input_type` and
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`). `inference_output_type` for full integer quantized models. This
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API allows users to modify the model input and output type to integer
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly. types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair. (`tf.float32`).
* TFLite Profiler for Android is available. See the detailed * TFLite Profiler for Android is available. See the detailed
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android). [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
* <ADD RELEASE NOTES HERE> * NNAPI
* Added NNAPI Delegation support for requantization use cases by
converting the operation into a dequantize-quantize pair.
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
* Use `Interpreter.Options.setUseNNAPI` instead.
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
* Prefer controlling this via delegate options, e.g.
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* <ADD RELEASE NOTES HERE>
* `tf.random`: * `tf.random`:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra: * Math and Linear Algebra:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* TPU Enhancements: * TPU Enhancements:
* Added support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent * Added support for the `beta` parameter of the FTRL optimizer for TPU
behavior by adjusting the `l2` parameter. embeddings. Users of other TensorFlow platforms can implement equivalent
* <ADD RELEASE NOTES HERE> behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* XLA Support: * XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead * xla.experimental.compile is deprecated, use
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR `tf.function(experimental_compile=True)` instead
(currently 'hlo' and 'optimized_hlo') for given input for given function. * Added `tf.function.experimental_get_compiler_ir` which returns compiler
* <ADD RELEASE NOTES HERE> IR (currently 'hlo' and 'optimized_hlo') for given input for given
function.
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging: * Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`: * `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint` * Now accepts a `root` argument in the initialization, which generates a
object that is compatible with Keras `model.save_weights()` and checkpoint with a root object. This allows users to create a
`model.load_weights`. The checkpoint is also compatible with the `Checkpoint` object that is compatible with Keras `model.save_weights()`
checkpoint saved in the `variables/` folder in the SavedModel. and `model.load_weights`. The checkpoint is also compatible with the
* When restoring, `save_path` can be a path to a SavedModel. The function checkpoint saved in the `variables/` folder in the SavedModel.
will automatically find the checkpoint in the SavedModel. * When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* `tf.nn`: * `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* `tf.nn.max_pool2d` now supports explicit padding.
* `tf.debugging`:
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
* `tf.print`:
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping.
* Other: * Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see * We have replaced uses of "whitelist" and "blacklist" with "allowlist"
https://developers.google.com/style/word-list#blacklist for more context. and "denylist" where possible. Please see
<ADD RELEASE NOTES HERE> https://developers.google.com/style/word-list#blacklist for more
context.
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
rollout the new MLIR TPU bridge.
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors ## Thanks to our Contributors
@ -500,42 +579,87 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.0 # Release 2.3.0
## Major Features and Improvements ## Major Features and Improvements
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler. * `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and
save resources:
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`). * [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your models memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level. In addition checkout the detailed
[guide](https://www.tensorflow.org/guide/data_performance_analysis) for
analyzing input pipeline performance with TF Profiler.
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers. * [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
is now a stable API and no longer considered experimental for TensorFlow.
(earlier `tf.distribute.experimental.TPUStrategy`).
* TFLite now properly supports dynamic shapes during conversion and inference. Weve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental). * [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new
tools: a memory profiler to visualize your models memory usage over time
and a [python tracer](https://www.tensorflow.org/guide/profiler#events)
which allows you to trace python function calls in your model. Usability
improvements include better diagnostic messages and
[profile options](https://tensorflow.org/guide/profiler#collect_performance_data)
to customize the host and device trace verbosity level.
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds). * Introduces experimental support for Keras Preprocessing Layers API
([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly))
to handle data preprocessing operations, with support for composite tensor
inputs. Please see below for additional details on these layers.
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations. * TFLite now properly supports dynamic shapes during conversion and inference.
Weve also added opt-in support on Android and iOS for
[XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack),
a highly optimized set of CPU kernels, as well as opt-in support for
[executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
* Libtensorflow packages are available in GCS starting this release. We have
also started to
[release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
* The experimental Python API
[`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info)
now allows you to instrument a TensorFlow program and dump debugging
information to a directory on the file system. The directory can be read and
visualized by a new interactive dashboard in TensorBoard 2.3 called
[Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which
reveals the details of the TensorFlow program including graph structures,
history of op executions at the Python (eager) and intra-graph levels, the
runtime dtype, shape, and numerical composition of tensors, as well as their
code locations.
## Breaking Changes ## Breaking Changes
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
* `tf.data` * Increases the **minimum bazel version** required to build TF to **3.1.0**.
* Makes the following (breaking) changes to the `tf.data`. * `tf.data`
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation. * Makes the following (breaking) changes to the `tf.data`.
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`. * C++ API: - `IteratorBase::RestoreInternal`,
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed. `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState`
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly. become pure-virtual and subclasses are now expected to provide an
* `tf.keras` implementation.
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback. * The deprecated `DatasetBase::IsStateful` method is removed in favor of
* `tf.image.extract_glimpse` has been updated to correctly process the case `DatasetBase::CheckExternalState`.
where `centered=False` and `normalized=False`. This is a breaking change as * Deprecated overrides of `DatasetBase::MakeIterator` and
the output is different from (incorrect) previous versions. Note this `MakeIteratorFromInputElement` are removed.
breaking change only impacts `tf.image.extract_glimpse` and * The signature of `tensorflow::data::IteratorBase::SaveInternal` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of `tensorflow::data::IteratorBase::SaveInput` has been extended with
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of `SerializationContext` argument to enable overriding the default policy
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved for the handling external state during iterator checkpointing. This is
models using `tf.raw_ops.ExtractGlimpse` will not be impacted. not a backwards compatible change and all subclasses of `IteratorBase`
*need to be updated* accordingly.
* `tf.keras`
* Add a new `BackupAndRestore` callback for handling distributed training
failures & restarts. Please take a look at this
[tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
for details on how to use the callback.
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
existing C++ kernel `ExtractGlimpse` does not change either, so saved models
using `tf.raw_ops.ExtractGlimpse` will not be impacted.
## Known Caveats ## Known Caveats
* `tf.lite` * `tf.lite`
@ -1525,8 +1649,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
conversion. TensorRT initialization arguments are now passed wrapped in conversion. TensorRT initialization arguments are now passed wrapped in
a named-tuple, `TrtConversionParams`, rather than as separate arguments a named-tuple, `TrtConversionParams`, rather than as separate arguments
as in `TrtGraphConverter`. as in `TrtGraphConverter`.
* Changed API to optimize TensorRT enginges during graph optimization. * Changed API to optimize TensorRT engines during graph optimization. This
This is now done by calling `converter.build()` where previously is now done by calling `converter.build()` where previously
`is_dynamic_op=False` would be set. `is_dynamic_op=False` would be set.
* `converter.convert()` no longer returns a `tf.function`. Now the * `converter.convert()` no longer returns a `tf.function`. Now the
function must be accessed from the saved model. function must be accessed from the saved model.

View File

@ -1485,6 +1485,7 @@ def main():
'adding "--config=<>" to your build command. See .bazelrc for more ' 'adding "--config=<>" to your build command. See .bazelrc for more '
'details.') 'details.')
config_info_line('mkl', 'Build with MKL support.') config_info_line('mkl', 'Build with MKL support.')
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
config_info_line('monolithic', 'Config for mostly static monolithic build.') config_info_line('monolithic', 'Config for mostly static monolithic build.')
config_info_line('ngraph', 'Build with Intel nGraph support.') config_info_line('ngraph', 'Build with Intel nGraph support.')
config_info_line('numa', 'Build with NUMA support.') config_info_line('numa', 'Build with NUMA support.')

View File

@ -568,17 +568,7 @@ selects.config_setting_group(
# If you need functionality that is not exposed, we will work with you to expand our public APIs. # If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group( package_group(
name = "internal", name = "internal",
packages = [ packages = ["//tensorflow/..."],
"//learning/brain/distribute/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
],
) )
package_group( package_group(
@ -588,10 +578,8 @@ package_group(
# Packages that use private types symbols, until they are exported. # Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove. # TODO(b/154650521) Remove.
package_group( # If this is modified, then copy.bara.sky must also be modified.
name = "types_whitelist", package_group(name = "types_whitelist")
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
# Packages that use StructuredTensors. # Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported. # TODO(b/159007891) Remove this package once StructuredTensor is exported.
@ -719,7 +707,7 @@ tf_cc_shared_object(
deps = [ deps = [
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core:core_cpu_impl", "//tensorflow/core/common_runtime:core_cpu_impl",
"//tensorflow/core:framework_internal_impl", "//tensorflow/core:framework_internal_impl",
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",

View File

@ -217,6 +217,8 @@ tf_cuda_library(
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:logging_ops", "//tensorflow/core/kernels:logging_ops",
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
], ],
}), }),
alwayslink = 1, alwayslink = 1,
@ -254,6 +256,30 @@ tf_cuda_library(
}), }),
) )
cc_library(
name = "tf_shape",
srcs = ["tf_shape.cc"],
hdrs = ["tf_shape.h"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tf_shape_internal",
"//tensorflow/core:framework",
],
)
cc_library(
name = "tf_shape_internal",
hdrs = ["tf_shape_internal.h"],
copts = tf_copts(),
visibility = ["//tensorflow:internal"],
deps = [
":conversion_macros",
"//tensorflow/core:framework",
],
)
cc_library( cc_library(
name = "tf_status", name = "tf_status",
srcs = ["tf_status.cc"], srcs = ["tf_status.cc"],

View File

@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
return ret; return ret;
} }
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
using tensorflow::RecordMutation;
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
}
// TF_Server functions ---------------------------------------------- // TF_Server functions ----------------------------------------------
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)

View File

@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
const char* name, TF_Status* status); const char* name, TF_Status* status);
// Update edge, switch input/ output in a node
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
TF_Input dst, TF_Status* status);
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
// In-process TensorFlow server functionality, for use in distributed training. // In-process TensorFlow server functionality, for use in distributed training.
// A Server instance encapsulates a set of devices and a Session target that // A Server instance encapsulates a set of devices and a Session target that

View File

@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
TF_DeleteStatus(s); TF_DeleteStatus(s);
} }
TEST(CAPI, UpdateEdge) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Make two scalar constants.
TF_Operation* one = ScalarConst(1, graph, s, "one");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* two = ScalarConst(2, graph, s, "two");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add oper.
TF_Operation* add = Add(one, two, graph, s, "add");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add another oper to the graph.
TF_Operation* neg = Neg(add, graph, s, "neg");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
NodeDef node_def_neg;
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("add"), node_def_neg.input(0));
// update edge of neg
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
// Clean up
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
/* /*
TODO(skyewm): this test currently DCHECKs, change to bad status TODO(skyewm): this test currently DCHECKs, change to bad status

View File

@ -3,7 +3,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"if_tpu", "if_libtpu",
"tf_cc_test", "tf_cc_test",
"tf_copts", "tf_copts",
"tf_cuda_cc_test", "tf_cuda_cc_test",
@ -116,7 +116,6 @@ filegroup(
"immediate_execution_context.h", "immediate_execution_context.h",
"immediate_execution_operation.h", "immediate_execution_operation.h",
"immediate_execution_tensor_handle.h", "immediate_execution_tensor_handle.h",
"mnist_gradients_testutil.h",
"tape.h", "tape.h",
"tfe_cancellation_manager_internal.h", "tfe_cancellation_manager_internal.h",
"tfe_context_internal.h", "tfe_context_internal.h",
@ -290,7 +289,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
] + if_tpu( ] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [], if_true = [],
), ),
@ -314,6 +313,7 @@ cc_library(
":gradients_internal", ":gradients_internal",
":gradients_util", ":gradients_util",
":tape", ":tape",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/c/experimental/ops:nn_ops",
@ -354,7 +354,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
] + if_tpu( ] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [], if_true = [],
), ),

View File

@ -39,7 +39,7 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU) #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" #include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif #endif
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
@ -729,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) { if (opts->use_tfrt) {
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU) #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
#else #else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy( void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
} }
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads. // safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) { TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>( return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy()); tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
} }
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) { TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
} }
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context = return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr;
} }
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
} }
} // extern "C" } // extern "C"

View File

@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
// Placement policy which silently copies int32 tensors but not other dtypes. // Placement policy which silently copies int32 tensors but not other dtypes.
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy; } TFE_ContextDevicePlacementPolicy;
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h) // LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
// Sets the default execution mode (sync/async). Note that this can be // Sets the default execution mode (sync/async). Note that this can be
// overridden per thread using TFE_ContextSetExecutorForThread. // overridden per thread using TFE_ContextSetExecutorForThread.

View File

@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false); TestDistributedFunctionCancellation(false);
} }
TEST(CAPI, DistributedFunctionCancelledOnError) { // TODO(b/170399182): Update test once an alternative to using the function
// optimization hook is in place.
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true); TestDistributedFunctionCancellation(true);
} }

View File

@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
} }
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
} }
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
} }
uint64_t TFE_GetContextId(TFE_Context* ctx) { uint64_t TFE_GetContextId(TFE_Context* ctx) {
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
} }
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor());
} }
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context = return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor());
} }
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace( auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name()); tensorflow::unwrap(ctx)->HostCPUParsedName());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length()); void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0); str.copy(static_cast<char*>(data), str.length(), 0);
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) { TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context = auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name);
if (function_def == nullptr) { if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound( status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name); "Unable to find FunctionDef with name: ", function_name);
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
} }
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
} }

View File

@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction(
&ctx, incoming_gradients, result); &ctx, incoming_gradients, result);
} }
Status TapeVSpace::BuildOnesLike(TapeTensor t, Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const { AbstractTensorHandle** result) const {
AbstractOperationPtr op(ctx_->CreateOperation()); AbstractOperationPtr op(ctx_->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));

View File

@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute // allow us to trace the data dependencies between operations and hence compute
// gradients. // gradients.
// //
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation // `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered // of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op. // for each op.
@ -233,7 +229,7 @@ class TapeVSpace
std::vector<AbstractTensorHandle*>* result) const override; std::vector<AbstractTensorHandle*>* result) const override;
// Builds a tensor filled with ones with the same shape and dtype as `t`. // Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(TapeTensor t, Status BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const override; AbstractTensorHandle** result) const override;
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.

View File

@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer)); TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
return Status::OK(); return Status::OK();
} }
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
return Status::OK(); return Status::OK();
} }
// Computes
// y = sqrt(inputs[0])
// return grad(y, {inputs[0]})
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sqrt_output : sqrt_outputs) {
sqrt_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes // Computes
// ignored, y = IdentityN(inputs[0], inputs[1]) // ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]}) // return grad(y, {inputs[0], inputs[1]})
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
result_tensor = nullptr; result_tensor = nullptr;
} }
TEST_P(CppGradients, TestSqrtGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = sqrt(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) { TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code: // Pseudo-code:
// //

View File

@ -29,8 +29,25 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow { namespace tensorflow {
class EagerExecutor;
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
enum ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default policy.
DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
};
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
// Abstract interface to a context. // Abstract interface to a context.
// //
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
// List attributes of available devices // List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0; virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished. // Block until all pending nodes are finished.
virtual Status AsyncWait() = 0; virtual Status AsyncWait() = 0;
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists. // already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
// Return the ParsedName of Host CPU device.
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
// Configure soft device placement policy.
virtual void SetAllowSoftPlacement(bool enable) = 0;
// Configure device placement policy logging.
virtual void SetLogDevicePlacement(bool enable) = 0;
// Sets the device placement policy for the current thread.
virtual void SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) = 0;
// Returns the device placement policy for the current thread.
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
// For LLVM style RTTI. // For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) { static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt; return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
} }
//===--------------------------------------------------------------------===//
// Following are legacy features in TF Eager Runtime.
// TODO(tf-runtime): Figure out a way to deprecate following features after
// migrated to TFRT.
//===--------------------------------------------------------------------===//
// Clear pending nodes in thread executors and kernel caches.
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Return the Eager Executor for current thread. Please note that Eager
// Executor is only used in current TF but not in TFRT.
virtual EagerExecutor& Executor() = 0;
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
protected: protected:
explicit ImmediateExecutionContext(AbstractContextKind kind) explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {} : AbstractContext(kind) {}

View File

@ -25,133 +25,18 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h" #include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
// ========================== Tape Ops ==============================
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
namespace internal { namespace internal {
using std::vector; using std::vector;
using tensorflow::tracing::TracingOperation;
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
int num_retvals = 1;
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
// one-hot) and records it on the tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractTensorHandle* scores = inputs[0];
AbstractTensorHandle* labels = inputs[1];
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
int num_retvals = 2; // returns loss values and backprop
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
//===================== Test Models to run ========================= //===================== Test Models to run =========================
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1); std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
registry)); // Compute x+y. TF_RETURN_IF_ERROR(
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1); vector<AbstractTensorHandle*> mm_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
/*transpose_b=*/false, registry)); // Compute x*y. absl::MakeSpan(mm_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
tape->Watch(ToId(W2)); // Watch W2. tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
absl::MakeSpan(temp_outputs), "relu", absl::MakeSpan(temp_outputs),
registry)); // Compute Relu(X*W1) "relu")); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, TF_RETURN_IF_ERROR(ops::MatMul(
absl::MakeSpan(temp_outputs), "matmul1", tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
/*transpose_a=*/false, /*transpose_b=*/false, "matmul1",
registry)); // Compute W2*Relu(X*W1) /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0]; AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2); temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels) "softmax_loss")); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0]; AbstractTensorHandle* loss_vals = temp_outputs[0];
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/true, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/true,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X tape->Watch(ToId(inputs[0])); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1); vector<AbstractTensorHandle*> relu_outputs(1);
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"relu0", registry)); // Relu(X) TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(relu_outputs),
"relu0")); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch scores. tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels. tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2); vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); // Watch W1. tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1. tape->Watch(ToId(W2)); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0]; AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1) absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0", registry)); "relu0"));
AbstractTensorHandle* hidden = temp_outputs[0]; AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, TF_RETURN_IF_ERROR(ops::MatMul(
absl::MakeSpan(temp_outputs), "matmul1", tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false, /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
registry)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0]; AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2); temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1) "softmaxloss")); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0]; AbstractTensorHandle* loss = temp_outputs[0];
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"scalarMul0", registry)); // Compute eta*A TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
absl::MakeSpan(temp_outputs),
"scalarMul0")); // Compute eta*A
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1); std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
delete tape; delete tape;
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1); std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"mul0", registry)); // Compute x*y TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
absl::MakeSpan(temp_outputs),
"mul0")); // Compute x*y
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
delete tape; delete tape;
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2); std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss", TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
registry)); tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
outputs[0] = temp_outputs[0]; // loss values outputs[0] = temp_outputs[0]; // loss values

View File

@ -29,45 +29,10 @@ limitations under the License.
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
// ========================== Tape Ops ==============================
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
namespace internal { namespace internal {
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` and records it on the tape.
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// ====================== End Tape Ops ============================
// Computes // Computes
// y = inputs[0] + inputs[1] // y = inputs[0] + inputs[1]

View File

@ -100,7 +100,8 @@ class VSpace {
std::vector<Gradient*>* result) const = 0; std::vector<Gradient*>* result) const = 0;
// Builds a tensor filled with ones with the same shape and dtype as `t`. // Builds a tensor filled with ones with the same shape and dtype as `t`.
virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0; virtual Status BuildOnesLike(const TapeTensor& t,
Gradient** result) const = 0;
// Looks up the ID of a Gradient. // Looks up the ID of a Gradient.
virtual int64 TensorId(Gradient* tensor) const = 0; virtual int64 TensorId(Gradient* tensor) const = 0;

View File

@ -29,6 +29,7 @@ cc_library(
}), }),
deps = [ deps = [
"//tensorflow/c:env", "//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"//third_party/hadoop:hdfs", "//third_party/hadoop:hdfs",

View File

@ -22,11 +22,10 @@ limitations under the License.
#include <sstream> #include <sstream>
#include <string> #include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/env.h" #include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
// Implementation of a filesystem for HADOOP environments. // Implementation of a filesystem for HADOOP environments.
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes. // This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
@ -149,15 +148,20 @@ class LibHDFS {
char* hdfs_home = getenv("HADOOP_HDFS_HOME"); char* hdfs_home = getenv("HADOOP_HDFS_HOME");
if (hdfs_home != nullptr) { if (hdfs_home != nullptr) {
auto JoinPath = [](std::string home, std::string lib) { auto JoinPath = [](std::string home, std::string lib) {
#if defined(_WIN32)
if (home.back() != '\\') home.push_back('\\');
return home + "lib\\native\\" + lib;
#else
if (home.back() != '/') home.push_back('/'); if (home.back() != '/') home.push_back('/');
return home + "lib/native/" + lib; return home + "lib/native/" + lib;
#endif
}; };
std::string path = JoinPath(hdfs_home, kLibHdfsDso); std::string path = JoinPath(hdfs_home, kLibHdfsDso);
TryLoadAndBind(path.c_str(), &handle_, status); TryLoadAndBind(path.c_str(), &handle_, status);
if (TF_GetCode(status) == TF_OK) { if (TF_GetCode(status) == TF_OK) {
return; return;
} else { } else {
std::cerr << "HadoopFileSystem load error: " << TF_Message(status); TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
} }
} }
@ -169,13 +173,15 @@ class LibHDFS {
void* handle_; void* handle_;
}; };
// We rely on HDFS connection caching here. The HDFS client calls // We implement connection caching in Tensorflow, which can significantly
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection // improve performance. Fixes #43187
// internally. hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) { const std::string& path, TF_Status* status) {
auto libhdfs = hadoop_file->libhdfs;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
std::string cacheKey(scheme);
hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
if (scheme == "file") { if (scheme == "file") {
libhdfs->hdfsBuilderSetNameNode(builder, nullptr); libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
@ -200,15 +206,24 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
SplitArchiveNameAndPath(&path_har, &namenode, status); SplitArchiveNameAndPath(&path_har, &namenode, status);
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
cacheKey += namenode;
} else { } else {
libhdfs->hdfsBuilderSetNameNode( libhdfs->hdfsBuilderSetNameNode(
builder, namenode.empty() ? "default" : namenode.c_str()); builder, namenode.empty() ? "default" : namenode.c_str());
cacheKey += namenode;
} }
auto fs = libhdfs->hdfsBuilderConnect(builder); absl::MutexLock l(&hadoop_file->connection_cache_lock);
if (fs == nullptr) if (hadoop_file->connection_cache.find(cacheKey) ==
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); hadoop_file->connection_cache.end()) {
else auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
TF_SetStatus(status, TF_OK, ""); if (cacheFs == nullptr) {
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
return cacheFs;
}
hadoop_file->connection_cache[cacheKey] = cacheFs;
}
auto fs = hadoop_file->connection_cache[cacheKey];
TF_SetStatus(status, TF_OK, "");
return fs; return fs;
} }
@ -409,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` // SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_read_only_memory_region { namespace tf_read_only_memory_region {
// Hadoop doesn't support Readonly Memory Region
// TODO(vnvo2409): Implement later
} // namespace tf_read_only_memory_region } // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem // SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_hadoop_filesystem { namespace tf_hadoop_filesystem {
HadoopFile::HadoopFile(TF_Status* status)
: libhdfs(new LibHDFS(status)),
connection_cache_lock(),
connection_cache() {}
void Init(TF_Filesystem* filesystem, TF_Status* status) { void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new LibHDFS(status); filesystem->plugin_filesystem = new HadoopFile(status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
void Cleanup(TF_Filesystem* filesystem) { void Cleanup(TF_Filesystem* filesystem) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
delete libhdfs; delete libhdfs;
delete hadoop_file;
} }
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) { TF_RandomAccessFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -448,8 +469,9 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
void NewWritableFile(const TF_Filesystem* filesystem, const char* path, void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) { TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -465,8 +487,9 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) { TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -497,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
void PathExists(const TF_Filesystem* filesystem, const char* path, void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -513,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path,
void Stat(const TF_Filesystem* filesystem, const char* path, void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) { TF_FileStatistics* stats, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -532,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return -1; if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -553,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
void DeleteFile(const TF_Filesystem* filesystem, const char* path, void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -568,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
void CreateDir(const TF_Filesystem* filesystem, const char* path, void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -583,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
void DeleteDir(const TF_Filesystem* filesystem, const char* path, void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) { TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -619,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
void RenameFile(const TF_Filesystem* filesystem, const char* src, void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) { const char* dst, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, src, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, src, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst; std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
@ -640,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
int GetChildren(const TF_Filesystem* filesystem, const char* path, int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) { char*** entries, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem); auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status); auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return -1; if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path; std::string scheme, namenode, hdfs_path;
@ -677,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
return num_entries; return num_entries;
} }
// TODO(vnvo2409): Implement later static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
} // namespace tf_hadoop_filesystem } // namespace tf_hadoop_filesystem
@ -685,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) { const char* uri) {
TF_SetFilesystemVersionMetadata(ops); TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri); ops->scheme = strdup(uri);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_hadoop_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file =
tf_hadoop_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_hadoop_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
} }
void TF_InitPlugin(TF_FilesystemPluginInfo* info) { void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -15,10 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#include <map>
#include <string> #include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
void ParseHadoopPath(const std::string& fname, std::string* scheme, void ParseHadoopPath(const std::string& fname, std::string* scheme,
std::string* namenode, std::string* path); std::string* namenode, std::string* path);
@ -43,6 +46,14 @@ void Close(const TF_WritableFile* file, TF_Status* status);
} // namespace tf_writable_file } // namespace tf_writable_file
namespace tf_hadoop_filesystem { namespace tf_hadoop_filesystem {
typedef struct HadoopFile {
LibHDFS* libhdfs;
absl::Mutex connection_cache_lock;
std::map<std::string, hdfsFS> connection_cache
ABSL_GUARDED_BY(connection_cache_lock);
HadoopFile(TF_Status* status);
} HadoopFile;
void Init(TF_Filesystem* filesystem, TF_Status* status); void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem); void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,

View File

@ -352,6 +352,48 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) {
EXPECT_TF_OK(status_); EXPECT_TF_OK(status_);
} }
TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) {
static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1";
putenv(set_disable_var);
const std::string path = TmpDir("ReadWhileOverwriting");
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
const string content1 = "content1";
WriteString(path, content1);
ASSERT_TF_OK(status_);
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content1.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content1, result);
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
string content2 = "overwrite";
WriteString(path, content1 + content2);
ASSERT_TF_OK(status_);
result.resize(content2.size());
read = tf_random_access_file::Read(reader.get(), content1.size(),
content2.size(), &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(0, result.size());
static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0";
putenv(set_enable_var);
}
TEST_F(HadoopFileSystemTest, HarSplit) { TEST_F(HadoopFileSystemTest, HarSplit) {
const std::string har_path = const std::string har_path =
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt"; "har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";

View File

@ -24,6 +24,7 @@ using std::vector;
using tensorflow::ops::Conj; using tensorflow::ops::Conj;
using tensorflow::ops::MatMul; using tensorflow::ops::MatMul;
using tensorflow::ops::Mul; using tensorflow::ops::Mul;
using tensorflow::ops::SqrtGrad;
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
AbstractTensorHandlePtr exp_; AbstractTensorHandlePtr exp_;
}; };
class SqrtGradientFunction : public GradientFunction {
public:
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
sqrt->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
std::string name = "Sqrt_Grad";
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~SqrtGradientFunction() override {}
private:
AbstractTensorHandlePtr sqrt_;
};
class MatMulGradientFunction : public GradientFunction { class MatMulGradientFunction : public GradientFunction {
public: public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs, explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients); return new BackwardFunction(gradient_function, default_gradients);
} }
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow

View File

@ -19,9 +19,12 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
BackwardFunction* AddRegisterer(const ForwardOperation& op); BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op); BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op); BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow

View File

@ -38,3 +38,29 @@ cc_library(
"//tensorflow/c/eager:gradients_internal", "//tensorflow/c/eager:gradients_internal",
], ],
) )
cc_library(
name = "tape",
hdrs = [
"tape_context.h",
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tape_context",
":tape_operation",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"tape_context.h",
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
)

View File

@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return exp_op->Execute(outputs, &num_retvals); return exp_op->Execute(outputs, &num_retvals);
} }
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
int num_retvals = 1;
Status s = sqrt_op->Execute(outputs, &num_retvals);
return s;
}
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
int num_retvals = 1;
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
return s;
}
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name); absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops } // namespace ops
} // namespace tensorflow } // namespace tensorflow

View File

@ -91,15 +91,24 @@ cc_library(
":signature_def_function_metadata", ":signature_def_function_metadata",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
cc_library( cc_library(
name = "signature_def_function_metadata", name = "signature_def_function_metadata",
srcs = [
"signature_def_function_metadata.cc",
],
hdrs = [ hdrs = [
"signature_def_function_metadata.h", "signature_def_function_metadata.h",
], ],
deps = [
":tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
) )
cc_library( cc_library(
@ -268,6 +277,20 @@ tf_cc_test(
], ],
) )
cc_library(
name = "tensor_spec",
srcs = [
"tensor_spec.cc",
],
hdrs = [
"tensor_spec.h",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test( tf_cc_test(
name = "tf_concrete_function_loading_test", name = "tf_concrete_function_loading_test",
srcs = [ srcs = [

View File

@ -92,6 +92,8 @@ cc_library(
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
@ -164,6 +166,8 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
#include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <utility> #include <utility>
#include "absl/types/span.h" #include "absl/types/span.h"
@ -30,14 +32,26 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
using StructuredValueDictEntry =
protobuf::MapPair<std::string, StructuredValue>;
using NamedParamMap =
gtl::FlatMap<StringPiece, const TensorSpecProto*, StringPieceHasher>;
Status AssertAllCreateResourceFunctionsHaveNoCaptures( Status AssertAllCreateResourceFunctionsHaveNoCaptures(
const PartiallyRevivedObjects& objects) { const PartiallyRevivedObjects& objects) {
for (const auto& id_and_resource : objects.restored_resources) { for (const auto& id_and_resource : objects.restored_resources) {
@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
} }
} }
std::vector<SignatureDefParam> SignatureDefParamsFromNamedParamMap(
const NamedParamMap& params) {
// The underlying functiondef associated with the SignatureDef has
// nest.flattened inputs and outputs, which are sorted by string key.
std::vector<SignatureDefParam> result;
result.reserve(params.size());
for (const auto& named_param : params) {
result.push_back(SignatureDefParam(std::string(named_param.first),
TensorSpec(*named_param.second)));
}
std::sort(result.begin(), result.end(),
[](const SignatureDefParam& x, const SignatureDefParam& y) {
return x.name() < y.name();
});
return result;
}
// SignatureDefArgsFromInputs takes the "canonicalized_input_signature"
// field of a SavedConcreteFunction, ensures it conforms to the structure of
// tuple(tuple(), dict<string,TensorSpec>()), and "returns" a list of
// SignatureDefParams of the SignatureDefFunction's arguments.
Status SignatureDefArgsFromInputs(
const StructuredValue& canonicalized_input_signature,
std::vector<SignatureDefParam>* out) {
// Note(bmzhao): canonicalized_input_signature should be a tuple of
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
// string keys to TensorSpecs.
if (!canonicalized_input_signature.has_tuple_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature should be "
"of form tuple(tuple(), dict()), but was instead: \n",
canonicalized_input_signature.DebugString());
}
const TupleValue& args_kwargs_tuple =
canonicalized_input_signature.tuple_value();
if (args_kwargs_tuple.values_size() != 2) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature should be "
"a tuple of two elements (args, kwargs), but was instead: \n",
args_kwargs_tuple.DebugString());
}
const StructuredValue& args = args_kwargs_tuple.values(0);
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's args"
"should be an empty tuple, but instead got: \n",
args.DebugString());
}
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
if (!kwargs.has_dict_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"should be a dictionary, but instead got: \n",
kwargs.DebugString());
}
const DictValue& kwargs_dict = kwargs.dict_value();
NamedParamMap result;
result.reserve(kwargs_dict.fields_size());
for (const auto& key_value : kwargs_dict.fields()) {
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"dictionary contained a non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
}
result[key] = &value.tensor_spec_value();
}
*out = SignatureDefParamsFromNamedParamMap(result);
return Status();
}
// SignatureDefReturnsFromOutputs takes the "output_signature" field of a
// SavedConcreteFunction, ensures it conforms to the structure of
// dict<string,TensorSpec>(), and "returns" a list of SignatureDefParams of the
// SignatureDefFunction's returns.
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
std::vector<SignatureDefParam>* out) {
if (!output_signature.has_dict_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's output_signature must be a dictionary, but "
"instead got: ",
output_signature.DebugString());
}
const DictValue& output_dict = output_signature.dict_value();
NamedParamMap result;
result.reserve(output_dict.fields_size());
for (const auto& key_value : output_dict.fields()) {
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's output_signature dictionary contained a "
"non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
}
result[key] = &value.tensor_spec_value();
}
*out = SignatureDefParamsFromNamedParamMap(result);
return Status();
}
// The implementation takes advantage of the fact that SignatureDefFunction's
// "traced" Signature wrapper function always has inputs/outputs of dictionaries
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178
// Additionally, we take advantage of the fact that the SignatureDefFunction's
// associated functiondef has lexicographically ordered inputs/outputs due to
// nest.flatten.
Status LoadSignatureDefFunctionMetadata(
const SavedConcreteFunction& saved_concrete_function,
SignatureDefFunctionMetadata* out) {
std::vector<SignatureDefParam> args;
TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs(
saved_concrete_function.canonicalized_input_signature(), &args));
std::vector<SignatureDefParam> rets;
TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs(
saved_concrete_function.output_signature(), &rets));
*out = SignatureDefFunctionMetadata(std::move(args), std::move(rets));
return Status();
}
// This function finds the necessary captures, then forwards to the builder // This function finds the necessary captures, then forwards to the builder
// method // method
Status CreateConcreteFunction(ImmediateExecutionContext* ctx, Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
@ -162,10 +312,14 @@ Status CreateSignatureDefFunction(
&capture_handle)); &capture_handle));
captures.push_back(capture_handle); captures.push_back(capture_handle);
} }
// TODO(bmzhao): Create Metadata here
SignatureDefFunctionMetadata metadata;
TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata(
*builder.saved_concrete_func, &metadata));
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef, return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
/*captures=*/std::move(captures), /*captures=*/std::move(captures),
/*metadata=*/{}, /*metadata=*/std::move(metadata),
/*ctx=*/ctx, /*ctx=*/ctx,
/*out=*/out); /*out=*/out);
} }
@ -378,6 +532,7 @@ Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx,
revived->variables = std::move(variables); revived->variables = std::move(variables);
revived->assets = std::move(assets); revived->assets = std::move(assets);
revived->constants = std::move(constants); revived->constants = std::move(constants);
revived->signatures_map = std::move(signatures_map);
// 3b. Move over resources. // 3b. Move over resources.
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived)); TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));

View File

@ -36,7 +36,14 @@ namespace tensorflow {
// Notably, resources and functions can be in a state where they reference // Notably, resources and functions can be in a state where they reference
// other resources/functions that have not been constructed yet. We collect // other resources/functions that have not been constructed yet. We collect
// *all* objects in a partially valid state here, then properly initialize // *all* objects in a partially valid state here, then properly initialize
// resources and functions. // resources and functions. Implementation-wise, PartiallyRevivedObjects
// contains maps keyed by the node number of the SavedObjectGraph, and map to an
// object of the corresponding type. So, if node 2 in the object graph is a
// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a
// tensorflow::Variable object. The only exception to this is the
// "signatures_map", which is keyed by the "signature" key
// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89),
// and maps to the SignatureDefFunction node in the SavedObjectGraph.
struct PartiallyRevivedObjects { struct PartiallyRevivedObjects {
gtl::FlatMap<int, std::unique_ptr<Variable>> variables; gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
gtl::FlatMap<int, std::unique_ptr<Asset>> assets; gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
@ -44,6 +51,7 @@ struct PartiallyRevivedObjects {
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions; gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions; gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources; gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
gtl::FlatMap<std::string, int> signatures_map;
Status Build(ImmediateExecutionContext* ctx, Status Build(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph, RevivedObjects* revived); const SavedObjectGraph& obj_graph, RevivedObjects* revived);

View File

@ -44,6 +44,7 @@ struct RevivedObjects {
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>> gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
signature_def_functions; signature_def_functions;
gtl::FlatMap<int, RestoredResource> restored_resources; gtl::FlatMap<int, RestoredResource> restored_resources;
gtl::FlatMap<std::string, int> signatures_map;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h" #include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out); return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
} }
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx, Status Variable::CreateUninitialized(
DataType dtype, TensorShape shape, ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, absl::optional<std::string> name, const char* raw_device_name,
const char* raw_device_name, const std::vector<std::string>& component_devices,
std::unique_ptr<Variable>* output) { std::unique_ptr<Variable>* output) {
ImmediateTensorHandlePtr handle; ImmediateTensorHandlePtr handle;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, raw_device_name, &handle));
if (component_devices.empty()) {
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, raw_device_name, &handle));
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();
}
if (!tensorflow::isa<EagerContext>(ctx)) {
return errors::InvalidArgument(
"Can only load distributed variables with EagerContext.");
}
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
std::vector<TensorHandle*> handles;
for (const auto& device : component_devices) {
ImmediateTensorHandlePtr handlePtr;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
&handlePtr));
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
return errors::Internal("Returned replica handle has unsupported type.");
}
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
}
TensorHandle* packed_handle;
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
std::move(handles), eager_ctx, &packed_handle));
// The call to `CreatePackedHandle` incremented the handles' reference count,
// which we must now decrement to make the packed handle the owner of those
// handles. We can't loop through the `handles` vector because it was
// `std::move`d in the call above.
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
TensorHandle* component;
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
component->Unref();
}
handle.reset(packed_handle);
output->reset( output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle))); new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status(); return Status();

View File

@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible {
public: public:
// Creates an uninitialized resource variable. Note that a caller must // Creates an uninitialized resource variable. Note that a caller must
// call "assign" to associate a value with the variable. // call "assign" to associate a value with the variable.
static Status CreateUninitialized(ImmediateExecutionContext* ctx, static Status CreateUninitialized(
DataType dtype, TensorShape shape, ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, absl::optional<std::string> name, const char* raw_device_name,
const char* raw_device_name, const std::vector<std::string>& component_devices,
std::unique_ptr<Variable>* output); std::unique_ptr<Variable>* output);
// The dtype of the underlying variable. // The dtype of the underlying variable.
DataType dtype(); DataType dtype();

View File

@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const std::string& name = variable.name(); const std::string& name = variable.name();
tensorflow::TensorShape shape(variable.shape()); tensorflow::TensorShape shape(variable.shape());
tensorflow::DataType dtype = variable.dtype(); tensorflow::DataType dtype = variable.dtype();
std::vector<std::string> component_devices;
for (const auto& component :
variable.experimental_distributed_variable_components()) {
component_devices.push_back(component.device());
}
TF_RETURN_IF_ERROR(Variable::CreateUninitialized( TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
ctx, dtype, shape, name, ctx, dtype, shape, name,
variable.device().empty() ? nullptr : variable.device().c_str(), output)); variable.device().empty() ? nullptr : variable.device().c_str(),
component_devices, output));
return Status(); return Status();
} }
@ -519,6 +526,8 @@ Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
} }
} }
objects->signatures_map = std::move(signatures_map);
return Status(); return Status();
} }

View File

@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
Status status; Status status;
std::unique_ptr<Variable> var; std::unique_ptr<Variable> var;
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape, TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
absl::nullopt, nullptr, &var)); absl::nullopt, nullptr, {}, &var));
// Create a TensorHandle // Create a TensorHandle
ImmediateTensorHandlePtr expected_handle = ImmediateTensorHandlePtr expected_handle =

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
namespace tensorflow {
SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec)
: name_(std::move(name)), spec_(std::move(spec)) {}
const std::string& SignatureDefParam::name() const { return name_; }
const TensorSpec& SignatureDefParam::spec() const { return spec_; }
SignatureDefFunctionMetadata::SignatureDefFunctionMetadata(
std::vector<SignatureDefParam> arguments,
std::vector<SignatureDefParam> returns)
: arguments_(std::move(arguments)), returns_(std::move(returns)) {}
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::arguments()
const {
return arguments_;
}
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::returns()
const {
return returns_;
}
} // namespace tensorflow

View File

@ -16,10 +16,42 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
#include <string>
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow { namespace tensorflow {
// SignatureDefParam represents a named Tensor input or output to a
// SignatureDefFunction.
class SignatureDefParam {
public:
SignatureDefParam(std::string name, TensorSpec spec);
const std::string& name() const;
const TensorSpec& spec() const;
private:
std::string name_;
TensorSpec spec_;
};
class SignatureDefFunctionMetadata { class SignatureDefFunctionMetadata {
// TODO(bmzhao): Fill in with fields as necessary public:
SignatureDefFunctionMetadata() = default;
SignatureDefFunctionMetadata(std::vector<SignatureDefParam> arguments,
std::vector<SignatureDefParam> returns);
const std::vector<SignatureDefParam>& arguments() const;
const std::vector<SignatureDefParam>& returns() const;
private:
std::vector<SignatureDefParam> arguments_;
std::vector<SignatureDefParam> returns_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include <initializer_list>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
TensorSpec::TensorSpec()
: shape_(std::initializer_list<int64>()), dtype_(DT_FLOAT) {}
TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype)
: shape_(std::move(shape)), dtype_(dtype) {}
TensorSpec::TensorSpec(const TensorSpecProto& proto)
: shape_(proto.shape()), dtype_(proto.dtype()) {}
const PartialTensorShape& TensorSpec::shape() const { return shape_; }
DataType TensorSpec::dtype() const { return dtype_; }
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
// Note(bmzhao): TensorSpec deliberately does not store the "name" from a
// TensorSpecProto. From edloper@, "Names should really be associated with
// parameters, not the tensors inside those parameters. This would be
// inconsistent with the corresponding Python class, but I don't think that's
// necessarily a problem. If it turns out later that we really need a name
// attribute here, we can always add it back in; but let's see how far we can
// get without it."
class TensorSpec {
public:
// Constructs a scalar, DT_FLOAT TensorSpec
TensorSpec();
TensorSpec(PartialTensorShape shape, DataType dtype);
explicit TensorSpec(const TensorSpecProto& proto);
const PartialTensorShape& shape() const;
DataType dtype() const;
private:
PartialTensorShape shape_;
DataType dtype_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_

View File

@ -192,9 +192,23 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
Status TFSavedModelAPI::GetSignatureDefFunction( Status TFSavedModelAPI::GetSignatureDefFunction(
const std::string& signature_def_key, SignatureDefFunction** function) { const std::string& signature_def_key, SignatureDefFunction** function) {
// TODO(bmzhao): Add support for retrieving a signaturedef function. auto signatures_iter =
return errors::Unimplemented( revived_objects_.signatures_map.find(signature_def_key);
"Retrieving SignatureDef functions is unimplemented currently"); if (signatures_iter == revived_objects_.signatures_map.end()) {
return errors::NotFound("No signature with key ", signature_def_key,
" was found");
}
int node = signatures_iter->second;
auto function_iter = revived_objects_.signature_def_functions.find(node);
if (function_iter == revived_objects_.signature_def_functions.end()) {
return errors::Internal(
"Unable to find SignatureDefFunction associated with key ",
signature_def_key, " despite key being valid.");
}
*function = function_iter->second.get();
return Status();
} }
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() { std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {

View File

@ -224,6 +224,8 @@ cc_library(
], ],
deps = [ deps = [
":signature_def_function_metadata_type", ":signature_def_function_metadata_type",
":signature_def_param_list",
":signature_def_param_list_type",
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
], ],
@ -240,6 +242,104 @@ cc_library(
], ],
) )
cc_library(
name = "signature_def_param",
srcs = [
"signature_def_param.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:signature_def_param.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":signature_def_param_type",
":tensor_spec",
":tensor_spec_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_shape_internal",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
cc_library(
name = "signature_def_param_type",
hdrs = [
"signature_def_param_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
cc_library(
name = "signature_def_param_list",
srcs = [
"signature_def_param_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":signature_def_param",
":signature_def_param_list_type",
":signature_def_param_type",
"//tensorflow/c:c_api_macros",
],
)
cc_library(
name = "signature_def_param_list_type",
hdrs = [
"signature_def_param_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
cc_library(
name = "tensor_spec",
srcs = [
"tensor_spec.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensor_spec.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensor_spec_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_shape",
"//tensorflow/c:tf_shape_internal",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
],
)
cc_library(
name = "tensor_spec_type",
hdrs = [
"tensor_spec_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_shape_internal",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
],
)
tf_cc_test( tf_cc_test(
name = "saved_model_api_test", name = "saved_model_api_test",
size = "small", size = "small",
@ -252,6 +352,8 @@ tf_cc_test(
], ],
deps = [ deps = [
":saved_model_api_type", ":saved_model_api_type",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_shape",
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor", "//tensorflow/c:tf_tensor",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
@ -260,6 +362,11 @@ tf_cc_test(
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api", "//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
"//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api", "//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/c/experimental/saved_model/public:signature_def_function",
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata",
"//tensorflow/c/experimental/saved_model/public:signature_def_param",
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list",
"//tensorflow/c/experimental/saved_model/public:tensor_spec",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",

View File

@ -24,6 +24,13 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_shape.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
@ -143,6 +150,146 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TFE_DeleteContext(ctx); TFE_DeleteContext(ctx);
} }
// This tests running the "serving_default" SignatureDefFunction from the
// VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
// protobuf in the metagraph looks like:
// signature_def: {
// key : "serving_default"
// value: {
// inputs: {
// key : "a"
// value: {
// name : "serving_default_a:0"
// dtype: DT_FLOAT
// tensor_shape: {
// }
// }
// }
// inputs: {
// key : "b"
// value: {
// name : "serving_default_b:0"
// dtype: DT_FLOAT
// tensor_shape: {
// }
// }
// }
// outputs: {
// key : "output_0"
// value: {
// name : "StatefulPartitionedCall:0"
// dtype: DT_FLOAT
// tensor_shape: {
// }
// }
// }
// method_name: "tensorflow/serving/predict"
// }
// }
TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_SignatureDefFunction* serving_default =
TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_SignatureDefFunctionMetadata* metadata =
TF_SignatureDefFunctionGetMetadata(serving_default);
const TF_SignatureDefParamList* args =
TF_SignatureDefFunctionMetadataArgs(metadata);
const TF_SignatureDefParamList* returns =
TF_SignatureDefFunctionMetadataReturns(metadata);
EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
// Input "a" is a scalar, float32 tensor
EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
EXPECT_EQ(0, TF_ShapeDims(shape_a));
const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
// Input "b" is a scalar, float32 tensor
EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
EXPECT_EQ(0, TF_ShapeDims(shape_b));
EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
const TF_SignatureDefParam* param_out =
TF_SignatureDefParamListGet(returns, 0);
const TF_TensorSpec* tensor_spec_out =
TF_SignatureDefParamTensorSpec(param_out);
const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
// Output "output_0" is a scalar, float32 tensor
EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
EXPECT_EQ(0, TF_ShapeDims(shape_out));
std::vector<TFE_TensorHandle*> compute_fn_inputs;
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
compute_fn_inputs.push_back(input_a);
compute_fn_inputs.push_back(input_b);
TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
std::vector<TFE_TensorHandle*> compute_fn_outputs(
TF_SignatureDefParamListSize(returns));
int num_retvals = TF_SignatureDefParamListSize(returns);
TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
float output_value = *static_cast<float*>(TF_TensorData(result));
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
EXPECT_FLOAT_EQ(output_value, 8.0);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
TFE_DeleteTensorHandle(input_a);
TFE_DeleteTensorHandle(input_b);
TFE_DeleteOp(serving_default_op);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) { TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -16,5 +16,18 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h" #include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
// TODO(bmzhao): Add getter functions here as necessary. extern "C" {
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataArgs(
const TF_SignatureDefFunctionMetadata* list) {
return tensorflow::wrap(&tensorflow::unwrap(list)->arguments());
}
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataReturns(
const TF_SignatureDefFunctionMetadata* list) {
return tensorflow::wrap(&tensorflow::unwrap(list)->returns());
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
extern "C" {
extern const char* TF_SignatureDefParamName(const TF_SignatureDefParam* param) {
return tensorflow::unwrap(param)->name().c_str();
}
extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
const TF_SignatureDefParam* param) {
return tensorflow::wrap(&tensorflow::unwrap(param)->spec());
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
extern "C" {
extern size_t TF_SignatureDefParamListSize(
const TF_SignatureDefParamList* list) {
return tensorflow::unwrap(list)->size();
}
extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
const TF_SignatureDefParamList* list, int i) {
return tensorflow::wrap(&tensorflow::unwrap(list)->at(i));
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(std::vector<SignatureDefParam>,
TF_SignatureDefParamList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
#include "tensorflow/c/tf_shape_internal.h"
extern "C" {
TF_DataType TF_TensorSpecDataType(const TF_TensorSpec* spec) {
return static_cast<TF_DataType>(tensorflow::unwrap(spec)->dtype());
}
const TF_Shape* TF_TensorSpecShape(const TF_TensorSpec* spec) {
return tensorflow::wrap(&tensorflow::unwrap(spec)->shape());
}
} // end extern "C"

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
typedef struct TF_TensorSpec TF_TensorSpec;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_

View File

@ -28,6 +28,9 @@ exports_files(
"saved_model_api.h", "saved_model_api.h",
"signature_def_function.h", "signature_def_function.h",
"signature_def_function_metadata.h", "signature_def_function_metadata.h",
"signature_def_param.h",
"signature_def_param_list.h",
"tensor_spec.h",
], ],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
) )
@ -45,6 +48,9 @@ cc_library(
":saved_model_api", ":saved_model_api",
":signature_def_function", ":signature_def_function",
":signature_def_function_metadata", ":signature_def_function_metadata",
":signature_def_param",
":signature_def_param_list",
":tensor_spec",
], ],
) )
@ -77,3 +83,18 @@ alias(
name = "signature_def_function_metadata", name = "signature_def_function_metadata",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata", actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
) )
alias(
name = "signature_def_param",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param",
)
alias(
name = "signature_def_param_list",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param_list",
)
alias(
name = "tensor_spec",
actual = "//tensorflow/c/experimental/saved_model/internal:tensor_spec",
)

View File

@ -23,6 +23,9 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" #include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif // __cplusplus #endif // __cplusplus
@ -24,6 +27,18 @@ extern "C" {
// SavedModel. // SavedModel.
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
// Retrieves the arguments of the SignatureDefFunction. The caller is not
// responsible for freeing the returned pointer.
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
TF_SignatureDefFunctionMetadataArgs(
const TF_SignatureDefFunctionMetadata* list);
// Retrieves the returns of the SignatureDefFunction. The caller is not
// responsible for freeing the returned pointer.
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
TF_SignatureDefFunctionMetadataReturns(
const TF_SignatureDefFunctionMetadata* list);
#ifdef __cplusplus #ifdef __cplusplus
} // end extern "C" } // end extern "C"
#endif // __cplusplus #endif // __cplusplus

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that containing metadata of an input/output of a
// TF_SignatureDefFunction loaded from a SavedModel.
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
// Returns the name of the given parameter. The caller is not responsible for
// freeing the returned char*.
TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName(
const TF_SignatureDefParam* param);
// Returns the TensorSpec associated with the given parameter. The caller is
// not reponsible for freeing the returned TF_TensorSpec*.
TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
const TF_SignatureDefParam* param);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that containing metadata of an input/output of a
// ConcreteFunction loaded from a SavedModel.
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize(
const TF_SignatureDefParamList* list);
// Returns the `i`th TF_SignatureDefParam in the list.
TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
const TF_SignatureDefParamList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_shape.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type corresponding to TensorSpec
typedef struct TF_TensorSpec TF_TensorSpec;
// Returns the dtype associated with the TensorSpec.
TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType(
const TF_TensorSpec* spec);
// Returns the shape associated with the TensorSpec. The returned Shape is not
// owned by the caller. Caller must not call TF_DeleteShape on the returned
// shape.
TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape(
const TF_TensorSpec* spec);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_

View File

@ -22,6 +22,8 @@ cc_library(
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:regexp",
"//tensorflow/core/platform:strcat",
"//tensorflow/stream_executor:executor_cache", "//tensorflow/stream_executor:executor_cache",
"//tensorflow/stream_executor:multi_platform_manager", "//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:platform", "//tensorflow/stream_executor:platform",

View File

@ -27,7 +27,10 @@ limitations under the License.
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
@ -39,6 +42,8 @@ limitations under the License.
using tensorflow::StatusFromTF_Status; using tensorflow::StatusFromTF_Status;
namespace stream_executor { namespace stream_executor {
using tensorflow::StringPiece;
namespace { namespace {
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
@ -58,10 +63,35 @@ namespace {
} \ } \
} while (0) } while (0)
port::Status ValidateDeviceType(StringPiece type) {
// Validate device type. Device type must start with a capital letter and
// consist of capital letters and underscores. Reasoning behind this decision:
// * At the minimum we want to disallow '/' and ':' since
// these characters are used in device spec, for e.g.
// /job:foo/replica:12/device:GPU:1.
// * Underscores seem useful, for e.g. XLA_GPU uses underscores.
// * Allowing lowercase might get confusing. For example, say someone
// registers a new type called "Gpu". It might be confusing for users that
// "Gpu" is not the same device type as "GPU".
// Note that lowercase "cpu" and "gpu" are currently supported only for
// legacy reasons:
// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
if (!matches) {
return port::FailedPreconditionError(
tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
kTfDeviceTypeRegEx->pattern(), "."));
}
return port::Status::OK();
}
port::Status ValidateSPPlatform(const SP_Platform& platform) { port::Status ValidateSPPlatform(const SP_Platform& platform) {
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE); VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
VALIDATE_MEMBER(SP_Platform, platform, name); VALIDATE_MEMBER(SP_Platform, platform, name);
VALIDATE_MEMBER(SP_Platform, platform, type); VALIDATE_MEMBER(SP_Platform, platform, type);
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
// `visible_device_count` could be 0 at initialization time. // `visible_device_count` could be 0 at initialization time.
return port::Status::OK(); return port::Status::OK();
} }

View File

@ -52,7 +52,7 @@ limitations under the License.
// params.device = &device; // params.device = &device;
// //
// /* Plugin code below */ // /* Plugin code below */
// constexpr char DEVICE_NAME[] = "MyDevice"; // constexpr char DEVICE_NAME[] = "MY_DEVICE";
// constexpr char DEVICE_TYPE[] = "GPU"; // constexpr char DEVICE_TYPE[] = "GPU";
// //
// void create_device(const SP_Platform* platform, // void create_device(const SP_Platform* platform,
@ -416,10 +416,15 @@ typedef struct SP_Platform {
void* ext; // free-form data set by plugin void* ext; // free-form data set by plugin
// Platform name. Must be null-terminated. // Platform name (also referred to as subtype), for example MY_DEVICE.
// The name must start with a capital letter and consist of
// capital letters and underscores.
// Must be null-terminated.
const char* name; const char* name;
// Device type name, for example GPU. Must be null-terminated. // Device type name, for example GPU. Must be null-terminated.
// The name must start with a capital letter and consist of
// capital letters and underscores.
const char* type; const char* type;
// Number of visible devices // Number of visible devices

View File

@ -41,9 +41,9 @@ struct SP_Timer_st {
namespace stream_executor { namespace stream_executor {
namespace { namespace {
constexpr int DEVICE_COUNT = 2; constexpr int kDeviceCount = 2;
constexpr char DEVICE_NAME[] = "MyDevice"; constexpr char kDeviceName[] = "MY_DEVICE";
constexpr char DEVICE_TYPE[] = "GPU"; constexpr char kDeviceType[] = "GPU";
/*** Create SP_StreamExecutor (with empty functions) ***/ /*** Create SP_StreamExecutor (with empty functions) ***/
void allocate(const SP_Device* const device, uint64_t size, void allocate(const SP_Device* const device, uint64_t size,
@ -190,9 +190,9 @@ void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
void PopulateDefaultPlatform(SP_Platform* platform, void PopulateDefaultPlatform(SP_Platform* platform,
SP_PlatformFns* platform_fns) { SP_PlatformFns* platform_fns) {
*platform = {SP_PLATFORM_STRUCT_SIZE}; *platform = {SP_PLATFORM_STRUCT_SIZE};
platform->name = DEVICE_NAME; platform->name = kDeviceName;
platform->type = DEVICE_TYPE; platform->type = kDeviceType;
platform->visible_device_count = DEVICE_COUNT; platform->visible_device_count = kDeviceCount;
platform_fns->create_device = create_device; platform_fns->create_device = create_device;
platform_fns->destroy_device = destroy_device; platform_fns->destroy_device = destroy_device;
platform_fns->create_device_fns = create_device_fns; platform_fns->create_device_fns = create_device_fns;
@ -218,11 +218,11 @@ TEST(StreamExecutor, SuccessfulRegistration) {
port::Status status = InitStreamExecutorPlugin(plugin_init); port::Status status = InitStreamExecutorPlugin(plugin_init);
TF_ASSERT_OK(status); TF_ASSERT_OK(status);
port::StatusOr<Platform*> maybe_platform = port::StatusOr<Platform*> maybe_platform =
MultiPlatformManager::PlatformWithName("MyDevice"); MultiPlatformManager::PlatformWithName("MY_DEVICE");
TF_ASSERT_OK(maybe_platform.status()); TF_ASSERT_OK(maybe_platform.status());
Platform* platform = maybe_platform.ConsumeValueOrDie(); Platform* platform = maybe_platform.ConsumeValueOrDie();
ASSERT_EQ(platform->Name(), DEVICE_NAME); ASSERT_EQ(platform->Name(), kDeviceName);
ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT); ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount);
port::StatusOr<StreamExecutor*> maybe_executor = port::StatusOr<StreamExecutor*> maybe_executor =
platform->ExecutorForDevice(0); platform->ExecutorForDevice(0);
@ -244,6 +244,39 @@ TEST(StreamExecutor, NameNotSet) {
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
} }
TEST(StreamExecutor, InvalidNameWithSemicolon) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->name = "INVALID:NAME";
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(
status.error_message(),
testing::ContainsRegex("Device name/type 'INVALID:NAME' must match"));
}
TEST(StreamExecutor, InvalidNameWithSlash) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->name = "INVALID/";
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(status.error_message(),
testing::ContainsRegex("Device name/type 'INVALID/' must match"));
}
TEST(StreamExecutor, CreateDeviceNotSet) { TEST(StreamExecutor, CreateDeviceNotSet) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params, auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void { TF_Status* const status) -> void {

View File

@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) { TF_Status* status) {
mutex_lock l(graph->mu); TF_UpdateEdge(graph, new_src, dst, status);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
} }
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
@ -136,6 +100,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
auto* out_shape_and_type = handle_data.add_shape_and_type(); auto* out_shape_and_type = handle_data.add_shape_and_type();
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
out_shape_and_type->set_dtype(p.dtype); out_shape_and_type->set_dtype(p.dtype);
out_shape_and_type->set_specialized_type(p.specialized_type);
} }
} }
string result; string result;
@ -163,7 +128,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
status->status = status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
shape_and_type_proto.specialized_type());
} }
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
} }

39
tensorflow/c/tf_shape.cc Normal file
View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/tf_shape.h"
#include <stdint.h>
#include "tensorflow/c/tf_shape_internal.h"
#include "tensorflow/core/framework/tensor_shape.h"
extern "C" {
TF_Shape* TF_NewShape() {
return tensorflow::wrap(new tensorflow::PartialTensorShape());
}
int TF_ShapeDims(const TF_Shape* shape) {
return tensorflow::unwrap(shape)->dims();
}
int64_t TF_ShapeDimSize(const TF_Shape* shape, int d) {
return tensorflow::unwrap(shape)->dim_size(d);
}
void TF_DeleteShape(TF_Shape* shape) { delete tensorflow::unwrap(shape); }
} // end extern "C"

50
tensorflow/c/tf_shape.h Normal file
View File

@ -0,0 +1,50 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include "tensorflow/c/c_api_macros.h"
#ifndef TENSORFLOW_C_TF_SHAPE_H_
#define TENSORFLOW_C_TF_SHAPE_H_
#ifdef __cplusplus
extern "C" {
#endif
// An opaque type corresponding to a shape in tensorflow. In the future,
// we may expose the ABI of TF_Shape for performance reasons.
typedef struct TF_Shape TF_Shape;
// Return a new, unknown rank shape object. The caller is responsible for
// calling TF_DeleteShape to deallocate and destroy the returned shape.
TF_CAPI_EXPORT extern TF_Shape* TF_NewShape();
// Returns the rank of `shape`. If `shape` has unknown rank, returns -1.
TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape);
// Returns the `d`th dimension of `shape`. If `shape` has unknown rank,
// invoking this function is undefined behavior. Returns -1 if dimension is
// unknown.
TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d);
// Deletes `shape`.
TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_SHAPE_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/core/framework/tensor_shape.h"
typedef struct TF_Shape TF_Shape;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape);
}
#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_

View File

@ -251,7 +251,6 @@ cc_library_with_android_deps(
deps = [ deps = [
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
], ],
) )
@ -266,7 +265,6 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",

View File

@ -15,13 +15,12 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
namespace tensorflow { namespace tensorflow {
namespace ops { namespace ops {
namespace { namespace {
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
} }
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad); REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op, Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
const std::vector<Output>& grad_inputs, const Operation& op,
std::vector<Output>* grad_outputs) { const std::vector<Output>& grad_inputs,
grad_outputs->push_back(Identity(scope, grad_inputs[0])); std::vector<Output>* grad_outputs) {
grad_outputs->push_back(NoGradient()); Input input = Shape(scope, op.input(0));
grad_outputs->push_back(NoGradient()); Input input_min = op.input(1);
Input input_max = op.input(2);
int64 axis;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
scope, grad_inputs[0], input, input_min, input_max,
QuantizeAndDequantizeV4Grad::Axis(axis));
grad_outputs->push_back(qdq_v4_grad.input_backprop);
grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
return scope.status(); return scope.status();
} }
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad); REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
QuantizeAndDequantizeV4GradHelper);
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op, Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs, const std::vector<Output>& grad_inputs,

View File

@ -21,10 +21,7 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
exports_files([ exports_files(["loader.h"])
"LICENSE",
"loader.h",
])
cc_library( cc_library(
name = "constants", name = "constants",
@ -45,13 +42,15 @@ cc_library(
name = "reader", name = "reader",
srcs = ["reader.cc"], srcs = ["reader.cc"],
hdrs = ["reader.h"], hdrs = ["reader.h"],
deps = [":constants"] + if_not_mobile([ deps = [
":constants",
"//tensorflow/core:protos_all_cc",
] + if_not_mobile([
# TODO(b/111634734): :lib and :protos_all contain dependencies that # TODO(b/111634734): :lib and :protos_all contain dependencies that
# cannot be built on mobile platforms. Instead, include the appropriate # cannot be built on mobile platforms. Instead, include the appropriate
# tf_lib depending on the build platform. # tf_lib depending on the build platform.
"@com_google_absl//absl/memory:memory", "@com_google_absl//absl/memory:memory",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
]), ]),
) )

View File

@ -12,8 +12,6 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
exports_files(["LICENSE"])
cc_library( cc_library(
name = "freeze_saved_model", name = "freeze_saved_model",
srcs = ["freeze_saved_model.cc"], srcs = ["freeze_saved_model.cc"],

View File

@ -75,7 +75,7 @@ cc_library(
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:Target", "@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep "@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal", "//tensorflow/core/platform:regexp",
] + if_llvm_system_z_available([ ] + if_llvm_system_z_available([
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep "@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
]) + if_llvm_aarch64_available([ ]) + if_llvm_aarch64_available([

View File

@ -336,9 +336,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:regexp",
"//third_party/eigen3", "//third_party/eigen3",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
@ -559,9 +559,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:regexp",
"//third_party/eigen3", "//third_party/eigen3",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -127,7 +127,7 @@ def tf_library(
"$(location " + tfcompile_tool + ")" + "$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" + " --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"), " --dump_fetch_nodes > $@"),
tools = [tfcompile_tool], exec_tools = [tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's # Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine. # typically way faster on the local machine.
local = 1, local = 1,
@ -242,7 +242,7 @@ def tf_library(
" --out_function_object=$(@D)/" + function_object_file + " --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag " " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
), ),
tools = [tfcompile_tool], exec_tools = [tfcompile_tool],
visibility = visibility, visibility = visibility,
testonly = testonly, testonly = testonly,
# Run tfcompile on the build host since it's typically faster on the # Run tfcompile on the build host since it's typically faster on the
@ -281,7 +281,7 @@ def tf_library(
" --out_session_module=$(@D)/" + session_module_pb + " --out_session_module=$(@D)/" + session_module_pb +
" " + flags " " + flags
), ),
tools = [tfcompile_tool], exec_tools = [tfcompile_tool],
visibility = visibility, visibility = visibility,
testonly = testonly, testonly = testonly,
local = 1, local = 1,

View File

@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts") load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
@ -77,7 +77,7 @@ cc_library(
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
] + if_tpu( ] + if_libtpu(
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"], if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
if_true = [], if_true = [],
), ),
@ -114,7 +114,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
] + if_tpu( ] + if_libtpu(
if_false = [ if_false = [
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
], ],
@ -141,7 +141,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/gpu:gpu_init", "//tensorflow/core/common_runtime/gpu:gpu_init",
] + if_tpu( ] + if_libtpu(
if_false = [ if_false = [
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
], ],
@ -204,7 +204,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core:resource_variable_ops_op_lib", "//tensorflow/core:resource_variable_ops_op_lib",
"//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:state_ops_op_lib", "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform:stream_executor_no_cuda",
"//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:fifo_queue", "//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:function_ops",
@ -375,7 +375,7 @@ cc_library(
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
] + if_tpu( ] + if_libtpu(
if_false = [ if_false = [
"//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir:array_container_utils",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
@ -435,6 +435,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime:core_cpu_internal", "//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -1022,10 +1023,10 @@ tf_cc_test(
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
"//tensorflow/core:all_kernels", "//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:ops", "//tensorflow/core:ops",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core/common_runtime:direct_session_internal",
"//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:matmul_op", "//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:partitioned_function_ops", "//tensorflow/core/kernels:partitioned_function_ops",

View File

@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
return Status::OK(); return Status::OK();
} }
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
const Node& node, absl::string_view attr_name,
absl::string_view call_name) {
std::vector<NameAttrList> attr_lists;
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
std::vector<NodeDef> out;
for (int i = 0; i < attr_lists.size(); i++) {
out.emplace_back();
NodeDef& inserted = out.back();
inserted.set_name(absl::StrCat(call_name, "_", i));
inserted.set_op(attr_lists[i].name());
*inserted.mutable_attr() = attr_lists[i].attr();
}
return out;
}
// Utility which searches for values in a sorted list by scanning over it once. // Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most // No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is // once. However, if a call to ScanForValue skips over a value, that value is
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
return is_compilable; return is_compilable;
} }
bool RecursiveCompilabilityChecker::IsCompilableCase(
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
xla::StatusOr<std::vector<NodeDef>> calls =
MakeCallNodesFromAttribute(case_node, "branches", "branch");
if (!calls.ok()) {
VLOG(2) << "Rejecting node " << case_node.name() << ": "
<< "missing attribute 'branches'";
return false;
}
bool is_compilable = true;
for (const NodeDef& call : *calls) {
is_compilable &=
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes);
}
return is_compilable;
}
// Tests whether 'while_node' is a completely compilable loop. // Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a // Every operator in the condition and body functions must be compilable for a
// while loop to be compilable. // while loop to be compilable.
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
return false; return false;
} }
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes)) {
LogNotCompilable(node, "unsupported case");
return false;
}
if (!op_filter_.allow_stateful_rng_ops && if (!op_filter_.allow_stateful_rng_ops &&
IsStatefulRandomOp(node.type_string())) { IsStatefulRandomOp(node.type_string())) {
absl::string_view uncompilable_reason = "stateful random op"; absl::string_view uncompilable_reason = "stateful random op";

View File

@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
// Whether ops known to have numerical accuracy issues should be considered // Whether ops known to have numerical accuracy issues should be considered
// compilable.. // compilable..
bool allow_inaccurate_ops = false; bool allow_inaccurate_ops = false;
// Require the function to be always compilable, regardless whether some
// control flow branches might be dead for a given input.
bool require_always_compilable = false;
}; };
RecursiveCompilabilityChecker(OperationFilter op_filter, RecursiveCompilabilityChecker(OperationFilter op_filter,
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
NameAttrList* encapsulating_function, NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const; UncompilableNodesMap* uncompilable_nodes) const;
// Tests whether 'case_node' is compilable. Every operator in all branches
// must be compilable.
bool IsCompilableCase(const Node& case_node,
FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const;
// Returns compilability of node def retrieved from `node`'s attribute with // Returns compilability of node def retrieved from `node`'s attribute with
// name `attr_name`. // name `attr_name`.
bool ExtractNodeDefAndCheckCompilability( bool ExtractNodeDefAndCheckCompilability(

View File

@ -34,7 +34,16 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
AttrValue attr;
for (const char* name : names) {
attr.mutable_list()->add_func()->set_name(name);
}
return attr;
}
constexpr char kFunctionalIfNodeName[] = "If"; constexpr char kFunctionalIfNodeName[] = "If";
constexpr char kFunctionalCaseNodeName[] = "Case";
constexpr char kFunctionalWhileNodeName[] = "While"; constexpr char kFunctionalWhileNodeName[] = "While";
constexpr char kCompilableFunctionName[] = "CompilableFn"; constexpr char kCompilableFunctionName[] = "CompilableFn";
constexpr char kCompilableFunctionNodeName[] = "n_c"; constexpr char kCompilableFunctionNodeName[] = "n_c";
@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
op_filter_.allow_inaccurate_ops = false; op_filter_.allow_inaccurate_ops = false;
op_filter_.allow_slow_ops = false; op_filter_.allow_slow_ops = false;
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_, checker_ = CreateCompilabilityChecker();
device_type_); }
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
device_type_);
} }
FunctionLibraryRuntime* GetFunctionLibraryRuntime() { FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
"unsupported op")); "unsupported op"));
} }
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_c_uncompilable:float"},
/*Attributes*/ {},
// Node info
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionTwoName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_d_uncompilable:float"},
/*Attribute*/ {},
// Node info
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
std::vector<NodeBuilder::NodeOut> inputes(
{NodeBuilder::NodeOut(placeholder.node())});
Node* case_node;
TF_ASSERT_OK(
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
.Input(branch_index.node())
.Input(inputes)
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
kUncompilableFunctionTwoName}))
.Attr("Tout", {DT_INT32})
.Finalize(root.graph(), &case_node));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
auto case_node_it = std::find_if(
graph->nodes().begin(), graph->nodes().end(),
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
EXPECT_NE(case_node_it, graph->nodes().end());
auto* flib_runtime = GetFunctionLibraryRuntime();
op_filter_.require_always_compilable = false;
checker_ = CreateCompilabilityChecker();
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
op_filter_.require_always_compilable = true;
checker_ = CreateCompilabilityChecker();
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
}
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) { TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately); GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError(); Scope root = Scope::NewRootScope().ExitOnError();

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -47,8 +48,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
xla::StatusOr<std::string> GetCompilerIr( xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev, absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const Tensor* const> inputs) { absl::Span<const TensorHandle* const> inputs_handles) {
NameAttrList function; NameAttrList function;
function.set_name(std::string{func_name}); function.set_name(std::string{func_name});
@ -65,6 +66,25 @@ xla::StatusOr<std::string> GetCompilerIr(
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
std::deque<Tensor> inputs_storage;
std::vector<const Tensor*> inputs;
inputs.reserve(inputs_handles.size());
for (int i = 0; i < inputs_handles.size(); i++) {
const TensorHandle* th = inputs_handles[i];
const Tensor* t;
// Handle owns the tensor.
TF_RETURN_IF_ERROR(th->Tensor(&t));
if (absl::c_binary_search(constant_arg_indices, i)) {
// Need to make sure it's on the host.
inputs_storage.emplace_back(t->dtype(), t->shape());
TF_RETURN_IF_ERROR(
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
inputs.push_back(&inputs_storage.back());
} else {
inputs.push_back(t);
}
}
std::vector<VariableInfo> variable_infos; std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs( TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
rmgr, dev, inputs, resource_arg_indices, &variable_infos)); rmgr, dev, inputs, resource_arg_indices, &variable_infos));

View File

@ -24,6 +24,8 @@ namespace tensorflow {
class ProcessFunctionLibraryRuntime; class ProcessFunctionLibraryRuntime;
class Device; class Device;
class Tensor; class Tensor;
class TensorHandle;
class EagerContext;
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT }; enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
// `runtime` on a device `dev` with given `inputs`. // `runtime` on a device `dev` with given `inputs`.
xla::StatusOr<std::string> GetCompilerIr( xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev, absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const Tensor* const> inputs); absl::Span<const TensorHandle* const> inputs);
} // namespace tensorflow } // namespace tensorflow

View File

@ -34,7 +34,7 @@ XLA_OPS_DEPS = [
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib", "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter", "//tensorflow/stream_executor:tf_allocator_adapter",
] ]

View File

@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue; continue;
} }
if (!RecursiveCompilabilityChecker{ RecursiveCompilabilityChecker::OperationFilter filter =
CreateOperationFilter(*registration), CreateOperationFilter(*registration);
DeviceType{registration->compilation_device_name}} filter.require_always_compilable = true;
.IsCompilableNode(*node, lib_runtime)) {
RecursiveCompilabilityChecker checker(
filter, DeviceType{registration->compilation_device_name});
if (!checker.IsCompilableNode(*node, lib_runtime)) {
continue; continue;
} }
@ -2062,6 +2066,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSpmdFullToShardShape", "XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape", "XlaSpmdShardToFullShape",
"XlaSvd", "XlaSvd",
"XlaVariadicReduce",
"XlaWhile", "XlaWhile",
"Zeta", "Zeta",
"_Arg", "_Arg",

View File

@ -47,7 +47,7 @@ limitations under the License.
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/dump_graph.h"
#if !defined(LIBTFTPU) #if !defined(LIBTPU_ON_GCE)
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
#endif #endif
@ -289,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
}); });
const ConfigProto* config = ctx->function_library()->config_proto(); const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge(); bool use_mlir = config && config->experimental().enable_mlir_bridge();
#ifdef LIBTFTPU #ifdef LIBTPU_ON_GCE
if (use_mlir && has_tensor_list_arg) { if (use_mlir && has_tensor_list_arg) {
LOG(WARNING) << "MLIR is not supported in this environment."; LOG(WARNING) << "MLIR is not supported in this environment.";
} }
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
} }
GraphDebugInfo debug_info; GraphDebugInfo debug_info;
std::vector<std::string> control_rets;
if (result_dtypes.empty()) {
control_rets.push_back(node_def.name());
}
return CompileGraphToXlaHlo( return CompileGraphToXlaHlo(
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
options.device_type.type_string(), compile_options.use_tuple_arg, options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result); *options.flib_def, debug_info, options.shape_representation_fn, result);
#endif #endif

View File

@ -9,3 +9,31 @@ dialects and utilities for
3. TF Lite 3. TF Lite
See [MLIR's website](https://mlir.llvm.org) for complete documentation. See [MLIR's website](https://mlir.llvm.org) for complete documentation.
## Getting started
Building dialects and utilities here follow the standard approach using
`bazel` as the rest of TensorFlow.
### Using local LLVM repo
To develop across MLIR core and TensorFlow, it is useful to override the repo
to use a local version instead of fetching from head. This can be achieved as
below but note, the BUILD files are not automatically generated from or CMake
used, so if your change requires a BUILD file change (or you are using a
different version of LLVM than set in tensorflow/workspace.bzl's LLVM_COMMIT)
then manual BUILD file changes may be required.
```sh
LLVM_SRC=...
# Create basic workspace file
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
# and copy over the bazel BUILD files.
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
cp third_party/mlir/BUILD $LLVM_SRC/mlir
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD
bazel build --override_repository=llvm-project=$LLVM_SRC \
-c opt tensorflow/compiler/mlir:tf-opt
```

View File

@ -48,6 +48,7 @@ filegroup(
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
"@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
@ -539,6 +540,8 @@ cc_library(
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",

View File

@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
}]; }];
} }
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
HLO_FpOrComplexTensor> {
let summary = "Atan operator";
let description = [{
Returns `Atan(operand)` element-wise.
$$
\atan(x) = \atan2(x, 1)
$$
}];
}
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
HLO_FpOrComplexTensor> { HLO_FpOrComplexTensor> {
let summary = "Sinh operation"; let summary = "Sinh operation";

View File

@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
>]; >];
} }
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp; [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
@ -193,12 +196,10 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one",
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
def HLO_ImagOp: HLO_Op< def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { [NoSideEffect, SameOperandsAndResultShape,
let builders = [OpBuilder< DeclareOpInterfaceMethods<InferTypeOpInterface>],
"OpBuilder &, OperationState &tblgen_state, Value val">]; HLO_ComplexTensor>, BASE_HLO_ImagOp {
let arguments = (ins HLO_ComplexTensor);
let results = (outs HLO_FpTensor); let results = (outs HLO_FpTensor);
let hasFolder = 1; let hasFolder = 1;
} }
@ -237,12 +238,10 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
BASE_HLO_PopulationCountOp; BASE_HLO_PopulationCountOp;
def HLO_RealOp: HLO_Op< def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { [NoSideEffect, SameOperandsAndResultShape,
let builders = [OpBuilder< DeclareOpInterfaceMethods<InferTypeOpInterface>],
"OpBuilder &, OperationState &tblgen_state, Value val">]; HLO_ComplexTensor>, BASE_HLO_RealOp {
let arguments = (ins HLO_ComplexTensor);
let results = (outs HLO_FpTensor); let results = (outs HLO_FpTensor);
let hasFolder = 1; let hasFolder = 1;
} }
@ -321,12 +320,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
def HLO_ComplexOp: HLO_Op<"complex", def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex",
[NoSideEffect, SameOperandsAndResultShape]>, [NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
BASE_HLO_ComplexOp { BASE_HLO_ComplexOp {
let builders = [OpBuilder<
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
let results = (outs HLO_ComplexTensor); let results = (outs HLO_ComplexTensor);
let hasFolder = 1; let hasFolder = 1;
@ -356,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
let hasFolder = 1;
}
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
@ -913,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
// TODO(hinsu): Make this struct dialect independent so that it can be shared
// between HLO and LHLO dialect.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
let arguments = (ins let arguments = !con(
HLO_Tensor:$lhs, (ins
HLO_Tensor:$rhs, HLO_Tensor:$lhs,
// Default value: one for each of the spatial dimension. HLO_Tensor:$rhs),
OptionalAttr<I64ElementsAttr>:$window_strides, ConvolutionAttributes<HLO_Dialect>.attributes);
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
@ -1198,14 +1170,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
let arguments = (ins let arguments = (ins
Variadic<HLO_Tensor>:$operands, Variadic<HLO_Tensor>:$operands,
DefaultValuedAttr<I64Attr, "-1">:$dimension, DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable DefaultValuedAttr<BoolAttr, "false">:$is_stable
); );
let results = (outs HLO_TensorOrTuple); let results = (outs Variadic<HLO_Tensor>);
let regions = (region SizedRegion<1>:$comparator); let regions = (region SizedRegion<1>:$comparator);
@ -1429,4 +1401,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }
// This is an op for purposes internal to XLA/GPU.
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
let arguments = (ins HLO_Tensor:$operand);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
BASE_HLO_ReducePrecisionOp {
let arguments = (ins
HLO_FpTensor:$operand,
I32Attr:$exponent_bits,
I32Attr:$mantissa_bits
);
let results = (outs HLO_FpTensor:$output);
}
#endif // HLO_OPS #endif // HLO_OPS

View File

@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
}]; }];
} }
class BASE_HLO_CbrtOp {
string summary = "Cubic root operator";
string description = [{
Returns element-wise cubic root of the operand.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_CeilOp { class BASE_HLO_CeilOp {
string summary = "Ceil operator"; string summary = "Ceil operator";
@ -996,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
}]; }];
} }
//===----------------------------------------------------------------------===//
// Common convolution attributes
//===----------------------------------------------------------------------===//
class ConvDimensionNumbersBase<Dialect dialect>
: StructAttr<"ConvDimensionNumbers", dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
class ConvolutionAttributes<Dialect dialect> {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}
class BASE_HLO_ConvOp { class BASE_HLO_ConvOp {
string summary = "Convolution operator"; string summary = "Convolution operator";
@ -1336,4 +1383,17 @@ class BASE_HLO_WhileOp {
}]; }];
} }
class BASE_HLO_BitcastOp {
string summary = "Bitcast operator";
string description = [{
This op changes the shape of the input in the way that the physical
arranggment of elements are unchanged.
However, the op needs layout information to make sense of "physical
arrangement of elements". Layout support in MHLO is currently under
exploration.
}];
}
#endif // HLO_OPS_BASE #endif // HLO_OPS_BASE

View File

@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
def LHLO_Dialect : Dialect { def LHLO_Dialect : Dialect {
let name = "lmhlo"; let name = "lmhlo";
let cppNamespace = "::mlir::lmhlo"; let cppNamespace = "::mlir::lmhlo";
} }
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LMHLO nullary op definitions. // LMHLO nullary op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -289,6 +264,16 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
} }
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
StrAttr:$call_target_name,
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config
);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LMHLO tuple op definitions. // LMHLO tuple op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -335,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast", def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = [{ let summary = [{
"modifies the offset, sizes and strides of a statically shaped memref. modifies the offset, sizes and strides of a statically shaped memref
}]; }];
let description = [{ let description = [{
Allows to modify the offset, sizes and strides of a statically shaped memref. Casts the statically shaped memref operand to a memref with optionally
modified offsets, sizes and strides.
Example: Example:
```mlir ```mlir
@ -354,12 +340,11 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand); let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
let results = (outs Res<LHLO_Buffer, "", []>:$result); let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [OpBuilder< let builders = [OpBuilder<"MemRefType resultType, Value operand",
"OpBuilder &builder, OperationState &result, MemRefType resultType, " # [{
"Value operand", [{ $_state.addOperands(operand);
result.addOperands(operand); $_state.types.push_back(resultType);
result.types.push_back(resultType); }]>];
}]>];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); } MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
@ -400,13 +385,13 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
); );
let results = (outs Res<LHLO_Buffer, "", []>:$result); let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [OpBuilder< let builders = [
"OpBuilder &builder, OperationState &result, MemRefType resultType, " # OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, "
"Value operand, ValueRange sizes, ValueRange strides", [{ "ValueRange strides", [{
result.addOperands(operand); $_state.addOperands(operand);
result.addOperands(sizes); $_state.addOperands(sizes);
result.addOperands(strides); $_state.addOperands(strides);
result.types.push_back(resultType); $_state.types.push_back(resultType);
}]>]; }]>];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
@ -582,40 +567,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
); );
} }
// TODO(bondhugula): Make this struct dialect independent so that it can be
// shared between the HLO and LHLO dialects.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = (ins let arguments = !con(
Arg<LHLO_Buffer, "", [MemRead]>:$lhs, (ins
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
// Default value: one for each of the spatial dimension. Arg<LHLO_Buffer, "", [MemWrite]>:$output),
OptionalAttr<I64ElementsAttr>:$window_strides, ConvolutionAttributes<LHLO_Dialect>.attributes);
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
} }
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
@ -856,9 +814,8 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, " OpBuilder<"ArrayRef<NamedAttribute> attributes">
"ArrayRef<NamedAttribute> attributes"> ];
];
} }
def TerminatorOp : def TerminatorOp :
@ -867,9 +824,8 @@ def TerminatorOp :
let description = [{ let description = [{
Terminator operation for the LHLO dialect. Terminator operation for the LHLO dialect.
}]; }];
let builders = [OpBuilder< let builders = [OpBuilder<"ValueRange operands",
"OpBuilder &b, OperationState &result, ValueRange operands", [{ build($_builder, $_state, llvm::None, operands, llvm::None); }]
[{ build(b, result, llvm::None, operands, llvm::None); }]
>]; >];
} }

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_OPS_BASE
#define LHLO_OPS_BASE
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
#endif // LHLO_OPS_BASE

View File

@ -50,6 +50,7 @@ MAP_HLO_TO_LHLO(ConvOp);
MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp); MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(CustomCallOp);
MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(ExpOp); MAP_HLO_TO_LHLO(ExpOp);
@ -57,11 +58,13 @@ MAP_HLO_TO_LHLO(FloorOp);
MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(GatherOp);
MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(ImagOp);
MAP_HLO_TO_LHLO(IotaOp); MAP_HLO_TO_LHLO(IotaOp);
MAP_HLO_TO_LHLO(IsFiniteOp);
MAP_HLO_TO_LHLO(LogOp); MAP_HLO_TO_LHLO(LogOp);
MAP_HLO_TO_LHLO(MaxOp); MAP_HLO_TO_LHLO(MaxOp);
MAP_HLO_TO_LHLO(MinOp); MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp); MAP_HLO_TO_LHLO(ReshapeOp);

View File

@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
loc, result_types, args, b); loc, result_types, args, b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
loc, result_types, args, b);
}
template <typename PredicateType> template <typename PredicateType>
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) { inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
return llvm::None; return llvm::None;
@ -345,6 +354,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
loc, result_types, args, b); loc, result_types, args, b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
if (args[0].getType().isa<FloatType>()) {
auto pos_inf = APFloat::getInf(
args[0].getType().cast<FloatType>().getFloatSemantics());
auto const_pos_inf =
b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
const_pos_inf);
}
return nullptr;
}
/// Implements the conversion of HLO op to scalar op (to use within region of a /// Implements the conversion of HLO op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max. /// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args> template <typename... Args>
@ -431,6 +456,21 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
return nullptr; return nullptr;
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// lmhlo.not(x) -> x ^ -1
auto all_ones =
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
}
return nullptr;
}
template <> template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc, inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
@ -454,11 +494,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) { if (auto float_type = element_type.dyn_cast<FloatType>()) {
FloatType float_type = element_type.cast<FloatType>(); bool ignored;
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); APFloat one_apfloat(1.0f);
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type); one_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
loc, integer_type.getWidth() - 1, integer_type.getWidth());
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value one =
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
} }
return nullptr; return nullptr;
} }

View File

@ -15,9 +15,9 @@ limitations under the License.
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> { def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Test pass for applying chlo -> hlo legalization patterns."; let summary = "Legalize CHLO to HLO.";
let constructor = "createTestChloLegalizeToHloPass()"; let constructor = "createChloLegalizeToHloPass()";
} }
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {

View File

@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
/// Lowers from HLO dialect to Standard dialect. /// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from the CHLO dialect to the HLO dialect.
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true, /// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the /// allocated buffers for function results will be returned and escape the
@ -63,7 +66,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass(); std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
/// Lowers trigonometric operations from the standard dialect to approximations /// Lowers trigonometric operations from the standard dialect to approximations
// that do not use intrinsics. /// that do not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
createLegalizeTrigonometricToApproximationPass(); createLegalizeTrigonometricToApproximationPass();

View File

@ -22,7 +22,6 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass(); std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass(); std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
std::unique_ptr<Pass> createTestUnfuseBatchNormPass(); std::unique_ptr<Pass> createTestUnfuseBatchNormPass();

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/BufferPlacement.h" #include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {

View File

@ -185,8 +185,7 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
return failure(); return failure();
const auto& dnums = gather.dimension_numbers(); const auto& dnums = gather.dimension_numbers();
if (dnums.collapsed_slice_dims().getNumElements() != 0 || if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
return failure(); return failure();
// TODO(tberghammer): Remove when the verifier catches this case what is // TODO(tberghammer): Remove when the verifier catches this case what is
@ -206,11 +205,35 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
} }
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1); llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
rewriter.replaceOpWithNewOp<SliceOp>( llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
gather, gather.getType(), gather.getOperand(0), for (int64_t i = 0; i < slice_end.size(); ++i) {
slice_shape[i] = slice_end[i] - slice_start[i];
}
Type element_type = gather.getType().cast<TensorType>().getElementType();
auto slice_type = RankedTensorType::get(slice_shape, element_type);
Value result = rewriter.create<SliceOp>(
gather.getLoc(), slice_type, gather.getOperand(0),
GetI64ElementsAttr(slice_start, &rewriter), GetI64ElementsAttr(slice_start, &rewriter),
GetI64ElementsAttr(slice_end, &rewriter), GetI64ElementsAttr(slice_end, &rewriter),
GetI64ElementsAttr(slice_stride, &rewriter)); GetI64ElementsAttr(slice_stride, &rewriter));
if (dnums.collapsed_slice_dims().getNumElements() > 0) {
auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
dnums.collapsed_slice_dims().getIntValues(),
[](const llvm::APInt& i) { return i.getSExtValue(); }));
llvm::SmallVector<int64_t, 8> reshape_shape;
for (int64_t i = 0; i < slice_shape.size(); ++i) {
if (llvm::count(collapsed_slice_dims, i) == 0) {
reshape_shape.push_back(slice_shape[i]);
}
}
auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
result =
rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
}
result.setType(gather.getType());
rewriter.replaceOp(gather, result);
return success(); return success();
} }
}; };
@ -889,9 +912,10 @@ static LogicalResult Verify(ClampOp op) {
// ComplexOp // ComplexOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, LogicalResult ComplexOp::inferReturnTypes(
Value rhs) { MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
auto type = lhs.getType(); RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto type = operands[0].getType();
auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
Type result_ty; Type result_ty;
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) { if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
@ -901,8 +925,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
} else { } else {
result_ty = element_ty; result_ty = element_ty;
} }
inferredReturnTypes.push_back(result_ty);
build(builder, state, result_ty, lhs, rhs); return success();
} }
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
@ -932,8 +956,11 @@ Type CreateRealType(Type type) {
} }
} // namespace } // namespace
void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { LogicalResult ImagOp::inferReturnTypes(
build(builder, state, CreateRealType(val.getType()), val); MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
return success();
} }
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
@ -945,8 +972,11 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
} }
void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { LogicalResult RealOp::inferReturnTypes(
build(builder, state, CreateRealType(val.getType()), val); MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
return success();
} }
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) { OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
@ -1971,6 +2001,23 @@ struct divide<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
}; };
template <typename T>
struct remainder : std::modulus<T> {};
template <>
struct remainder<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
};
template <>
struct remainder<APFloat> {
APFloat operator()(const APFloat& a, const APFloat& b) const {
APFloat result(a);
result.remainder(b);
return result;
}
};
template <typename T> template <typename T>
struct max { struct max {
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); } T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
@ -2012,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
BINARY_FOLDER(SubOp, std::minus); BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(MulOp, std::multiplies); BINARY_FOLDER(MulOp, std::multiplies);
BINARY_FOLDER(DivOp, divide); BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(RemOp, remainder);
BINARY_FOLDER(MaxOp, max); BINARY_FOLDER(MaxOp, max);
BINARY_FOLDER(MinOp, min); BINARY_FOLDER(MinOp, min);
@ -2261,10 +2309,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state,
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder.getBoolAttr(dimension)); state.addAttribute("is_stable", builder.getBoolAttr(dimension));
SmallVector<Type, 2> element_types; for (Value operand : operands) state.addTypes(operand.getType());
element_types.reserve(operands.size());
for (Value operand : operands) element_types.push_back(operand.getType());
state.addTypes(builder.getTupleType(element_types));
state.addRegion(); state.addRegion();
} }

View File

@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
auto if_op = rewriter.create<scf::IfOp>( auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true); loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(); OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>( Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>( Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs}, loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
else_lhs_scalar_builder.create<scf::YieldOp>(loc, else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0)); if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(); OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>( Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>( Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs}, loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
@ -516,7 +516,7 @@ struct HloCompareAdaptor {
void PopulateLegalizeChloToHloPatterns(MLIRContext *context, void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
populateWithGenerated(context, patterns); populateWithGenerated(context, *patterns);
// Instantiate conversion templates for conforming binary elementwise ops // Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do // that do not have different dtypes between operands and results and do

Some files were not shown because too many files have changed in this diff Show More