Merge branch 'master' into pluggable_device_load
This commit is contained in:
commit
22d22ff5b3
6
.bazelrc
6
.bazelrc
@ -174,6 +174,12 @@ build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
|
|||||||
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
||||||
build:mkl_opensource_only -c opt
|
build:mkl_opensource_only -c opt
|
||||||
|
|
||||||
|
# Config setting to build with oneDNN for Arm.
|
||||||
|
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
|
||||||
|
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
build:mkl_aarch64 --define=build_with_mkl_opensource=true
|
||||||
|
build:mkl_aarch64 -c opt
|
||||||
|
|
||||||
# This config refers to building with CUDA available. It does not necessarily
|
# This config refers to building with CUDA available. It does not necessarily
|
||||||
# mean that we build CUDA op kernels.
|
# mean that we build CUDA op kernels.
|
||||||
build:using_cuda --define=using_cuda=true
|
build:using_cuda --define=using_cuda=true
|
||||||
|
6
.github/bot_config.yml
vendored
6
.github/bot_config.yml
vendored
@ -12,12 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
#
|
|
||||||
# THIS IS A GENERATED DOCKERFILE.
|
|
||||||
#
|
|
||||||
# This file was assembled from multiple pieces, whose use is documented
|
|
||||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
|
||||||
# for more information.
|
|
||||||
|
|
||||||
# A list of assignees
|
# A list of assignees
|
||||||
assignees:
|
assignees:
|
||||||
|
28
.github/workflows/update-nightly.yml
vendored
Normal file
28
.github/workflows/update-nightly.yml
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # Allow manual triggers
|
||||||
|
schedule:
|
||||||
|
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
|
||||||
|
name: Set nightly branch to master HEAD
|
||||||
|
jobs:
|
||||||
|
master-to-nightly:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: zofrex/mirror-branch@v1
|
||||||
|
name: Set nightly branch to master HEAD
|
||||||
|
with:
|
||||||
|
target-branch: 'nightly'
|
@ -4,7 +4,7 @@
|
|||||||
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
|
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
|
||||||
/tenosrflow/core/debug @caisq
|
/tenosrflow/core/debug @caisq
|
||||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||||
/tensorflow/core/platform/windows/ @gunan @mihaimaruseac
|
/tensorflow/core/platform/windows/ @mihaimaruseac
|
||||||
/tensorflow/lite/experimental/micro @petewarden @advaitjain
|
/tensorflow/lite/experimental/micro @petewarden @advaitjain
|
||||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||||
/tensorflow/python/debug @caisq
|
/tensorflow/python/debug @caisq
|
||||||
|
538
RELEASE.md
538
RELEASE.md
@ -34,6 +34,7 @@
|
|||||||
shape assumptions (note that you can pass shapes with `None` entries for axes
|
shape assumptions (note that you can pass shapes with `None` entries for axes
|
||||||
that are meant to be dynamic). You can also disable the input checking
|
that are meant to be dynamic). You can also disable the input checking
|
||||||
entirely by setting `model.input_spec = None`.
|
entirely by setting `model.input_spec = None`.
|
||||||
|
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
|
||||||
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
||||||
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
||||||
removed).
|
removed).
|
||||||
@ -46,6 +47,13 @@
|
|||||||
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
||||||
instead of individual arguments. Usages should be updated to
|
instead of individual arguments. Usages should be updated to
|
||||||
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
||||||
|
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
|
||||||
|
updates the gradient definition for quantization which is outside the range
|
||||||
|
to be 0. To simulate the V1 the behavior of
|
||||||
|
tf.quantization.quantize_and_dequantize(...) use
|
||||||
|
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||||
|
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
|
||||||
|
use `tf.data.Dataset.from_tensor_slices` instead.
|
||||||
|
|
||||||
## Known Caveats
|
## Known Caveats
|
||||||
|
|
||||||
@ -63,143 +71,168 @@
|
|||||||
|
|
||||||
## Bug Fixes and Other Changes
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||||
* Security:
|
* Security:
|
||||||
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
|
||||||
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
|
||||||
* Fixes three vulnerabilities in conversion to DLPack format
|
* Fixes three vulnerabilities in conversion to DLPack format
|
||||||
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
|
||||||
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
|
||||||
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
|
||||||
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
|
||||||
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
|
||||||
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
|
||||||
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
|
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
|
||||||
`SparseCountSparseOutput` operations
|
`SparseCountSparseOutput` operations
|
||||||
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
|
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
|
||||||
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
|
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
|
||||||
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
|
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
|
||||||
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
||||||
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
||||||
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
||||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
* Fixes an integer truncation vulnerability in code using the work sharder
|
||||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
API
|
||||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||||
* Fixes segfault raised by calling session-only ops in eager mode
|
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
* Fixes segfault raised by calling session-only ops in eager mode
|
||||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
* Fixes data leak and potential ASLR violation from
|
||||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
`tf.raw_ops.StringNGrams`
|
||||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
* Fixes a data corruption due to a bug in negative indexing support in
|
||||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
TFLite
|
||||||
* Fixes several vulnerabilities in TFLite saved model format
|
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||||
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||||
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||||
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
* Fixes several vulnerabilities in TFLite saved model format
|
||||||
* Fixes several vulnerabilities in TFLite implementation of segment sum
|
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
|
||||||
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
|
||||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
|
||||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
* Fixes several vulnerabilities in TFLite implementation of segment sum
|
||||||
* TF Core:
|
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
|
||||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
|
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||||
type annotation for variables representing a Tensor or a value that can be
|
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||||
converted to Tensor by `tf.convert_to_tensor`.
|
* TF Core:
|
||||||
* Calling ops with a python constants or numpy values is now consistent with
|
* `tf.types.experimental.TensorLike` is a new `Union` type that can be
|
||||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
used as type annotation for variables representing a Tensor or a value
|
||||||
truncating inputs such as from int64 to int32.
|
that can be converted to Tensor by `tf.convert_to_tensor`.
|
||||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
* Calling ops with a python constants or numpy values is now consistent
|
||||||
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
|
with tf.convert_to_tensor behavior. This avoids operations like
|
||||||
and `__invert__` now support non-`bool` arguments and apply the
|
tf.reshape truncating inputs such as from int64 to int32.
|
||||||
corresponding bitwise ops. `bool` arguments continue to be supported and
|
* Added `tf.sparse.map_values` to apply a function to the `.value`s of
|
||||||
dispatch to logical ops. This brings them more in line with Python and NumPy
|
`SparseTensor` arguments.
|
||||||
benavior.
|
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
|
||||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
|
`__xor__` and `__invert__` now support non-`bool` arguments and apply
|
||||||
the same sparsity pattern, but with new provided values. It is similar to
|
the corresponding bitwise ops. `bool` arguments continue to be supported
|
||||||
the `with_values` function of `RaggedTensor`.
|
and dispatch to logical ops. This brings them more in line with Python
|
||||||
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
|
and NumPy behavior.
|
||||||
* Added `tf.config.experimental.get_memory_usage` to return total memory usage
|
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
|
||||||
of the device.
|
with the same sparsity pattern, but with new provided values. It is
|
||||||
* `tf.data`:
|
similar to the `with_values` function of `RaggedTensor`.
|
||||||
* tf.data service:
|
* Added `StatelessCase` op, and uses it if none of case branches has
|
||||||
* Added new `tf.data.experimental.service.register_dataset` and
|
stateful ops.
|
||||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
|
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||||
to register a dataset with the tf.data service, and another process to
|
usage of the device.
|
||||||
consume data from the dataset.
|
* `tf.data`:
|
||||||
* Added support for dispatcher fault tolerance. To enable fault tolerance,
|
* tf.data service:
|
||||||
configure a `work_dir` when running your dispatcher server and set
|
* Added new `tf.data.experimental.service.register_dataset` and
|
||||||
`dispatcher_fault_tolerance=True`. The dispatcher will store its state to
|
`tf.data.experimental.service.from_dataset_id` APIs to enable one
|
||||||
`work_dir`, so that on restart it can continue from its previous state
|
process to register a dataset with the tf.data service, and another
|
||||||
after restart.
|
process to consume data from the dataset.
|
||||||
* Added support for sharing dataset graphs via shared filesystem instead of
|
* Added support for dispatcher fault tolerance. To enable fault tolerance,
|
||||||
over RPC. This reduces load on the dispatcher, improving performance of
|
configure a `work_dir` when running your dispatcher server and set
|
||||||
distributing datasets. For this to work, the dispatcher's `work_dir` must
|
`dispatcher_fault_tolerance=True`. The dispatcher will store its state
|
||||||
be accessible from workers. If the worker fails to read from the
|
to `work_dir`, so that on restart it can continue from its previous
|
||||||
`work_dir`, it falls back to using RPC for dataset graph transfer.
|
state after restart.
|
||||||
* Added support for a new "distributed_epoch" processing mode. This
|
* Added support for sharing dataset graphs via shared filesystem instead
|
||||||
processing mode distributes a dataset across all tf.data workers, instead
|
of over RPC. This reduces load on the dispatcher, improving performance
|
||||||
of having each worker process the full dataset. See
|
of distributing datasets. For this to work, the dispatcher's `work_dir`
|
||||||
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
|
must be accessible from workers. If the worker fails to read from the
|
||||||
to learn more.
|
`work_dir`, it falls back to using RPC for dataset graph transfer.
|
||||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
* Added support for a new "distributed_epoch" processing mode. This
|
||||||
the complement of `select_cols`; at most one of these should be specified.
|
processing mode distributes a dataset across all tf.data workers,
|
||||||
* We have implemented an optimization which reorders data-discarding
|
instead of having each worker process the full dataset. See
|
||||||
transformations such as `take` and `shard` to happen earlier in the
|
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
|
||||||
dataset when it is safe to do so. The optimization can be disabled via
|
to learn more.
|
||||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||||
option.
|
the complement of `select_cols`; at most one of these should be
|
||||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
specified.
|
||||||
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
* We have implemented an optimization which reorders data-discarding
|
||||||
with a new `output_signature` argument, which allows `from_generator` to
|
transformations such as `take` and `shard` to happen earlier in the
|
||||||
produce any type describable by a `tf.TypeSpec`.
|
dataset when it is safe to do so. The optimization can be disabled via
|
||||||
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
|
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||||
`tf.data.AUTOTUNE`.
|
option.
|
||||||
* `tf.image`:
|
* `tf.data.Options` were previously immutable and can now be overridden.
|
||||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
||||||
`tf.image.random_*` function. Added a new op
|
with a new `output_signature` argument, which allows `from_generator` to
|
||||||
`stateless_sample_distorted_bounding_box` which is a determinstic
|
produce any type describable by a `tf.TypeSpec`.
|
||||||
version of `sample_distorted_bounding_box` op. Given the same seed, these
|
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
|
||||||
stateless functions/ops produce the same results independent of how many
|
`tf.data.AUTOTUNE`.
|
||||||
times the function is called, and independent of global seed settings.
|
* `tf.image`:
|
||||||
|
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||||
|
`tf.image.random_*` function. Added a new op
|
||||||
|
`stateless_sample_distorted_bounding_box` which is a deterministic
|
||||||
|
version of `sample_distorted_bounding_box` op. Given the same seed,
|
||||||
|
these stateless functions/ops produce the same results independent of
|
||||||
|
how many times the function is called, and independent of global seed
|
||||||
|
settings.
|
||||||
* `tf.distribute`:
|
* `tf.distribute`:
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
* `tf.keras`:
|
* `tf.keras`:
|
||||||
* Improvements from the functional API refactoring:
|
* Improvements from the functional API refactoring:
|
||||||
* Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models.
|
* Functional model construction does not need to maintain a global
|
||||||
* Functional model construction should be ~8-10% faster on average.
|
workspace graph, removing memory leaks especially when building many
|
||||||
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument.
|
models or very large models.
|
||||||
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale`
|
* Functional model construction should be ~8-10% faster on average.
|
||||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
* Functional models can now contain non-symbolic values in their call
|
||||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
inputs inside of the first positional argument.
|
||||||
as an alternative to accepting a `callable` loss.
|
* Several classes of TF ops that were not reliably converted to Keras
|
||||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
layers during functional API construction should now work, e.g.
|
||||||
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
`tf.image.ssim_multiscale`
|
||||||
* Added `mobilenet_v3` to keras application model.
|
* Error messages when Functional API construction goes wrong (and when
|
||||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
ops cannot be converted to Keras layers automatically) should be
|
||||||
customization of how gradients are aggregated across devices, as well as
|
clearer and easier to understand.
|
||||||
`gradients_transformers` to allow for custom gradient transformations
|
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||||
(such as gradient clipping).
|
as an alternative to accepting a `callable` loss.
|
||||||
* The `steps_per_execution` argument in `compile()` is no longer
|
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||||
experimental; if you were passing `experimental_steps_per_execution`,
|
to match FTRL paper
|
||||||
rename it to `steps_per_execution` in your code. This argument controls
|
(https://research.google.com/pubs/archive/41159.pdf).
|
||||||
the number of batches to run during each `tf.function` call when calling
|
* Added `mobilenet_v3` to keras application model.
|
||||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||||
greatly improve performance on TPUs or small models with a large Python
|
customization of how gradients are aggregated across devices, as well as
|
||||||
overhead.
|
`gradients_transformers` to allow for custom gradient transformations
|
||||||
* `tf.function` / AutoGraph:
|
(such as gradient clipping).
|
||||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
* The `steps_per_execution` argument in `compile()` is no longer
|
||||||
True, the function may use type annotations to optimize the tracing
|
experimental; if you were passing `experimental_steps_per_execution`,
|
||||||
performance.
|
rename it to `steps_per_execution` in your code. This argument controls
|
||||||
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
the number of batches to run during each `tf.function` call when calling
|
||||||
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||||
the values of these symbols at an iteration does not depend on the previous
|
greatly improve performance on TPUs or small models with a large Python
|
||||||
iteration. These types of loops must run at least one iteration, and will
|
overhead.
|
||||||
raise a runtime error otherwise.
|
* Improvements to Keras preprocessing layers:
|
||||||
|
* TextVectorization can now accept a vocabulary list or file as an
|
||||||
|
init arg.
|
||||||
|
* Normalization can now accept mean and variance values as init args.
|
||||||
|
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||||
|
accepts a `return_attention_scores` argument. When set to
|
||||||
|
True, the layer returns the attention scores as an additional output
|
||||||
|
argument.
|
||||||
|
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||||
|
with the same implementation as their `tf.losses` equivalent.
|
||||||
|
* `tf.function` / AutoGraph:
|
||||||
|
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||||
|
True, the function may use type annotations to optimize the tracing
|
||||||
|
performance.
|
||||||
|
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
||||||
|
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
||||||
|
the values of these symbols at an iteration does not depend on the
|
||||||
|
previous iteration. These types of loops must run at least one
|
||||||
|
iteration, and will raise a runtime error otherwise.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -208,51 +241,97 @@
|
|||||||
outputs = train_step(batch)
|
outputs = train_step(batch)
|
||||||
tf.print('final outputs', outputs)
|
tf.print('final outputs', outputs)
|
||||||
```
|
```
|
||||||
|
|
||||||
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
|
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
|
||||||
info.
|
info.
|
||||||
|
|
||||||
* `tf.lite`:
|
* `tf.lite`:
|
||||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
|
||||||
string to be joined is empty.
|
* `TFLiteConverter`:
|
||||||
* `TFLiteConverter`:
|
* Support optional flags `inference_input_type` and
|
||||||
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
|
`inference_output_type` for full integer quantized models. This
|
||||||
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
|
allows users to modify the model input and output type to integer
|
||||||
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
|
types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
|
||||||
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
|
(`tf.float32`).
|
||||||
* TFLite Profiler for Android is available. See the detailed
|
* TFLite Profiler for Android is available. See the detailed
|
||||||
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
||||||
* <ADD RELEASE NOTES HERE>
|
* NNAPI
|
||||||
|
* Added NNAPI Delegation support for requantization use cases by
|
||||||
|
converting the operation into a dequantize-quantize pair.
|
||||||
|
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
|
||||||
|
* Use `Interpreter.Options.setUseNNAPI` instead.
|
||||||
|
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
|
||||||
|
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||||
|
directly.
|
||||||
|
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
|
||||||
|
* Prefer controlling this via delegate options, e.g.
|
||||||
|
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
|
||||||
|
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
||||||
|
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||||
|
string to be joined is empty.
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* `tf.random`:
|
* `tf.random`:
|
||||||
* <ADD RELEASE NOTES HERE>
|
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* Math and Linear Algebra:
|
* Math and Linear Algebra:
|
||||||
* <ADD RELEASE NOTES HERE>
|
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* TPU Enhancements:
|
* TPU Enhancements:
|
||||||
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
|
||||||
embeddings. Users of other TensorFlow platforms can implement equivalent
|
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
||||||
behavior by adjusting the `l2` parameter.
|
embeddings. Users of other TensorFlow platforms can implement equivalent
|
||||||
* <ADD RELEASE NOTES HERE>
|
behavior by adjusting the `l2` parameter.
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* XLA Support:
|
* XLA Support:
|
||||||
* xla.experimental.compile is deprecated, use
|
|
||||||
`tf.function(experimental_compile=True)` instead
|
* xla.experimental.compile is deprecated, use
|
||||||
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
|
`tf.function(experimental_compile=True)` instead
|
||||||
(currently 'hlo' and 'optimized_hlo') for given input for given function.
|
* Added `tf.function.experimental_get_compiler_ir` which returns compiler
|
||||||
* <ADD RELEASE NOTES HERE>
|
IR (currently 'hlo' and 'optimized_hlo') for given input for given
|
||||||
|
function.
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* Tracing and Debugging:
|
* Tracing and Debugging:
|
||||||
* <ADD RELEASE NOTES HERE>
|
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* `tf.train.Checkpoint`:
|
* `tf.train.Checkpoint`:
|
||||||
* Now accepts a `root` argument in the initialization, which generates a
|
|
||||||
checkpoint with a root object. This allows users to create a `Checkpoint`
|
* Now accepts a `root` argument in the initialization, which generates a
|
||||||
object that is compatible with Keras `model.save_weights()` and
|
checkpoint with a root object. This allows users to create a
|
||||||
`model.load_weights`. The checkpoint is also compatible with the
|
`Checkpoint` object that is compatible with Keras `model.save_weights()`
|
||||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
and `model.load_weights`. The checkpoint is also compatible with the
|
||||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||||
will automatically find the checkpoint in the SavedModel.
|
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||||
|
will automatically find the checkpoint in the SavedModel.
|
||||||
|
|
||||||
* `tf.nn`:
|
* `tf.nn`:
|
||||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
|
||||||
|
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||||
|
|
||||||
|
* `tf.debugging`:
|
||||||
|
|
||||||
|
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
||||||
|
|
||||||
|
* `tf.print`:
|
||||||
|
|
||||||
|
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
|
||||||
|
didn't have the keys sorted, the keys and values were not being printed
|
||||||
|
in accordance with their correct mapping.
|
||||||
|
|
||||||
* Other:
|
* Other:
|
||||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
|
||||||
and "denylist" where possible. Please see
|
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||||
https://developers.google.com/style/word-list#blacklist for more context.
|
and "denylist" where possible. Please see
|
||||||
<ADD RELEASE NOTES HERE>
|
https://developers.google.com/style/word-list#blacklist for more
|
||||||
|
context.
|
||||||
|
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
|
||||||
|
rollout the new MLIR TPU bridge.
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
## Thanks to our Contributors
|
## Thanks to our Contributors
|
||||||
|
|
||||||
@ -500,42 +579,87 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
|||||||
# Release 2.3.0
|
# Release 2.3.0
|
||||||
|
|
||||||
## Major Features and Improvements
|
## Major Features and Improvements
|
||||||
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
|
|
||||||
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
|
||||||
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
|
||||||
|
|
||||||
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
|
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and
|
||||||
|
save resources:
|
||||||
|
|
||||||
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
|
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
||||||
|
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
||||||
|
|
||||||
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
|
In addition checkout the detailed
|
||||||
|
[guide](https://www.tensorflow.org/guide/data_performance_analysis) for
|
||||||
|
analyzing input pipeline performance with TF Profiler.
|
||||||
|
|
||||||
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
|
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
|
||||||
|
is now a stable API and no longer considered experimental for TensorFlow.
|
||||||
|
(earlier `tf.distribute.experimental.TPUStrategy`).
|
||||||
|
|
||||||
* TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new
|
||||||
|
tools: a memory profiler to visualize your model’s memory usage over time
|
||||||
|
and a [python tracer](https://www.tensorflow.org/guide/profiler#events)
|
||||||
|
which allows you to trace python function calls in your model. Usability
|
||||||
|
improvements include better diagnostic messages and
|
||||||
|
[profile options](https://tensorflow.org/guide/profiler#collect_performance_data)
|
||||||
|
to customize the host and device trace verbosity level.
|
||||||
|
|
||||||
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
* Introduces experimental support for Keras Preprocessing Layers API
|
||||||
|
([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly))
|
||||||
|
to handle data preprocessing operations, with support for composite tensor
|
||||||
|
inputs. Please see below for additional details on these layers.
|
||||||
|
|
||||||
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
|
* TFLite now properly supports dynamic shapes during conversion and inference.
|
||||||
|
We’ve also added opt-in support on Android and iOS for
|
||||||
|
[XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack),
|
||||||
|
a highly optimized set of CPU kernels, as well as opt-in support for
|
||||||
|
[executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
||||||
|
|
||||||
|
* Libtensorflow packages are available in GCS starting this release. We have
|
||||||
|
also started to
|
||||||
|
[release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
||||||
|
|
||||||
|
* The experimental Python API
|
||||||
|
[`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info)
|
||||||
|
now allows you to instrument a TensorFlow program and dump debugging
|
||||||
|
information to a directory on the file system. The directory can be read and
|
||||||
|
visualized by a new interactive dashboard in TensorBoard 2.3 called
|
||||||
|
[Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which
|
||||||
|
reveals the details of the TensorFlow program including graph structures,
|
||||||
|
history of op executions at the Python (eager) and intra-graph levels, the
|
||||||
|
runtime dtype, shape, and numerical composition of tensors, as well as their
|
||||||
|
code locations.
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
|
||||||
* `tf.data`
|
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
||||||
* Makes the following (breaking) changes to the `tf.data`.
|
* `tf.data`
|
||||||
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
|
* Makes the following (breaking) changes to the `tf.data`.
|
||||||
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
|
* C++ API: - `IteratorBase::RestoreInternal`,
|
||||||
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
|
`IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState`
|
||||||
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
|
become pure-virtual and subclasses are now expected to provide an
|
||||||
* `tf.keras`
|
implementation.
|
||||||
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
|
* The deprecated `DatasetBase::IsStateful` method is removed in favor of
|
||||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
`DatasetBase::CheckExternalState`.
|
||||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
* Deprecated overrides of `DatasetBase::MakeIterator` and
|
||||||
the output is different from (incorrect) previous versions. Note this
|
`MakeIteratorFromInputElement` are removed.
|
||||||
breaking change only impacts `tf.image.extract_glimpse` and
|
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and
|
||||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
`tensorflow::data::IteratorBase::SaveInput` has been extended with
|
||||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
`SerializationContext` argument to enable overriding the default policy
|
||||||
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
|
for the handling external state during iterator checkpointing. This is
|
||||||
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
not a backwards compatible change and all subclasses of `IteratorBase`
|
||||||
|
*need to be updated* accordingly.
|
||||||
|
* `tf.keras`
|
||||||
|
* Add a new `BackupAndRestore` callback for handling distributed training
|
||||||
|
failures & restarts. Please take a look at this
|
||||||
|
[tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
|
||||||
|
for details on how to use the callback.
|
||||||
|
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||||
|
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||||
|
the output is different from (incorrect) previous versions. Note this
|
||||||
|
breaking change only impacts `tf.image.extract_glimpse` and
|
||||||
|
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||||
|
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||||
|
existing C++ kernel `ExtractGlimpse` does not change either, so saved models
|
||||||
|
using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
||||||
|
|
||||||
## Known Caveats
|
## Known Caveats
|
||||||
* `tf.lite`
|
* `tf.lite`
|
||||||
@ -1105,7 +1229,7 @@ This release contains contributions from many people at Google, as well as:
|
|||||||
8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee
|
8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee
|
||||||
|
|
||||||
# Release 1.15.0
|
# Release 1.15.0
|
||||||
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
|
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
|
||||||
|
|
||||||
## Major Features and Improvements
|
## Major Features and Improvements
|
||||||
* As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
|
* As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
|
||||||
@ -1115,7 +1239,7 @@ This enables writing forward compatible code: by explicitly importing either `te
|
|||||||
* Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow.
|
* Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow.
|
||||||
* Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`.
|
* Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`.
|
||||||
* AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS.
|
* AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS.
|
||||||
* Adds `enable_tensor_equality()`, which switches the behavior such that:
|
* Adds `enable_tensor_equality()`, which switches the behavior such that:
|
||||||
* Tensors are no longer hashable.
|
* Tensors are no longer hashable.
|
||||||
* Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0.
|
* Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0.
|
||||||
|
|
||||||
@ -1271,12 +1395,12 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
|
|||||||
* TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow.
|
* TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow.
|
||||||
* Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation.
|
* Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation.
|
||||||
Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs.
|
Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs.
|
||||||
|
|
||||||
* `tf.contrib`:
|
* `tf.contrib`:
|
||||||
* `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely.
|
* `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely.
|
||||||
* Remove `tf.contrib.timeseries` dependency on TF distributions.
|
* Remove `tf.contrib.timeseries` dependency on TF distributions.
|
||||||
* Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`.
|
* Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`.
|
||||||
|
|
||||||
* `tf.estimator`:
|
* `tf.estimator`:
|
||||||
* Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`.
|
* Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`.
|
||||||
* Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method.
|
* Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method.
|
||||||
@ -1284,13 +1408,13 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
|
|||||||
* `Estimator.export_savedmodel` has been renamed to `export_saved_model`.
|
* `Estimator.export_savedmodel` has been renamed to `export_saved_model`.
|
||||||
* When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`.
|
* When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`.
|
||||||
* Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead.
|
* Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead.
|
||||||
|
|
||||||
* `tf.keras`:
|
* `tf.keras`:
|
||||||
* `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs.
|
* `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs.
|
||||||
* `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported.
|
* `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported.
|
||||||
* Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead.
|
* Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead.
|
||||||
* Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer <layer-name>` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information.
|
* Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer <layer-name>` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information.
|
||||||
|
|
||||||
* `tf.lite`:
|
* `tf.lite`:
|
||||||
* Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API.
|
* Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API.
|
||||||
* Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior.
|
* Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior.
|
||||||
@ -1525,8 +1649,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
|
|||||||
conversion. TensorRT initialization arguments are now passed wrapped in
|
conversion. TensorRT initialization arguments are now passed wrapped in
|
||||||
a named-tuple, `TrtConversionParams`, rather than as separate arguments
|
a named-tuple, `TrtConversionParams`, rather than as separate arguments
|
||||||
as in `TrtGraphConverter`.
|
as in `TrtGraphConverter`.
|
||||||
* Changed API to optimize TensorRT enginges during graph optimization.
|
* Changed API to optimize TensorRT engines during graph optimization. This
|
||||||
This is now done by calling `converter.build()` where previously
|
is now done by calling `converter.build()` where previously
|
||||||
`is_dynamic_op=False` would be set.
|
`is_dynamic_op=False` would be set.
|
||||||
* `converter.convert()` no longer returns a `tf.function`. Now the
|
* `converter.convert()` no longer returns a `tf.function`. Now the
|
||||||
function must be accessed from the saved model.
|
function must be accessed from the saved model.
|
||||||
@ -2536,7 +2660,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
|
|||||||
* [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier)
|
* [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier)
|
||||||
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
|
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
|
||||||
API supports broadcasting for Bijectors with new API changes.
|
API supports broadcasting for Bijectors with new API changes.
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by
|
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by
|
||||||
`variable_scope(tf.get_variable_scope(), ...)`.
|
`variable_scope(tf.get_variable_scope(), ...)`.
|
||||||
@ -3015,7 +3139,7 @@ Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim
|
|||||||
Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan,
|
Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan,
|
||||||
Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay,
|
Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay,
|
||||||
Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang,
|
Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang,
|
||||||
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
|
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
|
||||||
|
|
||||||
We are also grateful to all who filed issues or helped resolve them, asked and
|
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||||
answered questions, and were part of inspiring discussions.
|
answered questions, and were part of inspiring discussions.
|
||||||
|
@ -1485,6 +1485,7 @@ def main():
|
|||||||
'adding "--config=<>" to your build command. See .bazelrc for more '
|
'adding "--config=<>" to your build command. See .bazelrc for more '
|
||||||
'details.')
|
'details.')
|
||||||
config_info_line('mkl', 'Build with MKL support.')
|
config_info_line('mkl', 'Build with MKL support.')
|
||||||
|
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
|
||||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||||
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
||||||
config_info_line('numa', 'Build with NUMA support.')
|
config_info_line('numa', 'Build with NUMA support.')
|
||||||
|
@ -568,17 +568,7 @@ selects.config_setting_group(
|
|||||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = [
|
packages = ["//tensorflow/..."],
|
||||||
"//learning/brain/distribute/...",
|
|
||||||
"//learning/brain/swift/x10/...",
|
|
||||||
"//perftools/accelerators/xprof/api/...",
|
|
||||||
"//tensorflow/...",
|
|
||||||
"//tensorflow_estimator/python/estimator/...",
|
|
||||||
"//tensorflow_models/official/...",
|
|
||||||
"//third_party/py/autograph/...",
|
|
||||||
"//third_party/swift/tensorflow/x10/...",
|
|
||||||
"//third_party/swift/tensorflow_apis/...",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
@ -588,10 +578,8 @@ package_group(
|
|||||||
|
|
||||||
# Packages that use private types symbols, until they are exported.
|
# Packages that use private types symbols, until they are exported.
|
||||||
# TODO(b/154650521) Remove.
|
# TODO(b/154650521) Remove.
|
||||||
package_group(
|
# If this is modified, then copy.bara.sky must also be modified.
|
||||||
name = "types_whitelist",
|
package_group(name = "types_whitelist")
|
||||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Packages that use StructuredTensors.
|
# Packages that use StructuredTensors.
|
||||||
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
||||||
@ -719,7 +707,7 @@ tf_cc_shared_object(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||||
"//tensorflow/core:core_cpu_impl",
|
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||||
"//tensorflow/core:framework_internal_impl",
|
"//tensorflow/core:framework_internal_impl",
|
||||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||||
|
@ -217,6 +217,8 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/kernels:logging_ops",
|
"//tensorflow/core/kernels:logging_ops",
|
||||||
|
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
|
||||||
|
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
@ -254,6 +256,30 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_shape",
|
||||||
|
srcs = ["tf_shape.cc"],
|
||||||
|
hdrs = ["tf_shape.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":c_api_macros",
|
||||||
|
":tf_shape_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_shape_internal",
|
||||||
|
hdrs = ["tf_shape_internal.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
":conversion_macros",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf_status",
|
name = "tf_status",
|
||||||
srcs = ["tf_status.cc"],
|
srcs = ["tf_status.cc"],
|
||||||
|
@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||||
|
TF_Status* status) {
|
||||||
|
using tensorflow::RecordMutation;
|
||||||
|
mutex_lock l(graph->mu);
|
||||||
|
tensorflow::shape_inference::InferenceContext* ic =
|
||||||
|
graph->refiner.GetContext(&new_src.oper->node);
|
||||||
|
|
||||||
|
if (ic->num_outputs() <= new_src.index) {
|
||||||
|
status->status = tensorflow::errors::OutOfRange(
|
||||||
|
"Cannot update edge. Output index [", new_src.index,
|
||||||
|
"] is greater than the number of total outputs [", ic->num_outputs(),
|
||||||
|
"].");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
||||||
|
|
||||||
|
tensorflow::shape_inference::InferenceContext* ic_dst =
|
||||||
|
graph->refiner.GetContext(&dst.oper->node);
|
||||||
|
if (ic_dst->num_inputs() <= dst.index) {
|
||||||
|
status->status = tensorflow::errors::OutOfRange(
|
||||||
|
"Cannot update edge. Input index [", dst.index,
|
||||||
|
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
||||||
|
"].");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!ic_dst->MergeInput(dst.index, shape)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
||||||
|
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
||||||
|
&dst.oper->node, dst.index);
|
||||||
|
|
||||||
|
if (TF_GetCode(status) == TF_OK) {
|
||||||
|
// This modification only updates the destination node for
|
||||||
|
// the purposes of running this graph in a session. Thus, we don't
|
||||||
|
// record the source node as being modified.
|
||||||
|
RecordMutation(graph, *dst.oper, "updating input tensor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TF_Server functions ----------------------------------------------
|
// TF_Server functions ----------------------------------------------
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||||
|
@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
|
|||||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
||||||
const char* name, TF_Status* status);
|
const char* name, TF_Status* status);
|
||||||
|
|
||||||
|
// Update edge, switch input/ output in a node
|
||||||
|
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
|
||||||
|
TF_Input dst, TF_Status* status);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// In-process TensorFlow server functionality, for use in distributed training.
|
// In-process TensorFlow server functionality, for use in distributed training.
|
||||||
// A Server instance encapsulates a set of devices and a Session target that
|
// A Server instance encapsulates a set of devices and a Session target that
|
||||||
|
@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
|
|||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, UpdateEdge) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
// Make two scalar constants.
|
||||||
|
TF_Operation* one = ScalarConst(1, graph, s, "one");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
TF_Operation* two = ScalarConst(2, graph, s, "two");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Add oper.
|
||||||
|
TF_Operation* add = Add(one, two, graph, s, "add");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Add another oper to the graph.
|
||||||
|
TF_Operation* neg = Neg(add, graph, s, "neg");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
NodeDef node_def_neg;
|
||||||
|
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||||
|
EXPECT_EQ(string("add"), node_def_neg.input(0));
|
||||||
|
|
||||||
|
// update edge of neg
|
||||||
|
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
|
||||||
|
|
||||||
|
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||||
|
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
TF_DeleteGraph(graph);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TODO(skyewm): this test currently DCHECKs, change to bad status
|
TODO(skyewm): this test currently DCHECKs, change to bad status
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"if_tpu",
|
"if_libtpu",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_copts",
|
"tf_copts",
|
||||||
"tf_cuda_cc_test",
|
"tf_cuda_cc_test",
|
||||||
@ -116,7 +116,6 @@ filegroup(
|
|||||||
"immediate_execution_context.h",
|
"immediate_execution_context.h",
|
||||||
"immediate_execution_operation.h",
|
"immediate_execution_operation.h",
|
||||||
"immediate_execution_tensor_handle.h",
|
"immediate_execution_tensor_handle.h",
|
||||||
"mnist_gradients_testutil.h",
|
|
||||||
"tape.h",
|
"tape.h",
|
||||||
"tfe_cancellation_manager_internal.h",
|
"tfe_cancellation_manager_internal.h",
|
||||||
"tfe_context_internal.h",
|
"tfe_context_internal.h",
|
||||||
@ -290,7 +289,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
@ -314,6 +313,7 @@ cc_library(
|
|||||||
":gradients_internal",
|
":gradients_internal",
|
||||||
":gradients_util",
|
":gradients_util",
|
||||||
":tape",
|
":tape",
|
||||||
|
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||||
"//tensorflow/c/experimental/ops:array_ops",
|
"//tensorflow/c/experimental/ops:array_ops",
|
||||||
"//tensorflow/c/experimental/ops:math_ops",
|
"//tensorflow/c/experimental/ops:math_ops",
|
||||||
"//tensorflow/c/experimental/ops:nn_ops",
|
"//tensorflow/c/experimental/ops:nn_ops",
|
||||||
@ -354,7 +354,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||||
#endif
|
#endif
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
@ -729,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|||||||
|
|
||||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||||
if (opts->use_tfrt) {
|
if (opts->use_tfrt) {
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||||
#else
|
#else
|
||||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||||
@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetThreadLocalDevicePlacementPolicy(
|
|
||||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
|||||||
// safe to call this function from the async EagerExecutor threads.
|
// safe to call this function from the async EagerExecutor threads.
|
||||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||||
TFE_Context* ctx) {
|
TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||||
context->GetDevicePlacementPolicy());
|
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||||
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||||
tensorflow::EagerContext* context =
|
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return context->FindFunctionDef(name) != nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
|
|||||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||||
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||||
} TFE_ContextDevicePlacementPolicy;
|
} TFE_ContextDevicePlacementPolicy;
|
||||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
|
||||||
|
|
||||||
// Sets the default execution mode (sync/async). Note that this can be
|
// Sets the default execution mode (sync/async). Note that this can be
|
||||||
// overridden per thread using TFE_ContextSetExecutorForThread.
|
// overridden per thread using TFE_ContextSetExecutorForThread.
|
||||||
|
@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
|
|||||||
TestDistributedFunctionCancellation(false);
|
TestDistributedFunctionCancellation(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
// TODO(b/170399182): Update test once an alternative to using the function
|
||||||
|
// optimization hook is in place.
|
||||||
|
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
|
||||||
TestDistributedFunctionCancellation(true);
|
TestDistributedFunctionCancellation(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||||
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetExecutorForThread(executor->executor());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return new TFE_Executor(&context->Executor());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||||
tensorflow::EagerContext* context =
|
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||||
context->HostCPU()->parsed_name());
|
tensorflow::unwrap(ctx)->HostCPUParsedName());
|
||||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||||
void* data = tensorflow::port::Malloc(str.length());
|
void* data = tensorflow::port::Malloc(str.length());
|
||||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||||
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
|||||||
|
|
||||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||||
TF_Buffer* buf, TF_Status* status) {
|
TF_Buffer* buf, TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
auto* function_def = context->FindFunctionDef(function_name);
|
|
||||||
if (function_def == nullptr) {
|
if (function_def == nullptr) {
|
||||||
status->status = tensorflow::errors::NotFound(
|
status->status = tensorflow::errors::NotFound(
|
||||||
"Unable to find FunctionDef with name: ", function_name);
|
"Unable to find FunctionDef with name: ", function_name);
|
||||||
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetAllowSoftPlacement(enable);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetLogDevicePlacement(enable);
|
|
||||||
}
|
}
|
||||||
|
@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction(
|
|||||||
&ctx, incoming_gradients, result);
|
&ctx, incoming_gradients, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TapeVSpace::BuildOnesLike(TapeTensor t,
|
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
|
||||||
AbstractTensorHandle** result) const {
|
AbstractTensorHandle** result) const {
|
||||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||||
|
@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
|
|||||||
// allow us to trace the data dependencies between operations and hence compute
|
// allow us to trace the data dependencies between operations and hence compute
|
||||||
// gradients.
|
// gradients.
|
||||||
//
|
//
|
||||||
// This also implements `OnesLike` to create the default
|
|
||||||
// incoming gradients for tensors which do not already have an incoming
|
|
||||||
// gradient.
|
|
||||||
//
|
|
||||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||||
// for each op.
|
// for each op.
|
||||||
@ -233,7 +229,7 @@ class TapeVSpace
|
|||||||
std::vector<AbstractTensorHandle*>* result) const override;
|
std::vector<AbstractTensorHandle*>* result) const override;
|
||||||
|
|
||||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||||
Status BuildOnesLike(TapeTensor t,
|
Status BuildOnesLike(const TapeTensor& t,
|
||||||
AbstractTensorHandle** result) const override;
|
AbstractTensorHandle** result) const override;
|
||||||
|
|
||||||
// Looks up the ID of a Gradient.
|
// Looks up the ID of a Gradient.
|
||||||
|
@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
|||||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||||
|
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes
|
||||||
|
// y = sqrt(inputs[0])
|
||||||
|
// return grad(y, {inputs[0]})
|
||||||
|
Status SqrtGradModel(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs,
|
||||||
|
const GradientRegistry& registry) {
|
||||||
|
TapeVSpace vspace(ctx);
|
||||||
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
|
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||||
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||||
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
|
std::vector<AbstractTensorHandle*> out_grads;
|
||||||
|
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||||
|
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||||
|
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||||
|
/*output_gradients=*/{}, &out_grads,
|
||||||
|
/*build_default_zeros_grads=*/false));
|
||||||
|
for (auto sqrt_output : sqrt_outputs) {
|
||||||
|
sqrt_output->Unref();
|
||||||
|
}
|
||||||
|
outputs[0] = out_grads[0];
|
||||||
|
delete tape;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Computes
|
// Computes
|
||||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||||
// return grad(y, {inputs[0], inputs[1]})
|
// return grad(y, {inputs[0], inputs[1]})
|
||||||
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
|||||||
result_tensor = nullptr;
|
result_tensor = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(CppGradients, TestSqrtGrad) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
AbstractContextPtr ctx;
|
||||||
|
{
|
||||||
|
AbstractContext* ctx_raw = nullptr;
|
||||||
|
Status s =
|
||||||
|
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ctx.reset(ctx_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientRegistry registry;
|
||||||
|
Status s = RegisterGradients(®istry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
// Pseudo-code:
|
||||||
|
//
|
||||||
|
// tape.watch(x)
|
||||||
|
// y = sqrt(x)
|
||||||
|
// outputs = tape.gradient(y, x)
|
||||||
|
std::vector<AbstractTensorHandle*> outputs(1);
|
||||||
|
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||||
|
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
TF_Tensor* result_tensor;
|
||||||
|
s = getValue(outputs[0], &result_tensor);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||||
|
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||||
|
outputs[0]->Unref();
|
||||||
|
TF_DeleteTensor(result_tensor);
|
||||||
|
result_tensor = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||||
// Pseudo-code:
|
// Pseudo-code:
|
||||||
//
|
//
|
||||||
|
@ -29,8 +29,25 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/tstring.h"
|
#include "tensorflow/core/platform/tstring.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
class EagerExecutor;
|
||||||
|
|
||||||
|
// LINT.IfChange
|
||||||
|
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||||
|
enum ContextDevicePlacementPolicy {
|
||||||
|
// Running operations with input tensors on the wrong device will fail.
|
||||||
|
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||||
|
// Copy the tensor to the right device but log a warning.
|
||||||
|
DEVICE_PLACEMENT_WARN = 1,
|
||||||
|
// Silently copy the tensor, which has a performance cost since the operation
|
||||||
|
// will be blocked till the copy completes. This is the default policy.
|
||||||
|
DEVICE_PLACEMENT_SILENT = 2,
|
||||||
|
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||||
|
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||||
|
};
|
||||||
|
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||||
|
|
||||||
// Abstract interface to a context.
|
// Abstract interface to a context.
|
||||||
//
|
//
|
||||||
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// List attributes of available devices
|
// List attributes of available devices
|
||||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||||
|
|
||||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
|
||||||
|
|
||||||
// Initialize the step resource container for a training step. This is used
|
|
||||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
|
||||||
virtual void StartStep() = 0;
|
|
||||||
// Destroy the step resource container for a training step.
|
|
||||||
virtual void EndStep() = 0;
|
|
||||||
|
|
||||||
// Block until all pending nodes are finished.
|
// Block until all pending nodes are finished.
|
||||||
virtual Status AsyncWait() = 0;
|
virtual Status AsyncWait() = 0;
|
||||||
|
|
||||||
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// already exists.
|
// already exists.
|
||||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||||
|
|
||||||
|
// Find and return a added function by its name.
|
||||||
|
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||||
|
|
||||||
|
// Return the ParsedName of Host CPU device.
|
||||||
|
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||||
|
|
||||||
|
// Configure soft device placement policy.
|
||||||
|
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||||
|
|
||||||
|
// Configure device placement policy logging.
|
||||||
|
virtual void SetLogDevicePlacement(bool enable) = 0;
|
||||||
|
|
||||||
|
// Sets the device placement policy for the current thread.
|
||||||
|
virtual void SetThreadLocalDevicePlacementPolicy(
|
||||||
|
ContextDevicePlacementPolicy policy) = 0;
|
||||||
|
// Returns the device placement policy for the current thread.
|
||||||
|
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||||
|
|
||||||
// For LLVM style RTTI.
|
// For LLVM style RTTI.
|
||||||
static bool classof(const AbstractContext* ptr) {
|
static bool classof(const AbstractContext* ptr) {
|
||||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Following are legacy features in TF Eager Runtime.
|
||||||
|
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||||
|
// migrated to TFRT.
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Clear pending nodes in thread executors and kernel caches.
|
||||||
|
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||||
|
|
||||||
|
// Initialize the step resource container for a training step. This is used
|
||||||
|
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||||
|
virtual void StartStep() = 0;
|
||||||
|
// Destroy the step resource container for a training step.
|
||||||
|
virtual void EndStep() = 0;
|
||||||
|
|
||||||
|
// Return the Eager Executor for current thread. Please note that Eager
|
||||||
|
// Executor is only used in current TF but not in TFRT.
|
||||||
|
virtual EagerExecutor& Executor() = 0;
|
||||||
|
// Update the Eager Executor for current thread.
|
||||||
|
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||||
|
|
||||||
|
// Configure graph collection in RunMetadata.
|
||||||
|
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||||
: AbstractContext(kind) {}
|
: AbstractContext(kind) {}
|
||||||
|
@ -25,133 +25,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
#include "tensorflow/c/eager/gradients_internal.h"
|
#include "tensorflow/c/eager/gradients_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients_util.h"
|
#include "tensorflow/c/eager/gradients_util.h"
|
||||||
|
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
|
|
||||||
// ========================== Tape Ops ==============================
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using tensorflow::tracing::TracingOperation;
|
|
||||||
|
|
||||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
|
||||||
Status Add(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(add_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
|
||||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
bool transpose_a, bool transpose_b,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
|
||||||
/*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(matmul_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
|
||||||
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
|
||||||
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(mul_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
|
||||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(relu_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
|
|
||||||
// one-hot) and records it on the tape.
|
|
||||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
|
||||||
AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractTensorHandle* scores = inputs[0];
|
|
||||||
AbstractTensorHandle* labels = inputs[1];
|
|
||||||
|
|
||||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
|
||||||
/*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(sm_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 2; // returns loss values and backprop
|
|
||||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===================== Test Models to run =========================
|
//===================== Test Models to run =========================
|
||||||
|
|
||||||
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
registry)); // Compute x+y.
|
TF_RETURN_IF_ERROR(
|
||||||
|
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||||
/*transpose_b=*/false, registry)); // Compute x*y.
|
absl::MakeSpan(mm_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute x*y.
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W2)); // Watch W2.
|
tape->Watch(ToId(W2)); // Watch W2.
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
|
||||||
absl::MakeSpan(temp_outputs), "relu",
|
absl::MakeSpan(temp_outputs),
|
||||||
registry)); // Compute Relu(X*W1)
|
"relu")); // Compute Relu(X*W1)
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
|
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||||
absl::MakeSpan(temp_outputs), "matmul1",
|
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
|
||||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
"matmul1",
|
||||||
registry)); // Compute W2*Relu(X*W1)
|
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||||
|
|
||||||
AbstractTensorHandle* scores = temp_outputs[0];
|
AbstractTensorHandle* scores = temp_outputs[0];
|
||||||
|
|
||||||
temp_outputs.resize(2);
|
temp_outputs.resize(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||||
|
|
||||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||||
|
|
||||||
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W1));
|
tape->Watch(ToId(W1));
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/true,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/true,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
|
|
||||||
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
|
|||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
tape->Watch(ToId(inputs[0])); // Watch X
|
tape->Watch(ToId(inputs[0])); // Watch X
|
||||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"relu0", registry)); // Relu(X)
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||||
|
absl::MakeSpan(relu_outputs),
|
||||||
|
"relu0")); // Relu(X)
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
|
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W1)); // Watch W1.
|
tape->Watch(ToId(W1)); // Watch W1.
|
||||||
tape->Watch(ToId(W2)); // Watch W1.
|
tape->Watch(ToId(W2)); // Watch W1.
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
AbstractTensorHandle* mm = temp_outputs[0];
|
AbstractTensorHandle* mm = temp_outputs[0];
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
|
||||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||||
"relu0", registry));
|
"relu0"));
|
||||||
|
|
||||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
|
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||||
absl::MakeSpan(temp_outputs), "matmul1",
|
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
|
||||||
registry)); // W2*Relu(X*W1)
|
|
||||||
|
|
||||||
AbstractTensorHandle* scores = temp_outputs[0];
|
AbstractTensorHandle* scores = temp_outputs[0];
|
||||||
|
|
||||||
temp_outputs.resize(2);
|
temp_outputs.resize(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
"softmaxloss")); // W2*Relu(X*W1)
|
||||||
|
|
||||||
AbstractTensorHandle* loss = temp_outputs[0];
|
AbstractTensorHandle* loss = temp_outputs[0];
|
||||||
|
|
||||||
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
|
|||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"scalarMul0", registry)); // Compute eta*A
|
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
|
||||||
|
absl::MakeSpan(temp_outputs),
|
||||||
|
"scalarMul0")); // Compute eta*A
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
|
|
||||||
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
delete tape;
|
delete tape;
|
||||||
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"mul0", registry)); // Compute x*y
|
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
|
||||||
|
absl::MakeSpan(temp_outputs),
|
||||||
|
"mul0")); // Compute x*y
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
delete tape;
|
delete tape;
|
||||||
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
registry));
|
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0]; // loss values
|
outputs[0] = temp_outputs[0]; // loss values
|
||||||
|
|
||||||
|
@ -29,45 +29,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
// ========================== Tape Ops ==============================
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
|
||||||
Status Add(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
|
||||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
bool transpose_a, bool transpose_b,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
|
||||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
|
||||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
|
||||||
// tape.
|
|
||||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
|
||||||
AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// ====================== End Tape Ops ============================
|
|
||||||
|
|
||||||
// Computes
|
// Computes
|
||||||
// y = inputs[0] + inputs[1]
|
// y = inputs[0] + inputs[1]
|
||||||
|
@ -100,7 +100,8 @@ class VSpace {
|
|||||||
std::vector<Gradient*>* result) const = 0;
|
std::vector<Gradient*>* result) const = 0;
|
||||||
|
|
||||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||||
virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0;
|
virtual Status BuildOnesLike(const TapeTensor& t,
|
||||||
|
Gradient** result) const = 0;
|
||||||
|
|
||||||
// Looks up the ID of a Gradient.
|
// Looks up the ID of a Gradient.
|
||||||
virtual int64 TensorId(Gradient* tensor) const = 0;
|
virtual int64 TensorId(Gradient* tensor) const = 0;
|
||||||
|
@ -29,6 +29,7 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:env",
|
"//tensorflow/c:env",
|
||||||
|
"//tensorflow/c:logging",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
"//third_party/hadoop:hdfs",
|
"//third_party/hadoop:hdfs",
|
||||||
|
@ -22,11 +22,10 @@ limitations under the License.
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/synchronization/mutex.h"
|
|
||||||
#include "tensorflow/c/env.h"
|
#include "tensorflow/c/env.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/logging.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "third_party/hadoop/hdfs.h"
|
|
||||||
|
|
||||||
// Implementation of a filesystem for HADOOP environments.
|
// Implementation of a filesystem for HADOOP environments.
|
||||||
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
|
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
|
||||||
@ -149,15 +148,20 @@ class LibHDFS {
|
|||||||
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
||||||
if (hdfs_home != nullptr) {
|
if (hdfs_home != nullptr) {
|
||||||
auto JoinPath = [](std::string home, std::string lib) {
|
auto JoinPath = [](std::string home, std::string lib) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
if (home.back() != '\\') home.push_back('\\');
|
||||||
|
return home + "lib\\native\\" + lib;
|
||||||
|
#else
|
||||||
if (home.back() != '/') home.push_back('/');
|
if (home.back() != '/') home.push_back('/');
|
||||||
return home + "lib/native/" + lib;
|
return home + "lib/native/" + lib;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
||||||
TryLoadAndBind(path.c_str(), &handle_, status);
|
TryLoadAndBind(path.c_str(), &handle_, status);
|
||||||
if (TF_GetCode(status) == TF_OK) {
|
if (TF_GetCode(status) == TF_OK) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
|
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,13 +173,15 @@ class LibHDFS {
|
|||||||
void* handle_;
|
void* handle_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We rely on HDFS connection caching here. The HDFS client calls
|
// We implement connection caching in Tensorflow, which can significantly
|
||||||
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
|
// improve performance. Fixes #43187
|
||||||
// internally.
|
hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
||||||
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
const std::string& path, TF_Status* status) {
|
||||||
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
||||||
|
|
||||||
|
std::string cacheKey(scheme);
|
||||||
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
||||||
if (scheme == "file") {
|
if (scheme == "file") {
|
||||||
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
|
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
|
||||||
@ -200,15 +206,24 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
|||||||
SplitArchiveNameAndPath(&path_har, &namenode, status);
|
SplitArchiveNameAndPath(&path_har, &namenode, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
|
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
|
||||||
|
cacheKey += namenode;
|
||||||
} else {
|
} else {
|
||||||
libhdfs->hdfsBuilderSetNameNode(
|
libhdfs->hdfsBuilderSetNameNode(
|
||||||
builder, namenode.empty() ? "default" : namenode.c_str());
|
builder, namenode.empty() ? "default" : namenode.c_str());
|
||||||
|
cacheKey += namenode;
|
||||||
}
|
}
|
||||||
auto fs = libhdfs->hdfsBuilderConnect(builder);
|
absl::MutexLock l(&hadoop_file->connection_cache_lock);
|
||||||
if (fs == nullptr)
|
if (hadoop_file->connection_cache.find(cacheKey) ==
|
||||||
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
hadoop_file->connection_cache.end()) {
|
||||||
else
|
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
||||||
TF_SetStatus(status, TF_OK, "");
|
if (cacheFs == nullptr) {
|
||||||
|
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
||||||
|
return cacheFs;
|
||||||
|
}
|
||||||
|
hadoop_file->connection_cache[cacheKey] = cacheFs;
|
||||||
|
}
|
||||||
|
auto fs = hadoop_file->connection_cache[cacheKey];
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
return fs;
|
return fs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -409,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
|
|||||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_read_only_memory_region {
|
namespace tf_read_only_memory_region {
|
||||||
|
// Hadoop doesn't support Readonly Memory Region
|
||||||
// TODO(vnvo2409): Implement later
|
|
||||||
|
|
||||||
} // namespace tf_read_only_memory_region
|
} // namespace tf_read_only_memory_region
|
||||||
|
|
||||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_hadoop_filesystem {
|
namespace tf_hadoop_filesystem {
|
||||||
|
|
||||||
|
HadoopFile::HadoopFile(TF_Status* status)
|
||||||
|
: libhdfs(new LibHDFS(status)),
|
||||||
|
connection_cache_lock(),
|
||||||
|
connection_cache() {}
|
||||||
|
|
||||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||||
filesystem->plugin_filesystem = new LibHDFS(status);
|
filesystem->plugin_filesystem = new HadoopFile(status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cleanup(TF_Filesystem* filesystem) {
|
void Cleanup(TF_Filesystem* filesystem) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
delete libhdfs;
|
delete libhdfs;
|
||||||
|
delete hadoop_file;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_RandomAccessFile* file, TF_Status* status) {
|
TF_RandomAccessFile* file, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -448,8 +469,9 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_WritableFile* file, TF_Status* status) {
|
TF_WritableFile* file, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -465,8 +487,9 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_WritableFile* file, TF_Status* status) {
|
TF_WritableFile* file, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -497,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
|||||||
|
|
||||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -513,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_FileStatistics* stats, TF_Status* status) {
|
TF_FileStatistics* stats, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -532,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return -1;
|
if (TF_GetCode(status) != TF_OK) return -1;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -553,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -568,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -583,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -619,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||||
const char* dst, TF_Status* status) {
|
const char* dst, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, src, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, src, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
|
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
|
||||||
@ -640,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
|||||||
|
|
||||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||||
char*** entries, TF_Status* status) {
|
char*** entries, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return -1;
|
if (TF_GetCode(status) != TF_OK) return -1;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -677,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
|||||||
return num_entries;
|
return num_entries;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(vnvo2409): Implement later
|
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||||
|
return strdup(uri);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tf_hadoop_filesystem
|
} // namespace tf_hadoop_filesystem
|
||||||
|
|
||||||
@ -685,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
|||||||
const char* uri) {
|
const char* uri) {
|
||||||
TF_SetFilesystemVersionMetadata(ops);
|
TF_SetFilesystemVersionMetadata(ops);
|
||||||
ops->scheme = strdup(uri);
|
ops->scheme = strdup(uri);
|
||||||
|
|
||||||
|
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||||
|
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||||
|
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||||
|
|
||||||
|
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||||
|
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||||
|
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||||
|
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||||
|
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||||
|
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||||
|
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||||
|
|
||||||
|
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||||
|
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||||
|
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
|
||||||
|
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
|
||||||
|
ops->filesystem_ops->new_random_access_file =
|
||||||
|
tf_hadoop_filesystem::NewRandomAccessFile;
|
||||||
|
ops->filesystem_ops->new_writable_file =
|
||||||
|
tf_hadoop_filesystem::NewWritableFile;
|
||||||
|
ops->filesystem_ops->new_appendable_file =
|
||||||
|
tf_hadoop_filesystem::NewAppendableFile;
|
||||||
|
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||||
|
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||||
|
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
|
||||||
|
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
|
||||||
|
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
|
||||||
|
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
|
||||||
|
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
|
||||||
|
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
|
||||||
|
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
|
||||||
|
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
|
||||||
|
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||||
|
@ -15,10 +15,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "third_party/hadoop/hdfs.h"
|
||||||
|
|
||||||
void ParseHadoopPath(const std::string& fname, std::string* scheme,
|
void ParseHadoopPath(const std::string& fname, std::string* scheme,
|
||||||
std::string* namenode, std::string* path);
|
std::string* namenode, std::string* path);
|
||||||
@ -43,6 +46,14 @@ void Close(const TF_WritableFile* file, TF_Status* status);
|
|||||||
} // namespace tf_writable_file
|
} // namespace tf_writable_file
|
||||||
|
|
||||||
namespace tf_hadoop_filesystem {
|
namespace tf_hadoop_filesystem {
|
||||||
|
typedef struct HadoopFile {
|
||||||
|
LibHDFS* libhdfs;
|
||||||
|
absl::Mutex connection_cache_lock;
|
||||||
|
std::map<std::string, hdfsFS> connection_cache
|
||||||
|
ABSL_GUARDED_BY(connection_cache_lock);
|
||||||
|
HadoopFile(TF_Status* status);
|
||||||
|
} HadoopFile;
|
||||||
|
|
||||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||||
void Cleanup(TF_Filesystem* filesystem);
|
void Cleanup(TF_Filesystem* filesystem);
|
||||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
|
@ -352,6 +352,48 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) {
|
|||||||
EXPECT_TF_OK(status_);
|
EXPECT_TF_OK(status_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) {
|
||||||
|
static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1";
|
||||||
|
putenv(set_disable_var);
|
||||||
|
|
||||||
|
const std::string path = TmpDir("ReadWhileOverwriting");
|
||||||
|
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
|
||||||
|
|
||||||
|
const string content1 = "content1";
|
||||||
|
WriteString(path, content1);
|
||||||
|
ASSERT_TF_OK(status_);
|
||||||
|
|
||||||
|
auto reader = GetReader();
|
||||||
|
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||||
|
reader.get(), status_);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
result.resize(content1.size());
|
||||||
|
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
|
||||||
|
&result[0], status_);
|
||||||
|
result.resize(read);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
EXPECT_EQ(content1, result);
|
||||||
|
|
||||||
|
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
|
||||||
|
string content2 = "overwrite";
|
||||||
|
WriteString(path, content1 + content2);
|
||||||
|
ASSERT_TF_OK(status_);
|
||||||
|
|
||||||
|
result.resize(content2.size());
|
||||||
|
read = tf_random_access_file::Read(reader.get(), content1.size(),
|
||||||
|
content2.size(), &result[0], status_);
|
||||||
|
result.resize(read);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
EXPECT_EQ(0, result.size());
|
||||||
|
|
||||||
|
static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0";
|
||||||
|
putenv(set_enable_var);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HadoopFileSystemTest, HarSplit) {
|
TEST_F(HadoopFileSystemTest, HarSplit) {
|
||||||
const std::string har_path =
|
const std::string har_path =
|
||||||
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
|
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
|
||||||
|
@ -24,6 +24,7 @@ using std::vector;
|
|||||||
using tensorflow::ops::Conj;
|
using tensorflow::ops::Conj;
|
||||||
using tensorflow::ops::MatMul;
|
using tensorflow::ops::MatMul;
|
||||||
using tensorflow::ops::Mul;
|
using tensorflow::ops::Mul;
|
||||||
|
using tensorflow::ops::SqrtGrad;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
|
|||||||
AbstractTensorHandlePtr exp_;
|
AbstractTensorHandlePtr exp_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SqrtGradientFunction : public GradientFunction {
|
||||||
|
public:
|
||||||
|
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||||
|
sqrt->Ref();
|
||||||
|
}
|
||||||
|
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||||
|
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||||
|
std::string name = "Sqrt_Grad";
|
||||||
|
grad_outputs->resize(1);
|
||||||
|
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||||
|
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
~SqrtGradientFunction() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractTensorHandlePtr sqrt_;
|
||||||
|
};
|
||||||
|
|
||||||
class MatMulGradientFunction : public GradientFunction {
|
class MatMulGradientFunction : public GradientFunction {
|
||||||
public:
|
public:
|
||||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||||
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
|||||||
return new BackwardFunction(gradient_function, default_gradients);
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||||
|
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||||
|
// For ops with a single output, the gradient function is not called if there
|
||||||
|
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||||
|
// grads in this case.
|
||||||
|
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||||
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
|
|
||||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||||
|
@ -38,3 +38,29 @@ cc_library(
|
|||||||
"//tensorflow/c/eager:gradients_internal",
|
"//tensorflow/c/eager:gradients_internal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tape",
|
||||||
|
hdrs = [
|
||||||
|
"tape_context.h",
|
||||||
|
"tape_operation.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tape_context",
|
||||||
|
":tape_operation",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_required_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"tape_context.h",
|
||||||
|
"tape_operation.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
|||||||
return exp_op->Execute(outputs, &num_retvals);
|
return exp_op->Execute(outputs, &num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
|
|||||||
|
|
||||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -91,15 +91,24 @@ cc_library(
|
|||||||
":signature_def_function_metadata",
|
":signature_def_function_metadata",
|
||||||
"//tensorflow/c/eager:immediate_execution_operation",
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "signature_def_function_metadata",
|
name = "signature_def_function_metadata",
|
||||||
|
srcs = [
|
||||||
|
"signature_def_function_metadata.cc",
|
||||||
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"signature_def_function_metadata.h",
|
"signature_def_function_metadata.h",
|
||||||
],
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_spec",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -268,6 +277,20 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_spec",
|
||||||
|
srcs = [
|
||||||
|
"tensor_spec.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"tensor_spec.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "tf_concrete_function_loading_test",
|
name = "tf_concrete_function_loading_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -92,6 +92,8 @@ cc_library(
|
|||||||
"//tensorflow/c/eager:immediate_execution_context",
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
"//tensorflow/c/eager:immediate_execution_operation",
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
@ -164,6 +166,8 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/common_runtime/eager:context",
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
@ -30,14 +32,26 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using StructuredValueDictEntry =
|
||||||
|
protobuf::MapPair<std::string, StructuredValue>;
|
||||||
|
|
||||||
|
using NamedParamMap =
|
||||||
|
gtl::FlatMap<StringPiece, const TensorSpecProto*, StringPieceHasher>;
|
||||||
|
|
||||||
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
||||||
const PartiallyRevivedObjects& objects) {
|
const PartiallyRevivedObjects& objects) {
|
||||||
for (const auto& id_and_resource : objects.restored_resources) {
|
for (const auto& id_and_resource : objects.restored_resources) {
|
||||||
@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<SignatureDefParam> SignatureDefParamsFromNamedParamMap(
|
||||||
|
const NamedParamMap& params) {
|
||||||
|
// The underlying functiondef associated with the SignatureDef has
|
||||||
|
// nest.flattened inputs and outputs, which are sorted by string key.
|
||||||
|
std::vector<SignatureDefParam> result;
|
||||||
|
result.reserve(params.size());
|
||||||
|
for (const auto& named_param : params) {
|
||||||
|
result.push_back(SignatureDefParam(std::string(named_param.first),
|
||||||
|
TensorSpec(*named_param.second)));
|
||||||
|
}
|
||||||
|
std::sort(result.begin(), result.end(),
|
||||||
|
[](const SignatureDefParam& x, const SignatureDefParam& y) {
|
||||||
|
return x.name() < y.name();
|
||||||
|
});
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureDefArgsFromInputs takes the "canonicalized_input_signature"
|
||||||
|
// field of a SavedConcreteFunction, ensures it conforms to the structure of
|
||||||
|
// tuple(tuple(), dict<string,TensorSpec>()), and "returns" a list of
|
||||||
|
// SignatureDefParams of the SignatureDefFunction's arguments.
|
||||||
|
Status SignatureDefArgsFromInputs(
|
||||||
|
const StructuredValue& canonicalized_input_signature,
|
||||||
|
std::vector<SignatureDefParam>* out) {
|
||||||
|
// Note(bmzhao): canonicalized_input_signature should be a tuple of
|
||||||
|
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
|
||||||
|
// string keys to TensorSpecs.
|
||||||
|
if (!canonicalized_input_signature.has_tuple_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||||
|
"of form tuple(tuple(), dict()), but was instead: \n",
|
||||||
|
canonicalized_input_signature.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const TupleValue& args_kwargs_tuple =
|
||||||
|
canonicalized_input_signature.tuple_value();
|
||||||
|
if (args_kwargs_tuple.values_size() != 2) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||||
|
"a tuple of two elements (args, kwargs), but was instead: \n",
|
||||||
|
args_kwargs_tuple.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const StructuredValue& args = args_kwargs_tuple.values(0);
|
||||||
|
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature's args"
|
||||||
|
"should be an empty tuple, but instead got: \n",
|
||||||
|
args.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
|
||||||
|
if (!kwargs.has_dict_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||||
|
"should be a dictionary, but instead got: \n",
|
||||||
|
kwargs.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const DictValue& kwargs_dict = kwargs.dict_value();
|
||||||
|
NamedParamMap result;
|
||||||
|
result.reserve(kwargs_dict.fields_size());
|
||||||
|
|
||||||
|
for (const auto& key_value : kwargs_dict.fields()) {
|
||||||
|
const std::string& key = key_value.first;
|
||||||
|
const StructuredValue& value = key_value.second;
|
||||||
|
if (!value.has_tensor_spec_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||||
|
"dictionary contained a non-tensorspec value for key-value pair: \n",
|
||||||
|
"Key: ", key, "Value: \n", value.DebugString());
|
||||||
|
}
|
||||||
|
result[key] = &value.tensor_spec_value();
|
||||||
|
}
|
||||||
|
|
||||||
|
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||||
|
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureDefReturnsFromOutputs takes the "output_signature" field of a
|
||||||
|
// SavedConcreteFunction, ensures it conforms to the structure of
|
||||||
|
// dict<string,TensorSpec>(), and "returns" a list of SignatureDefParams of the
|
||||||
|
// SignatureDefFunction's returns.
|
||||||
|
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
|
||||||
|
std::vector<SignatureDefParam>* out) {
|
||||||
|
if (!output_signature.has_dict_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's output_signature must be a dictionary, but "
|
||||||
|
"instead got: ",
|
||||||
|
output_signature.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const DictValue& output_dict = output_signature.dict_value();
|
||||||
|
NamedParamMap result;
|
||||||
|
result.reserve(output_dict.fields_size());
|
||||||
|
|
||||||
|
for (const auto& key_value : output_dict.fields()) {
|
||||||
|
const std::string& key = key_value.first;
|
||||||
|
const StructuredValue& value = key_value.second;
|
||||||
|
if (!value.has_tensor_spec_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's output_signature dictionary contained a "
|
||||||
|
"non-tensorspec value for key-value pair: \n",
|
||||||
|
"Key: ", key, "Value: \n", value.DebugString());
|
||||||
|
}
|
||||||
|
result[key] = &value.tensor_spec_value();
|
||||||
|
}
|
||||||
|
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||||
|
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The implementation takes advantage of the fact that SignatureDefFunction's
|
||||||
|
// "traced" Signature wrapper function always has inputs/outputs of dictionaries
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178
|
||||||
|
// Additionally, we take advantage of the fact that the SignatureDefFunction's
|
||||||
|
// associated functiondef has lexicographically ordered inputs/outputs due to
|
||||||
|
// nest.flatten.
|
||||||
|
Status LoadSignatureDefFunctionMetadata(
|
||||||
|
const SavedConcreteFunction& saved_concrete_function,
|
||||||
|
SignatureDefFunctionMetadata* out) {
|
||||||
|
std::vector<SignatureDefParam> args;
|
||||||
|
TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs(
|
||||||
|
saved_concrete_function.canonicalized_input_signature(), &args));
|
||||||
|
|
||||||
|
std::vector<SignatureDefParam> rets;
|
||||||
|
TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs(
|
||||||
|
saved_concrete_function.output_signature(), &rets));
|
||||||
|
|
||||||
|
*out = SignatureDefFunctionMetadata(std::move(args), std::move(rets));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
// This function finds the necessary captures, then forwards to the builder
|
// This function finds the necessary captures, then forwards to the builder
|
||||||
// method
|
// method
|
||||||
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
|
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
|
||||||
@ -162,10 +312,14 @@ Status CreateSignatureDefFunction(
|
|||||||
&capture_handle));
|
&capture_handle));
|
||||||
captures.push_back(capture_handle);
|
captures.push_back(capture_handle);
|
||||||
}
|
}
|
||||||
// TODO(bmzhao): Create Metadata here
|
|
||||||
|
SignatureDefFunctionMetadata metadata;
|
||||||
|
TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata(
|
||||||
|
*builder.saved_concrete_func, &metadata));
|
||||||
|
|
||||||
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
|
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
|
||||||
/*captures=*/std::move(captures),
|
/*captures=*/std::move(captures),
|
||||||
/*metadata=*/{},
|
/*metadata=*/std::move(metadata),
|
||||||
/*ctx=*/ctx,
|
/*ctx=*/ctx,
|
||||||
/*out=*/out);
|
/*out=*/out);
|
||||||
}
|
}
|
||||||
@ -378,6 +532,7 @@ Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx,
|
|||||||
revived->variables = std::move(variables);
|
revived->variables = std::move(variables);
|
||||||
revived->assets = std::move(assets);
|
revived->assets = std::move(assets);
|
||||||
revived->constants = std::move(constants);
|
revived->constants = std::move(constants);
|
||||||
|
revived->signatures_map = std::move(signatures_map);
|
||||||
|
|
||||||
// 3b. Move over resources.
|
// 3b. Move over resources.
|
||||||
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));
|
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));
|
||||||
|
@ -36,7 +36,14 @@ namespace tensorflow {
|
|||||||
// Notably, resources and functions can be in a state where they reference
|
// Notably, resources and functions can be in a state where they reference
|
||||||
// other resources/functions that have not been constructed yet. We collect
|
// other resources/functions that have not been constructed yet. We collect
|
||||||
// *all* objects in a partially valid state here, then properly initialize
|
// *all* objects in a partially valid state here, then properly initialize
|
||||||
// resources and functions.
|
// resources and functions. Implementation-wise, PartiallyRevivedObjects
|
||||||
|
// contains maps keyed by the node number of the SavedObjectGraph, and map to an
|
||||||
|
// object of the corresponding type. So, if node 2 in the object graph is a
|
||||||
|
// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a
|
||||||
|
// tensorflow::Variable object. The only exception to this is the
|
||||||
|
// "signatures_map", which is keyed by the "signature" key
|
||||||
|
// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89),
|
||||||
|
// and maps to the SignatureDefFunction node in the SavedObjectGraph.
|
||||||
struct PartiallyRevivedObjects {
|
struct PartiallyRevivedObjects {
|
||||||
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
||||||
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
||||||
@ -44,6 +51,7 @@ struct PartiallyRevivedObjects {
|
|||||||
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
|
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
|
||||||
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
|
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
|
||||||
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
|
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
|
||||||
|
gtl::FlatMap<std::string, int> signatures_map;
|
||||||
|
|
||||||
Status Build(ImmediateExecutionContext* ctx,
|
Status Build(ImmediateExecutionContext* ctx,
|
||||||
const SavedObjectGraph& obj_graph, RevivedObjects* revived);
|
const SavedObjectGraph& obj_graph, RevivedObjects* revived);
|
||||||
|
@ -44,6 +44,7 @@ struct RevivedObjects {
|
|||||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
|
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
|
||||||
signature_def_functions;
|
signature_def_functions;
|
||||||
gtl::FlatMap<int, RestoredResource> restored_resources;
|
gtl::FlatMap<int, RestoredResource> restored_resources;
|
||||||
|
gtl::FlatMap<std::string, int> signatures_map;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,8 +20,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
|||||||
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
Status Variable::CreateUninitialized(
|
||||||
DataType dtype, TensorShape shape,
|
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||||
absl::optional<std::string> name,
|
absl::optional<std::string> name, const char* raw_device_name,
|
||||||
const char* raw_device_name,
|
const std::vector<std::string>& component_devices,
|
||||||
std::unique_ptr<Variable>* output) {
|
std::unique_ptr<Variable>* output) {
|
||||||
ImmediateTensorHandlePtr handle;
|
ImmediateTensorHandlePtr handle;
|
||||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
|
||||||
ctx, dtype, shape, raw_device_name, &handle));
|
|
||||||
|
|
||||||
|
if (component_devices.empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||||
|
ctx, dtype, shape, raw_device_name, &handle));
|
||||||
|
output->reset(
|
||||||
|
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tensorflow::isa<EagerContext>(ctx)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Can only load distributed variables with EagerContext.");
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
|
||||||
|
|
||||||
|
std::vector<TensorHandle*> handles;
|
||||||
|
for (const auto& device : component_devices) {
|
||||||
|
ImmediateTensorHandlePtr handlePtr;
|
||||||
|
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||||
|
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
|
||||||
|
&handlePtr));
|
||||||
|
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
|
||||||
|
return errors::Internal("Returned replica handle has unsupported type.");
|
||||||
|
}
|
||||||
|
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
|
||||||
|
}
|
||||||
|
TensorHandle* packed_handle;
|
||||||
|
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||||
|
std::move(handles), eager_ctx, &packed_handle));
|
||||||
|
// The call to `CreatePackedHandle` incremented the handles' reference count,
|
||||||
|
// which we must now decrement to make the packed handle the owner of those
|
||||||
|
// handles. We can't loop through the `handles` vector because it was
|
||||||
|
// `std::move`d in the call above.
|
||||||
|
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
|
||||||
|
TensorHandle* component;
|
||||||
|
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
|
||||||
|
component->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
handle.reset(packed_handle);
|
||||||
output->reset(
|
output->reset(
|
||||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||||
return Status();
|
return Status();
|
||||||
|
@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible {
|
|||||||
public:
|
public:
|
||||||
// Creates an uninitialized resource variable. Note that a caller must
|
// Creates an uninitialized resource variable. Note that a caller must
|
||||||
// call "assign" to associate a value with the variable.
|
// call "assign" to associate a value with the variable.
|
||||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
static Status CreateUninitialized(
|
||||||
DataType dtype, TensorShape shape,
|
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||||
absl::optional<std::string> name,
|
absl::optional<std::string> name, const char* raw_device_name,
|
||||||
const char* raw_device_name,
|
const std::vector<std::string>& component_devices,
|
||||||
std::unique_ptr<Variable>* output);
|
std::unique_ptr<Variable>* output);
|
||||||
|
|
||||||
// The dtype of the underlying variable.
|
// The dtype of the underlying variable.
|
||||||
DataType dtype();
|
DataType dtype();
|
||||||
|
@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
|||||||
const std::string& name = variable.name();
|
const std::string& name = variable.name();
|
||||||
tensorflow::TensorShape shape(variable.shape());
|
tensorflow::TensorShape shape(variable.shape());
|
||||||
tensorflow::DataType dtype = variable.dtype();
|
tensorflow::DataType dtype = variable.dtype();
|
||||||
|
std::vector<std::string> component_devices;
|
||||||
|
|
||||||
|
for (const auto& component :
|
||||||
|
variable.experimental_distributed_variable_components()) {
|
||||||
|
component_devices.push_back(component.device());
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
||||||
ctx, dtype, shape, name,
|
ctx, dtype, shape, name,
|
||||||
variable.device().empty() ? nullptr : variable.device().c_str(), output));
|
variable.device().empty() ? nullptr : variable.device().c_str(),
|
||||||
|
component_devices, output));
|
||||||
return Status();
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -519,6 +526,8 @@ Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
objects->signatures_map = std::move(signatures_map);
|
||||||
|
|
||||||
return Status();
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
|||||||
Status status;
|
Status status;
|
||||||
std::unique_ptr<Variable> var;
|
std::unique_ptr<Variable> var;
|
||||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||||
absl::nullopt, nullptr, &var));
|
absl::nullopt, nullptr, {}, &var));
|
||||||
|
|
||||||
// Create a TensorHandle
|
// Create a TensorHandle
|
||||||
ImmediateTensorHandlePtr expected_handle =
|
ImmediateTensorHandlePtr expected_handle =
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec)
|
||||||
|
: name_(std::move(name)), spec_(std::move(spec)) {}
|
||||||
|
|
||||||
|
const std::string& SignatureDefParam::name() const { return name_; }
|
||||||
|
|
||||||
|
const TensorSpec& SignatureDefParam::spec() const { return spec_; }
|
||||||
|
|
||||||
|
SignatureDefFunctionMetadata::SignatureDefFunctionMetadata(
|
||||||
|
std::vector<SignatureDefParam> arguments,
|
||||||
|
std::vector<SignatureDefParam> returns)
|
||||||
|
: arguments_(std::move(arguments)), returns_(std::move(returns)) {}
|
||||||
|
|
||||||
|
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::arguments()
|
||||||
|
const {
|
||||||
|
return arguments_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::returns()
|
||||||
|
const {
|
||||||
|
return returns_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -16,10 +16,42 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// SignatureDefParam represents a named Tensor input or output to a
|
||||||
|
// SignatureDefFunction.
|
||||||
|
class SignatureDefParam {
|
||||||
|
public:
|
||||||
|
SignatureDefParam(std::string name, TensorSpec spec);
|
||||||
|
|
||||||
|
const std::string& name() const;
|
||||||
|
|
||||||
|
const TensorSpec& spec() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name_;
|
||||||
|
TensorSpec spec_;
|
||||||
|
};
|
||||||
|
|
||||||
class SignatureDefFunctionMetadata {
|
class SignatureDefFunctionMetadata {
|
||||||
// TODO(bmzhao): Fill in with fields as necessary
|
public:
|
||||||
|
SignatureDefFunctionMetadata() = default;
|
||||||
|
SignatureDefFunctionMetadata(std::vector<SignatureDefParam> arguments,
|
||||||
|
std::vector<SignatureDefParam> returns);
|
||||||
|
|
||||||
|
const std::vector<SignatureDefParam>& arguments() const;
|
||||||
|
const std::vector<SignatureDefParam>& returns() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<SignatureDefParam> arguments_;
|
||||||
|
std::vector<SignatureDefParam> returns_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TensorSpec::TensorSpec()
|
||||||
|
: shape_(std::initializer_list<int64>()), dtype_(DT_FLOAT) {}
|
||||||
|
|
||||||
|
TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype)
|
||||||
|
: shape_(std::move(shape)), dtype_(dtype) {}
|
||||||
|
|
||||||
|
TensorSpec::TensorSpec(const TensorSpecProto& proto)
|
||||||
|
: shape_(proto.shape()), dtype_(proto.dtype()) {}
|
||||||
|
|
||||||
|
const PartialTensorShape& TensorSpec::shape() const { return shape_; }
|
||||||
|
|
||||||
|
DataType TensorSpec::dtype() const { return dtype_; }
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Note(bmzhao): TensorSpec deliberately does not store the "name" from a
|
||||||
|
// TensorSpecProto. From edloper@, "Names should really be associated with
|
||||||
|
// parameters, not the tensors inside those parameters. This would be
|
||||||
|
// inconsistent with the corresponding Python class, but I don't think that's
|
||||||
|
// necessarily a problem. If it turns out later that we really need a name
|
||||||
|
// attribute here, we can always add it back in; but let's see how far we can
|
||||||
|
// get without it."
|
||||||
|
class TensorSpec {
|
||||||
|
public:
|
||||||
|
// Constructs a scalar, DT_FLOAT TensorSpec
|
||||||
|
TensorSpec();
|
||||||
|
|
||||||
|
TensorSpec(PartialTensorShape shape, DataType dtype);
|
||||||
|
|
||||||
|
explicit TensorSpec(const TensorSpecProto& proto);
|
||||||
|
|
||||||
|
const PartialTensorShape& shape() const;
|
||||||
|
DataType dtype() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
PartialTensorShape shape_;
|
||||||
|
DataType dtype_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
@ -192,9 +192,23 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
|||||||
|
|
||||||
Status TFSavedModelAPI::GetSignatureDefFunction(
|
Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||||
const std::string& signature_def_key, SignatureDefFunction** function) {
|
const std::string& signature_def_key, SignatureDefFunction** function) {
|
||||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
auto signatures_iter =
|
||||||
return errors::Unimplemented(
|
revived_objects_.signatures_map.find(signature_def_key);
|
||||||
"Retrieving SignatureDef functions is unimplemented currently");
|
if (signatures_iter == revived_objects_.signatures_map.end()) {
|
||||||
|
return errors::NotFound("No signature with key ", signature_def_key,
|
||||||
|
" was found");
|
||||||
|
}
|
||||||
|
int node = signatures_iter->second;
|
||||||
|
|
||||||
|
auto function_iter = revived_objects_.signature_def_functions.find(node);
|
||||||
|
if (function_iter == revived_objects_.signature_def_functions.end()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Unable to find SignatureDefFunction associated with key ",
|
||||||
|
signature_def_key, " despite key being valid.");
|
||||||
|
}
|
||||||
|
|
||||||
|
*function = function_iter->second.get();
|
||||||
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||||
|
@ -224,6 +224,8 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":signature_def_function_metadata_type",
|
":signature_def_function_metadata_type",
|
||||||
|
":signature_def_param_list",
|
||||||
|
":signature_def_param_list_type",
|
||||||
"//tensorflow/c:c_api_macros",
|
"//tensorflow/c:c_api_macros",
|
||||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
],
|
],
|
||||||
@ -240,6 +242,104 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param",
|
||||||
|
srcs = [
|
||||||
|
"signature_def_param.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":signature_def_param_type",
|
||||||
|
":tensor_spec",
|
||||||
|
":tensor_spec_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c:tf_shape_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param_type",
|
||||||
|
hdrs = [
|
||||||
|
"signature_def_param_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param_list",
|
||||||
|
srcs = [
|
||||||
|
"signature_def_param_list.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":signature_def_param",
|
||||||
|
":signature_def_param_list_type",
|
||||||
|
":signature_def_param_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param_list_type",
|
||||||
|
hdrs = [
|
||||||
|
"signature_def_param_list_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_spec",
|
||||||
|
srcs = [
|
||||||
|
"tensor_spec.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:tensor_spec.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_spec_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
"//tensorflow/c:tf_shape",
|
||||||
|
"//tensorflow/c:tf_shape_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_spec_type",
|
||||||
|
hdrs = [
|
||||||
|
"tensor_spec_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
|
"//tensorflow/c:tf_shape_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "saved_model_api_test",
|
name = "saved_model_api_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -252,6 +352,8 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":saved_model_api_type",
|
":saved_model_api_type",
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
"//tensorflow/c:tf_shape",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_tensor",
|
"//tensorflow/c:tf_tensor",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
@ -260,6 +362,11 @@ tf_cc_test(
|
|||||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
|
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
|
||||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_function",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:tensor_spec",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
@ -24,6 +24,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/c/tf_shape.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
@ -143,6 +150,146 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This tests running the "serving_default" SignatureDefFunction from the
|
||||||
|
// VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
|
||||||
|
// protobuf in the metagraph looks like:
|
||||||
|
// signature_def: {
|
||||||
|
// key : "serving_default"
|
||||||
|
// value: {
|
||||||
|
// inputs: {
|
||||||
|
// key : "a"
|
||||||
|
// value: {
|
||||||
|
// name : "serving_default_a:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: {
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// inputs: {
|
||||||
|
// key : "b"
|
||||||
|
// value: {
|
||||||
|
// name : "serving_default_b:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: {
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// outputs: {
|
||||||
|
// key : "output_0"
|
||||||
|
// value: {
|
||||||
|
// name : "StatefulPartitionedCall:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: {
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// method_name: "tensorflow/serving/predict"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
bool use_tfrt = GetParam();
|
||||||
|
if (use_tfrt) {
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||||
|
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||||
|
|
||||||
|
TF_SavedModel* saved_model =
|
||||||
|
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TF_SignatureDefFunction* serving_default =
|
||||||
|
TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_SignatureDefFunctionMetadata* metadata =
|
||||||
|
TF_SignatureDefFunctionGetMetadata(serving_default);
|
||||||
|
|
||||||
|
const TF_SignatureDefParamList* args =
|
||||||
|
TF_SignatureDefFunctionMetadataArgs(metadata);
|
||||||
|
const TF_SignatureDefParamList* returns =
|
||||||
|
TF_SignatureDefFunctionMetadataReturns(metadata);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
|
||||||
|
const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
|
||||||
|
const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
|
||||||
|
const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
|
||||||
|
|
||||||
|
// Input "a" is a scalar, float32 tensor
|
||||||
|
EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
|
||||||
|
EXPECT_EQ(0, TF_ShapeDims(shape_a));
|
||||||
|
|
||||||
|
const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
|
||||||
|
const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
|
||||||
|
const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
|
||||||
|
|
||||||
|
// Input "b" is a scalar, float32 tensor
|
||||||
|
EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
|
||||||
|
EXPECT_EQ(0, TF_ShapeDims(shape_b));
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
|
||||||
|
|
||||||
|
const TF_SignatureDefParam* param_out =
|
||||||
|
TF_SignatureDefParamListGet(returns, 0);
|
||||||
|
const TF_TensorSpec* tensor_spec_out =
|
||||||
|
TF_SignatureDefParamTensorSpec(param_out);
|
||||||
|
const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
|
||||||
|
|
||||||
|
// Output "output_0" is a scalar, float32 tensor
|
||||||
|
EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
|
||||||
|
EXPECT_EQ(0, TF_ShapeDims(shape_out));
|
||||||
|
|
||||||
|
std::vector<TFE_TensorHandle*> compute_fn_inputs;
|
||||||
|
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
|
||||||
|
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
|
||||||
|
compute_fn_inputs.push_back(input_a);
|
||||||
|
compute_fn_inputs.push_back(input_b);
|
||||||
|
|
||||||
|
TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
|
||||||
|
serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
std::vector<TFE_TensorHandle*> compute_fn_outputs(
|
||||||
|
TF_SignatureDefParamListSize(returns));
|
||||||
|
int num_retvals = TF_SignatureDefParamListSize(returns);
|
||||||
|
|
||||||
|
TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_NumDims(result), 0);
|
||||||
|
float output_value = *static_cast<float*>(TF_TensorData(result));
|
||||||
|
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
|
||||||
|
EXPECT_FLOAT_EQ(output_value, 8.0);
|
||||||
|
|
||||||
|
TF_DeleteTensor(result);
|
||||||
|
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
|
||||||
|
TFE_DeleteTensorHandle(input_a);
|
||||||
|
TFE_DeleteTensorHandle(input_b);
|
||||||
|
TFE_DeleteOp(serving_default_op);
|
||||||
|
TF_DeleteSavedModel(saved_model);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
@ -16,5 +16,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
|
||||||
|
|
||||||
// TODO(bmzhao): Add getter functions here as necessary.
|
extern "C" {
|
||||||
|
|
||||||
|
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataArgs(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(list)->arguments());
|
||||||
|
}
|
||||||
|
|
||||||
|
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataReturns(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(list)->returns());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
extern const char* TF_SignatureDefParamName(const TF_SignatureDefParam* param) {
|
||||||
|
return tensorflow::unwrap(param)->name().c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
|
||||||
|
const TF_SignatureDefParam* param) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(param)->spec());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
extern size_t TF_SignatureDefParamListSize(
|
||||||
|
const TF_SignatureDefParamList* list) {
|
||||||
|
return tensorflow::unwrap(list)->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
|
||||||
|
const TF_SignatureDefParamList* list, int i) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(list)->at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
|
||||||
|
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(std::vector<SignatureDefParam>,
|
||||||
|
TF_SignatureDefParamList)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
|
||||||
|
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
|
||||||
|
#include "tensorflow/c/tf_shape_internal.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
TF_DataType TF_TensorSpecDataType(const TF_TensorSpec* spec) {
|
||||||
|
return static_cast<TF_DataType>(tensorflow::unwrap(spec)->dtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
const TF_Shape* TF_TensorSpecShape(const TF_TensorSpec* spec) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(spec)->shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
|
||||||
|
typedef struct TF_TensorSpec TF_TensorSpec;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
@ -28,6 +28,9 @@ exports_files(
|
|||||||
"saved_model_api.h",
|
"saved_model_api.h",
|
||||||
"signature_def_function.h",
|
"signature_def_function.h",
|
||||||
"signature_def_function_metadata.h",
|
"signature_def_function_metadata.h",
|
||||||
|
"signature_def_param.h",
|
||||||
|
"signature_def_param_list.h",
|
||||||
|
"tensor_spec.h",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||||
)
|
)
|
||||||
@ -45,6 +48,9 @@ cc_library(
|
|||||||
":saved_model_api",
|
":saved_model_api",
|
||||||
":signature_def_function",
|
":signature_def_function",
|
||||||
":signature_def_function_metadata",
|
":signature_def_function_metadata",
|
||||||
|
":signature_def_param",
|
||||||
|
":signature_def_param_list",
|
||||||
|
":tensor_spec",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -77,3 +83,18 @@ alias(
|
|||||||
name = "signature_def_function_metadata",
|
name = "signature_def_function_metadata",
|
||||||
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
|
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "signature_def_param",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param",
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "signature_def_param_list",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param_list",
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "tensor_spec",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:tensor_spec",
|
||||||
|
)
|
||||||
|
@ -23,6 +23,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
@ -24,6 +27,18 @@ extern "C" {
|
|||||||
// SavedModel.
|
// SavedModel.
|
||||||
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
|
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
|
||||||
|
|
||||||
|
// Retrieves the arguments of the SignatureDefFunction. The caller is not
|
||||||
|
// responsible for freeing the returned pointer.
|
||||||
|
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
|
||||||
|
TF_SignatureDefFunctionMetadataArgs(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list);
|
||||||
|
|
||||||
|
// Retrieves the returns of the SignatureDefFunction. The caller is not
|
||||||
|
// responsible for freeing the returned pointer.
|
||||||
|
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
|
||||||
|
TF_SignatureDefFunctionMetadataReturns(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // end extern "C"
|
} // end extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type that containing metadata of an input/output of a
|
||||||
|
// TF_SignatureDefFunction loaded from a SavedModel.
|
||||||
|
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
|
||||||
|
|
||||||
|
// Returns the name of the given parameter. The caller is not responsible for
|
||||||
|
// freeing the returned char*.
|
||||||
|
TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName(
|
||||||
|
const TF_SignatureDefParam* param);
|
||||||
|
|
||||||
|
// Returns the TensorSpec associated with the given parameter. The caller is
|
||||||
|
// not reponsible for freeing the returned TF_TensorSpec*.
|
||||||
|
TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
|
||||||
|
const TF_SignatureDefParam* param);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type that containing metadata of an input/output of a
|
||||||
|
// ConcreteFunction loaded from a SavedModel.
|
||||||
|
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
|
||||||
|
|
||||||
|
// Returns the size of `list`.
|
||||||
|
TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize(
|
||||||
|
const TF_SignatureDefParamList* list);
|
||||||
|
|
||||||
|
// Returns the `i`th TF_SignatureDefParam in the list.
|
||||||
|
TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
|
||||||
|
const TF_SignatureDefParamList* list, int i);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
46
tensorflow/c/experimental/saved_model/public/tensor_spec.h
Normal file
46
tensorflow/c/experimental/saved_model/public/tensor_spec.h
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/c/tf_shape.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type corresponding to TensorSpec
|
||||||
|
typedef struct TF_TensorSpec TF_TensorSpec;
|
||||||
|
|
||||||
|
// Returns the dtype associated with the TensorSpec.
|
||||||
|
TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType(
|
||||||
|
const TF_TensorSpec* spec);
|
||||||
|
|
||||||
|
// Returns the shape associated with the TensorSpec. The returned Shape is not
|
||||||
|
// owned by the caller. Caller must not call TF_DeleteShape on the returned
|
||||||
|
// shape.
|
||||||
|
TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape(
|
||||||
|
const TF_TensorSpec* spec);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
@ -22,6 +22,8 @@ cc_library(
|
|||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:regexp",
|
||||||
|
"//tensorflow/core/platform:strcat",
|
||||||
"//tensorflow/stream_executor:executor_cache",
|
"//tensorflow/stream_executor:executor_cache",
|
||||||
"//tensorflow/stream_executor:multi_platform_manager",
|
"//tensorflow/stream_executor:multi_platform_manager",
|
||||||
"//tensorflow/stream_executor:platform",
|
"//tensorflow/stream_executor:platform",
|
||||||
|
@ -27,7 +27,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
#include "tensorflow/stream_executor/executor_cache.h"
|
#include "tensorflow/stream_executor/executor_cache.h"
|
||||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
@ -39,6 +42,8 @@ limitations under the License.
|
|||||||
using tensorflow::StatusFromTF_Status;
|
using tensorflow::StatusFromTF_Status;
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
using tensorflow::StringPiece;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
|
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
|
||||||
@ -58,10 +63,35 @@ namespace {
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
port::Status ValidateDeviceType(StringPiece type) {
|
||||||
|
// Validate device type. Device type must start with a capital letter and
|
||||||
|
// consist of capital letters and underscores. Reasoning behind this decision:
|
||||||
|
// * At the minimum we want to disallow '/' and ':' since
|
||||||
|
// these characters are used in device spec, for e.g.
|
||||||
|
// /job:foo/replica:12/device:GPU:1.
|
||||||
|
// * Underscores seem useful, for e.g. XLA_GPU uses underscores.
|
||||||
|
// * Allowing lowercase might get confusing. For example, say someone
|
||||||
|
// registers a new type called "Gpu". It might be confusing for users that
|
||||||
|
// "Gpu" is not the same device type as "GPU".
|
||||||
|
// Note that lowercase "cpu" and "gpu" are currently supported only for
|
||||||
|
// legacy reasons:
|
||||||
|
// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
|
||||||
|
static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
|
||||||
|
bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
|
||||||
|
if (!matches) {
|
||||||
|
return port::FailedPreconditionError(
|
||||||
|
tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
|
||||||
|
kTfDeviceTypeRegEx->pattern(), "."));
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
port::Status ValidateSPPlatform(const SP_Platform& platform) {
|
port::Status ValidateSPPlatform(const SP_Platform& platform) {
|
||||||
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
|
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
|
||||||
VALIDATE_MEMBER(SP_Platform, platform, name);
|
VALIDATE_MEMBER(SP_Platform, platform, name);
|
||||||
VALIDATE_MEMBER(SP_Platform, platform, type);
|
VALIDATE_MEMBER(SP_Platform, platform, type);
|
||||||
|
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
|
||||||
// `visible_device_count` could be 0 at initialization time.
|
// `visible_device_count` could be 0 at initialization time.
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ limitations under the License.
|
|||||||
// params.device = &device;
|
// params.device = &device;
|
||||||
//
|
//
|
||||||
// /* Plugin code below */
|
// /* Plugin code below */
|
||||||
// constexpr char DEVICE_NAME[] = "MyDevice";
|
// constexpr char DEVICE_NAME[] = "MY_DEVICE";
|
||||||
// constexpr char DEVICE_TYPE[] = "GPU";
|
// constexpr char DEVICE_TYPE[] = "GPU";
|
||||||
//
|
//
|
||||||
// void create_device(const SP_Platform* platform,
|
// void create_device(const SP_Platform* platform,
|
||||||
@ -416,10 +416,15 @@ typedef struct SP_Platform {
|
|||||||
|
|
||||||
void* ext; // free-form data set by plugin
|
void* ext; // free-form data set by plugin
|
||||||
|
|
||||||
// Platform name. Must be null-terminated.
|
// Platform name (also referred to as subtype), for example MY_DEVICE.
|
||||||
|
// The name must start with a capital letter and consist of
|
||||||
|
// capital letters and underscores.
|
||||||
|
// Must be null-terminated.
|
||||||
const char* name;
|
const char* name;
|
||||||
|
|
||||||
// Device type name, for example GPU. Must be null-terminated.
|
// Device type name, for example GPU. Must be null-terminated.
|
||||||
|
// The name must start with a capital letter and consist of
|
||||||
|
// capital letters and underscores.
|
||||||
const char* type;
|
const char* type;
|
||||||
|
|
||||||
// Number of visible devices
|
// Number of visible devices
|
||||||
|
@ -41,9 +41,9 @@ struct SP_Timer_st {
|
|||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int DEVICE_COUNT = 2;
|
constexpr int kDeviceCount = 2;
|
||||||
constexpr char DEVICE_NAME[] = "MyDevice";
|
constexpr char kDeviceName[] = "MY_DEVICE";
|
||||||
constexpr char DEVICE_TYPE[] = "GPU";
|
constexpr char kDeviceType[] = "GPU";
|
||||||
|
|
||||||
/*** Create SP_StreamExecutor (with empty functions) ***/
|
/*** Create SP_StreamExecutor (with empty functions) ***/
|
||||||
void allocate(const SP_Device* const device, uint64_t size,
|
void allocate(const SP_Device* const device, uint64_t size,
|
||||||
@ -190,9 +190,9 @@ void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
|
|||||||
void PopulateDefaultPlatform(SP_Platform* platform,
|
void PopulateDefaultPlatform(SP_Platform* platform,
|
||||||
SP_PlatformFns* platform_fns) {
|
SP_PlatformFns* platform_fns) {
|
||||||
*platform = {SP_PLATFORM_STRUCT_SIZE};
|
*platform = {SP_PLATFORM_STRUCT_SIZE};
|
||||||
platform->name = DEVICE_NAME;
|
platform->name = kDeviceName;
|
||||||
platform->type = DEVICE_TYPE;
|
platform->type = kDeviceType;
|
||||||
platform->visible_device_count = DEVICE_COUNT;
|
platform->visible_device_count = kDeviceCount;
|
||||||
platform_fns->create_device = create_device;
|
platform_fns->create_device = create_device;
|
||||||
platform_fns->destroy_device = destroy_device;
|
platform_fns->destroy_device = destroy_device;
|
||||||
platform_fns->create_device_fns = create_device_fns;
|
platform_fns->create_device_fns = create_device_fns;
|
||||||
@ -218,11 +218,11 @@ TEST(StreamExecutor, SuccessfulRegistration) {
|
|||||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||||
TF_ASSERT_OK(status);
|
TF_ASSERT_OK(status);
|
||||||
port::StatusOr<Platform*> maybe_platform =
|
port::StatusOr<Platform*> maybe_platform =
|
||||||
MultiPlatformManager::PlatformWithName("MyDevice");
|
MultiPlatformManager::PlatformWithName("MY_DEVICE");
|
||||||
TF_ASSERT_OK(maybe_platform.status());
|
TF_ASSERT_OK(maybe_platform.status());
|
||||||
Platform* platform = maybe_platform.ConsumeValueOrDie();
|
Platform* platform = maybe_platform.ConsumeValueOrDie();
|
||||||
ASSERT_EQ(platform->Name(), DEVICE_NAME);
|
ASSERT_EQ(platform->Name(), kDeviceName);
|
||||||
ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT);
|
ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount);
|
||||||
|
|
||||||
port::StatusOr<StreamExecutor*> maybe_executor =
|
port::StatusOr<StreamExecutor*> maybe_executor =
|
||||||
platform->ExecutorForDevice(0);
|
platform->ExecutorForDevice(0);
|
||||||
@ -244,6 +244,39 @@ TEST(StreamExecutor, NameNotSet) {
|
|||||||
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
|
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StreamExecutor, InvalidNameWithSemicolon) {
|
||||||
|
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||||
|
TF_Status* const status) -> void {
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||||
|
params->platform->name = "INVALID:NAME";
|
||||||
|
params->destroy_platform = destroy_platform;
|
||||||
|
params->destroy_platform_fns = destroy_platform_fns;
|
||||||
|
};
|
||||||
|
|
||||||
|
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||||
|
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.error_message(),
|
||||||
|
testing::ContainsRegex("Device name/type 'INVALID:NAME' must match"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StreamExecutor, InvalidNameWithSlash) {
|
||||||
|
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||||
|
TF_Status* const status) -> void {
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||||
|
params->platform->name = "INVALID/";
|
||||||
|
params->destroy_platform = destroy_platform;
|
||||||
|
params->destroy_platform_fns = destroy_platform_fns;
|
||||||
|
};
|
||||||
|
|
||||||
|
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||||
|
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||||
|
EXPECT_THAT(status.error_message(),
|
||||||
|
testing::ContainsRegex("Device name/type 'INVALID/' must match"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StreamExecutor, CreateDeviceNotSet) {
|
TEST(StreamExecutor, CreateDeviceNotSet) {
|
||||||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||||
TF_Status* const status) -> void {
|
TF_Status* const status) -> void {
|
||||||
|
@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
|||||||
|
|
||||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
mutex_lock l(graph->mu);
|
TF_UpdateEdge(graph, new_src, dst, status);
|
||||||
tensorflow::shape_inference::InferenceContext* ic =
|
|
||||||
graph->refiner.GetContext(&new_src.oper->node);
|
|
||||||
|
|
||||||
if (ic->num_outputs() <= new_src.index) {
|
|
||||||
status->status = tensorflow::errors::OutOfRange(
|
|
||||||
"Cannot update edge. Output index [", new_src.index,
|
|
||||||
"] is greater than the number of total outputs [", ic->num_outputs(),
|
|
||||||
"].");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
|
||||||
|
|
||||||
tensorflow::shape_inference::InferenceContext* ic_dst =
|
|
||||||
graph->refiner.GetContext(&dst.oper->node);
|
|
||||||
if (ic_dst->num_inputs() <= dst.index) {
|
|
||||||
status->status = tensorflow::errors::OutOfRange(
|
|
||||||
"Cannot update edge. Input index [", dst.index,
|
|
||||||
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
|
||||||
"].");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!ic_dst->MergeInput(dst.index, shape)) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
|
||||||
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
|
||||||
&dst.oper->node, dst.index);
|
|
||||||
|
|
||||||
if (TF_GetCode(status) == TF_OK) {
|
|
||||||
// This modification only updates the destination node for
|
|
||||||
// the purposes of running this graph in a session. Thus, we don't
|
|
||||||
// record the source node as being modified.
|
|
||||||
RecordMutation(graph, *dst.oper, "updating input tensor");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
|
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
|
||||||
@ -136,6 +100,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
|
|||||||
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
||||||
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
||||||
out_shape_and_type->set_dtype(p.dtype);
|
out_shape_and_type->set_dtype(p.dtype);
|
||||||
|
out_shape_and_type->set_specialized_type(p.specialized_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
string result;
|
string result;
|
||||||
@ -163,7 +128,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
|||||||
status->status =
|
status->status =
|
||||||
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
|
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
|
||||||
|
shape_and_type_proto.specialized_type());
|
||||||
}
|
}
|
||||||
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
||||||
}
|
}
|
||||||
|
39
tensorflow/c/tf_shape.cc
Normal file
39
tensorflow/c/tf_shape.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_shape.h"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_shape_internal.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
TF_Shape* TF_NewShape() {
|
||||||
|
return tensorflow::wrap(new tensorflow::PartialTensorShape());
|
||||||
|
}
|
||||||
|
|
||||||
|
int TF_ShapeDims(const TF_Shape* shape) {
|
||||||
|
return tensorflow::unwrap(shape)->dims();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t TF_ShapeDimSize(const TF_Shape* shape, int d) {
|
||||||
|
return tensorflow::unwrap(shape)->dim_size(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteShape(TF_Shape* shape) { delete tensorflow::unwrap(shape); }
|
||||||
|
|
||||||
|
} // end extern "C"
|
50
tensorflow/c/tf_shape.h
Normal file
50
tensorflow/c/tf_shape.h
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_TF_SHAPE_H_
|
||||||
|
#define TENSORFLOW_C_TF_SHAPE_H_
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// An opaque type corresponding to a shape in tensorflow. In the future,
|
||||||
|
// we may expose the ABI of TF_Shape for performance reasons.
|
||||||
|
typedef struct TF_Shape TF_Shape;
|
||||||
|
|
||||||
|
// Return a new, unknown rank shape object. The caller is responsible for
|
||||||
|
// calling TF_DeleteShape to deallocate and destroy the returned shape.
|
||||||
|
TF_CAPI_EXPORT extern TF_Shape* TF_NewShape();
|
||||||
|
|
||||||
|
// Returns the rank of `shape`. If `shape` has unknown rank, returns -1.
|
||||||
|
TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape);
|
||||||
|
|
||||||
|
// Returns the `d`th dimension of `shape`. If `shape` has unknown rank,
|
||||||
|
// invoking this function is undefined behavior. Returns -1 if dimension is
|
||||||
|
// unknown.
|
||||||
|
TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d);
|
||||||
|
|
||||||
|
// Deletes `shape`.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_TF_SHAPE_H_
|
30
tensorflow/c/tf_shape_internal.h
Normal file
30
tensorflow/c/tf_shape_internal.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
typedef struct TF_Shape TF_Shape;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
|
@ -251,7 +251,6 @@ cc_library_with_android_deps(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_experimental",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -266,7 +265,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_experimental",
|
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
@ -15,13 +15,12 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradients.h"
|
||||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
|
||||||
#include "tensorflow/cc/framework/gradients.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
|
|||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
||||||
|
|
||||||
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
|
Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
|
||||||
const std::vector<Output>& grad_inputs,
|
const Operation& op,
|
||||||
std::vector<Output>* grad_outputs) {
|
const std::vector<Output>& grad_inputs,
|
||||||
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
|
std::vector<Output>* grad_outputs) {
|
||||||
grad_outputs->push_back(NoGradient());
|
Input input = Shape(scope, op.input(0));
|
||||||
grad_outputs->push_back(NoGradient());
|
Input input_min = op.input(1);
|
||||||
|
Input input_max = op.input(2);
|
||||||
|
int64 axis;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
||||||
|
auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
|
||||||
|
scope, grad_inputs[0], input, input_min, input_max,
|
||||||
|
QuantizeAndDequantizeV4Grad::Axis(axis));
|
||||||
|
grad_outputs->push_back(qdq_v4_grad.input_backprop);
|
||||||
|
grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
|
||||||
|
grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
|
||||||
return scope.status();
|
return scope.status();
|
||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
|
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
|
||||||
|
QuantizeAndDequantizeV4GradHelper);
|
||||||
|
|
||||||
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
||||||
const std::vector<Output>& grad_inputs,
|
const std::vector<Output>& grad_inputs,
|
||||||
|
@ -21,10 +21,7 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files([
|
exports_files(["loader.h"])
|
||||||
"LICENSE",
|
|
||||||
"loader.h",
|
|
||||||
])
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "constants",
|
name = "constants",
|
||||||
@ -45,13 +42,15 @@ cc_library(
|
|||||||
name = "reader",
|
name = "reader",
|
||||||
srcs = ["reader.cc"],
|
srcs = ["reader.cc"],
|
||||||
hdrs = ["reader.h"],
|
hdrs = ["reader.h"],
|
||||||
deps = [":constants"] + if_not_mobile([
|
deps = [
|
||||||
|
":constants",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
] + if_not_mobile([
|
||||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||||
# tf_lib depending on the build platform.
|
# tf_lib depending on the build platform.
|
||||||
"@com_google_absl//absl/memory:memory",
|
"@com_google_absl//absl/memory:memory",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,8 +12,6 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "freeze_saved_model",
|
name = "freeze_saved_model",
|
||||||
srcs = ["freeze_saved_model.cc"],
|
srcs = ["freeze_saved_model.cc"],
|
||||||
|
@ -75,7 +75,7 @@ cc_library(
|
|||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//llvm:Target",
|
"@llvm-project//llvm:Target",
|
||||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||||
"//tensorflow/core:regexp_internal",
|
"//tensorflow/core/platform:regexp",
|
||||||
] + if_llvm_system_z_available([
|
] + if_llvm_system_z_available([
|
||||||
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
|
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
|
||||||
]) + if_llvm_aarch64_available([
|
]) + if_llvm_aarch64_available([
|
||||||
|
@ -336,9 +336,9 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:regexp_internal",
|
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:regexp",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
@ -559,9 +559,9 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:regexp_internal",
|
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:regexp",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -127,7 +127,7 @@ def tf_library(
|
|||||||
"$(location " + tfcompile_tool + ")" +
|
"$(location " + tfcompile_tool + ")" +
|
||||||
" --config=$(location " + config + ")" +
|
" --config=$(location " + config + ")" +
|
||||||
" --dump_fetch_nodes > $@"),
|
" --dump_fetch_nodes > $@"),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
# Run tfcompile on the build host, rather than forge, since it's
|
# Run tfcompile on the build host, rather than forge, since it's
|
||||||
# typically way faster on the local machine.
|
# typically way faster on the local machine.
|
||||||
local = 1,
|
local = 1,
|
||||||
@ -242,7 +242,7 @@ def tf_library(
|
|||||||
" --out_function_object=$(@D)/" + function_object_file +
|
" --out_function_object=$(@D)/" + function_object_file +
|
||||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||||
),
|
),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
# Run tfcompile on the build host since it's typically faster on the
|
# Run tfcompile on the build host since it's typically faster on the
|
||||||
@ -281,7 +281,7 @@ def tf_library(
|
|||||||
" --out_session_module=$(@D)/" + session_module_pb +
|
" --out_session_module=$(@D)/" + session_module_pb +
|
||||||
" " + flags
|
" " + flags
|
||||||
),
|
),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
local = 1,
|
local = 1,
|
||||||
|
@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
|||||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts")
|
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
@ -77,7 +77,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
@ -114,7 +114,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
||||||
],
|
],
|
||||||
@ -141,7 +141,7 @@ cc_library(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||||
],
|
],
|
||||||
@ -204,7 +204,7 @@ XLA_DEVICE_DEPS = [
|
|||||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||||
"//tensorflow/core:state_ops_op_lib",
|
"//tensorflow/core:state_ops_op_lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/kernels:constant_op",
|
"//tensorflow/core/kernels:constant_op",
|
||||||
"//tensorflow/core/kernels:fifo_queue",
|
"//tensorflow/core/kernels:fifo_queue",
|
||||||
"//tensorflow/core/kernels:function_ops",
|
"//tensorflow/core/kernels:function_ops",
|
||||||
@ -375,7 +375,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/mlir:array_container_utils",
|
"//tensorflow/compiler/mlir:array_container_utils",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
@ -435,6 +435,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
@ -1022,10 +1023,10 @@ tf_cc_test(
|
|||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/core:all_kernels",
|
"//tensorflow/core:all_kernels",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:direct_session_internal",
|
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:ops",
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core/common_runtime:direct_session_internal",
|
||||||
"//tensorflow/core/kernels:cwise_op",
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
"//tensorflow/core/kernels:matmul_op",
|
"//tensorflow/core/kernels:matmul_op",
|
||||||
"//tensorflow/core/kernels:partitioned_function_ops",
|
"//tensorflow/core/kernels:partitioned_function_ops",
|
||||||
|
@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
|
||||||
|
const Node& node, absl::string_view attr_name,
|
||||||
|
absl::string_view call_name) {
|
||||||
|
std::vector<NameAttrList> attr_lists;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
|
||||||
|
|
||||||
|
std::vector<NodeDef> out;
|
||||||
|
for (int i = 0; i < attr_lists.size(); i++) {
|
||||||
|
out.emplace_back();
|
||||||
|
NodeDef& inserted = out.back();
|
||||||
|
inserted.set_name(absl::StrCat(call_name, "_", i));
|
||||||
|
inserted.set_op(attr_lists[i].name());
|
||||||
|
*inserted.mutable_attr() = attr_lists[i].attr();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
// Utility which searches for values in a sorted list by scanning over it once.
|
// Utility which searches for values in a sorted list by scanning over it once.
|
||||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||||
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
|
|||||||
return is_compilable;
|
return is_compilable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool RecursiveCompilabilityChecker::IsCompilableCase(
|
||||||
|
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
|
||||||
|
std::vector<StackFrameView>* stack_trace,
|
||||||
|
NameAttrList* encapsulating_function,
|
||||||
|
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||||
|
const {
|
||||||
|
xla::StatusOr<std::vector<NodeDef>> calls =
|
||||||
|
MakeCallNodesFromAttribute(case_node, "branches", "branch");
|
||||||
|
if (!calls.ok()) {
|
||||||
|
VLOG(2) << "Rejecting node " << case_node.name() << ": "
|
||||||
|
<< "missing attribute 'branches'";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_compilable = true;
|
||||||
|
|
||||||
|
for (const NodeDef& call : *calls) {
|
||||||
|
is_compilable &=
|
||||||
|
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
|
||||||
|
uncompilable_nodes);
|
||||||
|
}
|
||||||
|
return is_compilable;
|
||||||
|
}
|
||||||
|
|
||||||
// Tests whether 'while_node' is a completely compilable loop.
|
// Tests whether 'while_node' is a completely compilable loop.
|
||||||
// Every operator in the condition and body functions must be compilable for a
|
// Every operator in the condition and body functions must be compilable for a
|
||||||
// while loop to be compilable.
|
// while loop to be compilable.
|
||||||
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
|
||||||
|
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
|
||||||
|
uncompilable_nodes)) {
|
||||||
|
LogNotCompilable(node, "unsupported case");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (!op_filter_.allow_stateful_rng_ops &&
|
if (!op_filter_.allow_stateful_rng_ops &&
|
||||||
IsStatefulRandomOp(node.type_string())) {
|
IsStatefulRandomOp(node.type_string())) {
|
||||||
absl::string_view uncompilable_reason = "stateful random op";
|
absl::string_view uncompilable_reason = "stateful random op";
|
||||||
|
@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
|
|||||||
// Whether ops known to have numerical accuracy issues should be considered
|
// Whether ops known to have numerical accuracy issues should be considered
|
||||||
// compilable..
|
// compilable..
|
||||||
bool allow_inaccurate_ops = false;
|
bool allow_inaccurate_ops = false;
|
||||||
|
|
||||||
|
// Require the function to be always compilable, regardless whether some
|
||||||
|
// control flow branches might be dead for a given input.
|
||||||
|
bool require_always_compilable = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
||||||
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
|
|||||||
NameAttrList* encapsulating_function,
|
NameAttrList* encapsulating_function,
|
||||||
UncompilableNodesMap* uncompilable_nodes) const;
|
UncompilableNodesMap* uncompilable_nodes) const;
|
||||||
|
|
||||||
|
// Tests whether 'case_node' is compilable. Every operator in all branches
|
||||||
|
// must be compilable.
|
||||||
|
bool IsCompilableCase(const Node& case_node,
|
||||||
|
FunctionLibraryRuntime* lib_runtime,
|
||||||
|
std::vector<StackFrameView>* stack_trace,
|
||||||
|
NameAttrList* encapsulating_function,
|
||||||
|
UncompilableNodesMap* uncompilable_nodes) const;
|
||||||
|
|
||||||
// Returns compilability of node def retrieved from `node`'s attribute with
|
// Returns compilability of node def retrieved from `node`'s attribute with
|
||||||
// name `attr_name`.
|
// name `attr_name`.
|
||||||
bool ExtractNodeDefAndCheckCompilability(
|
bool ExtractNodeDefAndCheckCompilability(
|
||||||
|
@ -34,7 +34,16 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
|
||||||
|
AttrValue attr;
|
||||||
|
for (const char* name : names) {
|
||||||
|
attr.mutable_list()->add_func()->set_name(name);
|
||||||
|
}
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr char kFunctionalIfNodeName[] = "If";
|
constexpr char kFunctionalIfNodeName[] = "If";
|
||||||
|
constexpr char kFunctionalCaseNodeName[] = "Case";
|
||||||
constexpr char kFunctionalWhileNodeName[] = "While";
|
constexpr char kFunctionalWhileNodeName[] = "While";
|
||||||
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
||||||
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
||||||
@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
|
|||||||
op_filter_.allow_inaccurate_ops = false;
|
op_filter_.allow_inaccurate_ops = false;
|
||||||
op_filter_.allow_slow_ops = false;
|
op_filter_.allow_slow_ops = false;
|
||||||
|
|
||||||
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
checker_ = CreateCompilabilityChecker();
|
||||||
device_type_);
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
|
||||||
|
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||||
|
device_type_);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
||||||
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
|||||||
"unsupported op"));
|
"unsupported op"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
|
||||||
|
FunctionDefLibrary flib;
|
||||||
|
*flib.add_function() = FunctionDefHelper::Define(
|
||||||
|
/*Function*/ kUncompilableFunctionName,
|
||||||
|
/*Inputs*/ {"n_a:float"},
|
||||||
|
/*Outputs*/ {"n_c_uncompilable:float"},
|
||||||
|
/*Attributes*/ {},
|
||||||
|
// Node info
|
||||||
|
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
|
||||||
|
*flib.add_function() = FunctionDefHelper::Define(
|
||||||
|
/*Function*/ kUncompilableFunctionTwoName,
|
||||||
|
/*Inputs*/ {"n_a:float"},
|
||||||
|
/*Outputs*/ {"n_d_uncompilable:float"},
|
||||||
|
/*Attribute*/ {},
|
||||||
|
// Node info
|
||||||
|
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
|
||||||
|
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
|
||||||
|
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
|
||||||
|
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputes(
|
||||||
|
{NodeBuilder::NodeOut(placeholder.node())});
|
||||||
|
Node* case_node;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
|
||||||
|
.Input(branch_index.node())
|
||||||
|
.Input(inputes)
|
||||||
|
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
|
||||||
|
kUncompilableFunctionTwoName}))
|
||||||
|
.Attr("Tout", {DT_INT32})
|
||||||
|
.Finalize(root.graph(), &case_node));
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||||
|
|
||||||
|
auto case_node_it = std::find_if(
|
||||||
|
graph->nodes().begin(), graph->nodes().end(),
|
||||||
|
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
|
||||||
|
EXPECT_NE(case_node_it, graph->nodes().end());
|
||||||
|
auto* flib_runtime = GetFunctionLibraryRuntime();
|
||||||
|
|
||||||
|
op_filter_.require_always_compilable = false;
|
||||||
|
checker_ = CreateCompilabilityChecker();
|
||||||
|
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||||
|
op_filter_.require_always_compilable = true;
|
||||||
|
checker_ = CreateCompilabilityChecker();
|
||||||
|
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
||||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||||
Scope root = Scope::NewRootScope().ExitOnError();
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -47,8 +48,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
|
|||||||
|
|
||||||
xla::StatusOr<std::string> GetCompilerIr(
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
absl::string_view func_name, Device* dev,
|
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||||
absl::Span<const Tensor* const> inputs) {
|
absl::Span<const TensorHandle* const> inputs_handles) {
|
||||||
NameAttrList function;
|
NameAttrList function;
|
||||||
function.set_name(std::string{func_name});
|
function.set_name(std::string{func_name});
|
||||||
|
|
||||||
@ -65,6 +66,25 @@ xla::StatusOr<std::string> GetCompilerIr(
|
|||||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||||
|
|
||||||
|
std::deque<Tensor> inputs_storage;
|
||||||
|
std::vector<const Tensor*> inputs;
|
||||||
|
inputs.reserve(inputs_handles.size());
|
||||||
|
for (int i = 0; i < inputs_handles.size(); i++) {
|
||||||
|
const TensorHandle* th = inputs_handles[i];
|
||||||
|
const Tensor* t;
|
||||||
|
// Handle owns the tensor.
|
||||||
|
TF_RETURN_IF_ERROR(th->Tensor(&t));
|
||||||
|
if (absl::c_binary_search(constant_arg_indices, i)) {
|
||||||
|
// Need to make sure it's on the host.
|
||||||
|
inputs_storage.emplace_back(t->dtype(), t->shape());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
|
||||||
|
inputs.push_back(&inputs_storage.back());
|
||||||
|
} else {
|
||||||
|
inputs.push_back(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<VariableInfo> variable_infos;
|
std::vector<VariableInfo> variable_infos;
|
||||||
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||||
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||||
|
@ -24,6 +24,8 @@ namespace tensorflow {
|
|||||||
class ProcessFunctionLibraryRuntime;
|
class ProcessFunctionLibraryRuntime;
|
||||||
class Device;
|
class Device;
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
class TensorHandle;
|
||||||
|
class EagerContext;
|
||||||
|
|
||||||
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
||||||
|
|
||||||
@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
|||||||
// `runtime` on a device `dev` with given `inputs`.
|
// `runtime` on a device `dev` with given `inputs`.
|
||||||
xla::StatusOr<std::string> GetCompilerIr(
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
absl::string_view func_name, Device* dev,
|
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||||
absl::Span<const Tensor* const> inputs);
|
absl::Span<const TensorHandle* const> inputs);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ XLA_OPS_DEPS = [
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:state_ops_op_lib",
|
"//tensorflow/core:state_ops_op_lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||||
]
|
]
|
||||||
|
@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RecursiveCompilabilityChecker{
|
RecursiveCompilabilityChecker::OperationFilter filter =
|
||||||
CreateOperationFilter(*registration),
|
CreateOperationFilter(*registration);
|
||||||
DeviceType{registration->compilation_device_name}}
|
filter.require_always_compilable = true;
|
||||||
.IsCompilableNode(*node, lib_runtime)) {
|
|
||||||
|
RecursiveCompilabilityChecker checker(
|
||||||
|
filter, DeviceType{registration->compilation_device_name});
|
||||||
|
|
||||||
|
if (!checker.IsCompilableNode(*node, lib_runtime)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2062,6 +2066,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"XlaSpmdFullToShardShape",
|
"XlaSpmdFullToShardShape",
|
||||||
"XlaSpmdShardToFullShape",
|
"XlaSpmdShardToFullShape",
|
||||||
"XlaSvd",
|
"XlaSvd",
|
||||||
|
"XlaVariadicReduce",
|
||||||
"XlaWhile",
|
"XlaWhile",
|
||||||
"Zeta",
|
"Zeta",
|
||||||
"_Arg",
|
"_Arg",
|
||||||
|
@ -47,7 +47,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
|
|
||||||
#if !defined(LIBTFTPU)
|
#if !defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||||
#endif
|
#endif
|
||||||
@ -289,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
});
|
});
|
||||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||||
#ifdef LIBTFTPU
|
#ifdef LIBTPU_ON_GCE
|
||||||
if (use_mlir && has_tensor_list_arg) {
|
if (use_mlir && has_tensor_list_arg) {
|
||||||
LOG(WARNING) << "MLIR is not supported in this environment.";
|
LOG(WARNING) << "MLIR is not supported in this environment.";
|
||||||
}
|
}
|
||||||
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
|
std::vector<std::string> control_rets;
|
||||||
|
if (result_dtypes.empty()) {
|
||||||
|
control_rets.push_back(node_def.name());
|
||||||
|
}
|
||||||
return CompileGraphToXlaHlo(
|
return CompileGraphToXlaHlo(
|
||||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
|
||||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||||
#endif
|
#endif
|
||||||
|
@ -9,3 +9,31 @@ dialects and utilities for
|
|||||||
3. TF Lite
|
3. TF Lite
|
||||||
|
|
||||||
See [MLIR's website](https://mlir.llvm.org) for complete documentation.
|
See [MLIR's website](https://mlir.llvm.org) for complete documentation.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
Building dialects and utilities here follow the standard approach using
|
||||||
|
`bazel` as the rest of TensorFlow.
|
||||||
|
|
||||||
|
### Using local LLVM repo
|
||||||
|
|
||||||
|
To develop across MLIR core and TensorFlow, it is useful to override the repo
|
||||||
|
to use a local version instead of fetching from head. This can be achieved as
|
||||||
|
below but note, the BUILD files are not automatically generated from or CMake
|
||||||
|
used, so if your change requires a BUILD file change (or you are using a
|
||||||
|
different version of LLVM than set in tensorflow/workspace.bzl's LLVM_COMMIT)
|
||||||
|
then manual BUILD file changes may be required.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
LLVM_SRC=...
|
||||||
|
|
||||||
|
# Create basic workspace file
|
||||||
|
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
|
||||||
|
# and copy over the bazel BUILD files.
|
||||||
|
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
|
||||||
|
cp third_party/mlir/BUILD $LLVM_SRC/mlir
|
||||||
|
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD
|
||||||
|
|
||||||
|
bazel build --override_repository=llvm-project=$LLVM_SRC \
|
||||||
|
-c opt tensorflow/compiler/mlir:tf-opt
|
||||||
|
```
|
||||||
|
@ -48,6 +48,7 @@ filegroup(
|
|||||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
||||||
|
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||||
@ -539,6 +540,8 @@ cc_library(
|
|||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:Shape",
|
||||||
|
"@llvm-project//mlir:ShapeTransforms",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
|
@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Atan operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Atan(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\atan(x) = \atan2(x, 1)
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Sinh operation";
|
let summary = "Sinh operation";
|
||||||
|
@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
|||||||
>];
|
>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
|
||||||
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
|
||||||
|
|
||||||
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
||||||
|
|
||||||
@ -193,12 +196,10 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one",
|
|||||||
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
||||||
|
|
||||||
def HLO_ImagOp: HLO_Op<
|
def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
|
||||||
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
let builders = [OpBuilder<
|
DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
HLO_ComplexTensor>, BASE_HLO_ImagOp {
|
||||||
|
|
||||||
let arguments = (ins HLO_ComplexTensor);
|
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
@ -237,12 +238,10 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
|
|||||||
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
||||||
BASE_HLO_PopulationCountOp;
|
BASE_HLO_PopulationCountOp;
|
||||||
|
|
||||||
def HLO_RealOp: HLO_Op<
|
def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
|
||||||
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
let builders = [OpBuilder<
|
DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
HLO_ComplexTensor>, BASE_HLO_RealOp {
|
||||||
|
|
||||||
let arguments = (ins HLO_ComplexTensor);
|
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
@ -321,12 +320,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
|
|||||||
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
||||||
|
|
||||||
def HLO_ComplexOp: HLO_Op<"complex",
|
def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex",
|
||||||
[NoSideEffect, SameOperandsAndResultShape]>,
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
|
||||||
BASE_HLO_ComplexOp {
|
BASE_HLO_ComplexOp {
|
||||||
let builders = [OpBuilder<
|
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
|
|
||||||
|
|
||||||
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
||||||
let results = (outs HLO_ComplexTensor);
|
let results = (outs HLO_ComplexTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
@ -356,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
|||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
||||||
|
|
||||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
||||||
@ -913,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
|||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Make this struct dialect independent so that it can be shared
|
|
||||||
// between HLO and LHLO dialect.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
HLO_Tensor:$lhs,
|
(ins
|
||||||
HLO_Tensor:$rhs,
|
HLO_Tensor:$lhs,
|
||||||
// Default value: one for each of the spatial dimension.
|
HLO_Tensor:$rhs),
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
ConvolutionAttributes<HLO_Dialect>.attributes);
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
@ -1198,14 +1170,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
|
|||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp {
|
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<HLO_Tensor>:$operands,
|
Variadic<HLO_Tensor>:$operands,
|
||||||
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_TensorOrTuple);
|
let results = (outs Variadic<HLO_Tensor>);
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$comparator);
|
let regions = (region SizedRegion<1>:$comparator);
|
||||||
|
|
||||||
@ -1429,4 +1401,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
|
|||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is an op for purposes internal to XLA/GPU.
|
||||||
|
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
|
||||||
|
let arguments = (ins HLO_Tensor:$operand);
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
|
||||||
|
BASE_HLO_ReducePrecisionOp {
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_FpTensor:$operand,
|
||||||
|
I32Attr:$exponent_bits,
|
||||||
|
I32Attr:$mantissa_bits
|
||||||
|
);
|
||||||
|
let results = (outs HLO_FpTensor:$output);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // HLO_OPS
|
#endif // HLO_OPS
|
||||||
|
@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_CbrtOp {
|
||||||
|
string summary = "Cubic root operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns element-wise cubic root of the operand.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_CeilOp {
|
class BASE_HLO_CeilOp {
|
||||||
string summary = "Ceil operator";
|
string summary = "Ceil operator";
|
||||||
|
|
||||||
@ -996,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common convolution attributes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class ConvDimensionNumbersBase<Dialect dialect>
|
||||||
|
: StructAttr<"ConvDimensionNumbers", dialect, [
|
||||||
|
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||||
|
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||||
|
|
||||||
|
let description = "Structure of dimension information for conv op";
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConvolutionAttributes<Dialect dialect> {
|
||||||
|
dag attributes = (ins
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||||
|
// Default value: zero for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$padding,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||||
|
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
|
||||||
|
I64Attr:$feature_group_count,
|
||||||
|
I64Attr:$batch_group_count,
|
||||||
|
HLO_PrecisionConfigAttr:$precision_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_ConvOp {
|
class BASE_HLO_ConvOp {
|
||||||
string summary = "Convolution operator";
|
string summary = "Convolution operator";
|
||||||
|
|
||||||
@ -1336,4 +1383,17 @@ class BASE_HLO_WhileOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_BitcastOp {
|
||||||
|
string summary = "Bitcast operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
This op changes the shape of the input in the way that the physical
|
||||||
|
arranggment of elements are unchanged.
|
||||||
|
|
||||||
|
However, the op needs layout information to make sense of "physical
|
||||||
|
arrangement of elements". Layout support in MHLO is currently under
|
||||||
|
exploration.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // HLO_OPS_BASE
|
#endif // HLO_OPS_BASE
|
||||||
|
@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
|
|||||||
include "mlir/Interfaces/CopyOpInterface.td"
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||||
|
|
||||||
def LHLO_Dialect : Dialect {
|
def LHLO_Dialect : Dialect {
|
||||||
let name = "lmhlo";
|
let name = "lmhlo";
|
||||||
let cppNamespace = "::mlir::lmhlo";
|
let cppNamespace = "::mlir::lmhlo";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LMHLO type definitions.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Any integer tensor types
|
|
||||||
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
|
||||||
|
|
||||||
// Any floating-point tensor types
|
|
||||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
|
||||||
|
|
||||||
// Any integer or floating-point tensor types
|
|
||||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
|
||||||
|
|
||||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LMHLO nullary op definitions.
|
// LMHLO nullary op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -289,6 +264,16 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
|||||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
StrAttr:$call_target_name,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||||
|
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LMHLO tuple op definitions.
|
// LMHLO tuple op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -335,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
|||||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
"modifies the offset, sizes and strides of a statically shaped memref.
|
modifies the offset, sizes and strides of a statically shaped memref
|
||||||
}];
|
}];
|
||||||
let description = [{
|
let description = [{
|
||||||
Allows to modify the offset, sizes and strides of a statically shaped memref.
|
Casts the statically shaped memref operand to a memref with optionally
|
||||||
|
modified offsets, sizes and strides.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
@ -354,12 +340,11 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
|||||||
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
|
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
|
||||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<"MemRefType resultType, Value operand",
|
||||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
[{
|
||||||
"Value operand", [{
|
$_state.addOperands(operand);
|
||||||
result.addOperands(operand);
|
$_state.types.push_back(resultType);
|
||||||
result.types.push_back(resultType);
|
}]>];
|
||||||
}]>];
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||||
@ -400,13 +385,13 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
|||||||
);
|
);
|
||||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [
|
||||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, "
|
||||||
"Value operand, ValueRange sizes, ValueRange strides", [{
|
"ValueRange strides", [{
|
||||||
result.addOperands(operand);
|
$_state.addOperands(operand);
|
||||||
result.addOperands(sizes);
|
$_state.addOperands(sizes);
|
||||||
result.addOperands(strides);
|
$_state.addOperands(strides);
|
||||||
result.types.push_back(resultType);
|
$_state.types.push_back(resultType);
|
||||||
}]>];
|
}]>];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
@ -582,40 +567,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
|
||||||
// shared between the HLO and LHLO dialects.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
(ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
// Default value: one for each of the spatial dimension.
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
ConvolutionAttributes<LHLO_Dialect>.attributes);
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||||
@ -856,9 +814,8 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||||||
|
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
OpBuilder<"ArrayRef<NamedAttribute> attributes">
|
||||||
"ArrayRef<NamedAttribute> attributes">
|
];
|
||||||
];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def TerminatorOp :
|
def TerminatorOp :
|
||||||
@ -867,9 +824,8 @@ def TerminatorOp :
|
|||||||
let description = [{
|
let description = [{
|
||||||
Terminator operation for the LHLO dialect.
|
Terminator operation for the LHLO dialect.
|
||||||
}];
|
}];
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<"ValueRange operands",
|
||||||
"OpBuilder &b, OperationState &result, ValueRange operands",
|
[{ build($_builder, $_state, llvm::None, operands, llvm::None); }]
|
||||||
[{ build(b, result, llvm::None, operands, llvm::None); }]
|
|
||||||
>];
|
>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef LHLO_OPS_BASE
|
||||||
|
#define LHLO_OPS_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LMHLO type definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Any integer tensor types
|
||||||
|
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||||
|
|
||||||
|
// Any floating-point tensor types
|
||||||
|
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||||
|
|
||||||
|
// Any integer or floating-point tensor types
|
||||||
|
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||||
|
|
||||||
|
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
||||||
|
|
||||||
|
#endif // LHLO_OPS_BASE
|
@ -50,6 +50,7 @@ MAP_HLO_TO_LHLO(ConvOp);
|
|||||||
MAP_HLO_TO_LHLO(ConvertOp);
|
MAP_HLO_TO_LHLO(ConvertOp);
|
||||||
MAP_HLO_TO_LHLO(CopyOp);
|
MAP_HLO_TO_LHLO(CopyOp);
|
||||||
MAP_HLO_TO_LHLO(CosOp);
|
MAP_HLO_TO_LHLO(CosOp);
|
||||||
|
MAP_HLO_TO_LHLO(CustomCallOp);
|
||||||
MAP_HLO_TO_LHLO(DivOp);
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
MAP_HLO_TO_LHLO(DotOp);
|
MAP_HLO_TO_LHLO(DotOp);
|
||||||
MAP_HLO_TO_LHLO(ExpOp);
|
MAP_HLO_TO_LHLO(ExpOp);
|
||||||
@ -57,11 +58,13 @@ MAP_HLO_TO_LHLO(FloorOp);
|
|||||||
MAP_HLO_TO_LHLO(GatherOp);
|
MAP_HLO_TO_LHLO(GatherOp);
|
||||||
MAP_HLO_TO_LHLO(ImagOp);
|
MAP_HLO_TO_LHLO(ImagOp);
|
||||||
MAP_HLO_TO_LHLO(IotaOp);
|
MAP_HLO_TO_LHLO(IotaOp);
|
||||||
|
MAP_HLO_TO_LHLO(IsFiniteOp);
|
||||||
MAP_HLO_TO_LHLO(LogOp);
|
MAP_HLO_TO_LHLO(LogOp);
|
||||||
MAP_HLO_TO_LHLO(MaxOp);
|
MAP_HLO_TO_LHLO(MaxOp);
|
||||||
MAP_HLO_TO_LHLO(MinOp);
|
MAP_HLO_TO_LHLO(MinOp);
|
||||||
MAP_HLO_TO_LHLO(MulOp);
|
MAP_HLO_TO_LHLO(MulOp);
|
||||||
MAP_HLO_TO_LHLO(NegOp);
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
|
MAP_HLO_TO_LHLO(NotOp);
|
||||||
MAP_HLO_TO_LHLO(RealOp);
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
MAP_HLO_TO_LHLO(ReduceOp);
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
|
@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
|||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename PredicateType>
|
template <typename PredicateType>
|
||||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
@ -345,6 +354,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
|||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
if (args[0].getType().isa<FloatType>()) {
|
||||||
|
auto pos_inf = APFloat::getInf(
|
||||||
|
args[0].getType().cast<FloatType>().getFloatSemantics());
|
||||||
|
auto const_pos_inf =
|
||||||
|
b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
|
||||||
|
Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
|
||||||
|
return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
|
||||||
|
const_pos_inf);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
||||||
/// linalg.generic op) for compare-select style operations like min/max.
|
/// linalg.generic op) for compare-select style operations like min/max.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
@ -431,6 +456,21 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||||
|
// lmhlo.not(x) -> x ^ -1
|
||||||
|
auto all_ones =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
|
||||||
|
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
@ -454,11 +494,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
|||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = args.front().getType();
|
Type element_type = args.front().getType();
|
||||||
if (element_type.isa<FloatType>()) {
|
if (auto float_type = element_type.dyn_cast<FloatType>()) {
|
||||||
FloatType float_type = element_type.cast<FloatType>();
|
bool ignored;
|
||||||
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
APFloat one_apfloat(1.0f);
|
||||||
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
|
one_apfloat.convert(float_type.getFloatSemantics(),
|
||||||
|
APFloat::rmNearestTiesToEven, &ignored);
|
||||||
|
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
|
||||||
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||||
|
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||||
|
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
|
||||||
|
Value zero =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||||
|
Value cmp =
|
||||||
|
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
|
||||||
|
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
|
||||||
|
loc, integer_type.getWidth() - 1, integer_type.getWidth());
|
||||||
|
Value ashr =
|
||||||
|
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
|
||||||
|
Value one =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
|
||||||
|
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
|
||||||
|
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
|
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
|
||||||
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
|
let summary = "Legalize CHLO to HLO.";
|
||||||
let constructor = "createTestChloLegalizeToHloPass()";
|
let constructor = "createChloLegalizeToHloPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||||
|
@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
|||||||
/// Lowers from HLO dialect to Standard dialect.
|
/// Lowers from HLO dialect to Standard dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
|
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||||
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
|
||||||
|
|
||||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
/// buffers if necessary. If `results_escape_functions` is set to true,
|
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||||
/// allocated buffers for function results will be returned and escape the
|
/// allocated buffers for function results will be returned and escape the
|
||||||
@ -63,7 +66,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
|||||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
||||||
|
|
||||||
/// Lowers trigonometric operations from the standard dialect to approximations
|
/// Lowers trigonometric operations from the standard dialect to approximations
|
||||||
// that do not use intrinsics.
|
/// that do not use intrinsics.
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
createLegalizeTrigonometricToApproximationPass();
|
createLegalizeTrigonometricToApproximationPass();
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
|
|
||||||
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
||||||
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
||||||
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Transforms/BufferPlacement.h"
|
#include "mlir/Transforms/Bufferize.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -185,8 +185,7 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const auto& dnums = gather.dimension_numbers();
|
const auto& dnums = gather.dimension_numbers();
|
||||||
if (dnums.collapsed_slice_dims().getNumElements() != 0 ||
|
if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
|
||||||
dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// TODO(tberghammer): Remove when the verifier catches this case what is
|
// TODO(tberghammer): Remove when the verifier catches this case what is
|
||||||
@ -206,11 +205,35 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
|
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
|
||||||
rewriter.replaceOpWithNewOp<SliceOp>(
|
llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
|
||||||
gather, gather.getType(), gather.getOperand(0),
|
for (int64_t i = 0; i < slice_end.size(); ++i) {
|
||||||
|
slice_shape[i] = slice_end[i] - slice_start[i];
|
||||||
|
}
|
||||||
|
Type element_type = gather.getType().cast<TensorType>().getElementType();
|
||||||
|
auto slice_type = RankedTensorType::get(slice_shape, element_type);
|
||||||
|
Value result = rewriter.create<SliceOp>(
|
||||||
|
gather.getLoc(), slice_type, gather.getOperand(0),
|
||||||
GetI64ElementsAttr(slice_start, &rewriter),
|
GetI64ElementsAttr(slice_start, &rewriter),
|
||||||
GetI64ElementsAttr(slice_end, &rewriter),
|
GetI64ElementsAttr(slice_end, &rewriter),
|
||||||
GetI64ElementsAttr(slice_stride, &rewriter));
|
GetI64ElementsAttr(slice_stride, &rewriter));
|
||||||
|
|
||||||
|
if (dnums.collapsed_slice_dims().getNumElements() > 0) {
|
||||||
|
auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
|
||||||
|
dnums.collapsed_slice_dims().getIntValues(),
|
||||||
|
[](const llvm::APInt& i) { return i.getSExtValue(); }));
|
||||||
|
llvm::SmallVector<int64_t, 8> reshape_shape;
|
||||||
|
for (int64_t i = 0; i < slice_shape.size(); ++i) {
|
||||||
|
if (llvm::count(collapsed_slice_dims, i) == 0) {
|
||||||
|
reshape_shape.push_back(slice_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
|
||||||
|
result =
|
||||||
|
rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.setType(gather.getType());
|
||||||
|
rewriter.replaceOp(gather, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -889,9 +912,10 @@ static LogicalResult Verify(ClampOp op) {
|
|||||||
// ComplexOp
|
// ComplexOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
LogicalResult ComplexOp::inferReturnTypes(
|
||||||
Value rhs) {
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
auto type = lhs.getType();
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
auto type = operands[0].getType();
|
||||||
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
||||||
Type result_ty;
|
Type result_ty;
|
||||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
@ -901,8 +925,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
|||||||
} else {
|
} else {
|
||||||
result_ty = element_ty;
|
result_ty = element_ty;
|
||||||
}
|
}
|
||||||
|
inferredReturnTypes.push_back(result_ty);
|
||||||
build(builder, state, result_ty, lhs, rhs);
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||||
@ -932,8 +956,11 @@ Type CreateRealType(Type type) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
LogicalResult ImagOp::inferReturnTypes(
|
||||||
build(builder, state, CreateRealType(val.getType()), val);
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||||
@ -945,8 +972,11 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
LogicalResult RealOp::inferReturnTypes(
|
||||||
build(builder, state, CreateRealType(val.getType()), val);
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||||
@ -1971,6 +2001,23 @@ struct divide<APInt> {
|
|||||||
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct remainder : std::modulus<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APInt> {
|
||||||
|
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APFloat> {
|
||||||
|
APFloat operator()(const APFloat& a, const APFloat& b) const {
|
||||||
|
APFloat result(a);
|
||||||
|
result.remainder(b);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct max {
|
struct max {
|
||||||
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
||||||
@ -2012,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
|
|||||||
BINARY_FOLDER(SubOp, std::minus);
|
BINARY_FOLDER(SubOp, std::minus);
|
||||||
BINARY_FOLDER(MulOp, std::multiplies);
|
BINARY_FOLDER(MulOp, std::multiplies);
|
||||||
BINARY_FOLDER(DivOp, divide);
|
BINARY_FOLDER(DivOp, divide);
|
||||||
|
BINARY_FOLDER(RemOp, remainder);
|
||||||
BINARY_FOLDER(MaxOp, max);
|
BINARY_FOLDER(MaxOp, max);
|
||||||
BINARY_FOLDER(MinOp, min);
|
BINARY_FOLDER(MinOp, min);
|
||||||
|
|
||||||
@ -2261,10 +2309,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state,
|
|||||||
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
|
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
|
||||||
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
|
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
|
||||||
|
|
||||||
SmallVector<Type, 2> element_types;
|
for (Value operand : operands) state.addTypes(operand.getType());
|
||||||
element_types.reserve(operands.size());
|
|
||||||
for (Value operand : operands) element_types.push_back(operand.getType());
|
|
||||||
state.addTypes(builder.getTupleType(element_types));
|
|
||||||
|
|
||||||
state.addRegion();
|
state.addRegion();
|
||||||
}
|
}
|
||||||
|
@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||||||
auto if_op = rewriter.create<scf::IfOp>(
|
auto if_op = rewriter.create<scf::IfOp>(
|
||||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||||
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
||||||
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
|
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||||
@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||||
if_rhs_scalar_op.getResult(0));
|
if_rhs_scalar_op.getResult(0));
|
||||||
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
||||||
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
|
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||||
@ -516,7 +516,7 @@ struct HloCompareAdaptor {
|
|||||||
|
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
populateWithGenerated(context, patterns);
|
populateWithGenerated(context, *patterns);
|
||||||
|
|
||||||
// Instantiate conversion templates for conforming binary elementwise ops
|
// Instantiate conversion templates for conforming binary elementwise ops
|
||||||
// that do not have different dtypes between operands and results and do
|
// that do not have different dtypes between operands and results and do
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user