Merge branch 'master' into pluggable_device_load
This commit is contained in:
commit
22d22ff5b3
6
.bazelrc
6
.bazelrc
@ -174,6 +174,12 @@ build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
|
|||||||
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
||||||
build:mkl_opensource_only -c opt
|
build:mkl_opensource_only -c opt
|
||||||
|
|
||||||
|
# Config setting to build with oneDNN for Arm.
|
||||||
|
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
|
||||||
|
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
build:mkl_aarch64 --define=build_with_mkl_opensource=true
|
||||||
|
build:mkl_aarch64 -c opt
|
||||||
|
|
||||||
# This config refers to building with CUDA available. It does not necessarily
|
# This config refers to building with CUDA available. It does not necessarily
|
||||||
# mean that we build CUDA op kernels.
|
# mean that we build CUDA op kernels.
|
||||||
build:using_cuda --define=using_cuda=true
|
build:using_cuda --define=using_cuda=true
|
||||||
|
6
.github/bot_config.yml
vendored
6
.github/bot_config.yml
vendored
@ -12,12 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
#
|
|
||||||
# THIS IS A GENERATED DOCKERFILE.
|
|
||||||
#
|
|
||||||
# This file was assembled from multiple pieces, whose use is documented
|
|
||||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
|
||||||
# for more information.
|
|
||||||
|
|
||||||
# A list of assignees
|
# A list of assignees
|
||||||
assignees:
|
assignees:
|
||||||
|
28
.github/workflows/update-nightly.yml
vendored
Normal file
28
.github/workflows/update-nightly.yml
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # Allow manual triggers
|
||||||
|
schedule:
|
||||||
|
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
|
||||||
|
name: Set nightly branch to master HEAD
|
||||||
|
jobs:
|
||||||
|
master-to-nightly:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: zofrex/mirror-branch@v1
|
||||||
|
name: Set nightly branch to master HEAD
|
||||||
|
with:
|
||||||
|
target-branch: 'nightly'
|
@ -4,7 +4,7 @@
|
|||||||
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
|
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
|
||||||
/tenosrflow/core/debug @caisq
|
/tenosrflow/core/debug @caisq
|
||||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||||
/tensorflow/core/platform/windows/ @gunan @mihaimaruseac
|
/tensorflow/core/platform/windows/ @mihaimaruseac
|
||||||
/tensorflow/lite/experimental/micro @petewarden @advaitjain
|
/tensorflow/lite/experimental/micro @petewarden @advaitjain
|
||||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||||
/tensorflow/python/debug @caisq
|
/tensorflow/python/debug @caisq
|
||||||
|
278
RELEASE.md
278
RELEASE.md
@ -34,6 +34,7 @@
|
|||||||
shape assumptions (note that you can pass shapes with `None` entries for axes
|
shape assumptions (note that you can pass shapes with `None` entries for axes
|
||||||
that are meant to be dynamic). You can also disable the input checking
|
that are meant to be dynamic). You can also disable the input checking
|
||||||
entirely by setting `model.input_spec = None`.
|
entirely by setting `model.input_spec = None`.
|
||||||
|
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
|
||||||
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
|
||||||
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
|
||||||
removed).
|
removed).
|
||||||
@ -46,6 +47,13 @@
|
|||||||
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
|
||||||
instead of individual arguments. Usages should be updated to
|
instead of individual arguments. Usages should be updated to
|
||||||
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
`tf.data.experimental.service.WorkerServer(worker_config)`.
|
||||||
|
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
|
||||||
|
updates the gradient definition for quantization which is outside the range
|
||||||
|
to be 0. To simulate the V1 the behavior of
|
||||||
|
tf.quantization.quantize_and_dequantize(...) use
|
||||||
|
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
|
||||||
|
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
|
||||||
|
use `tf.data.Dataset.from_tensor_slices` instead.
|
||||||
|
|
||||||
## Known Caveats
|
## Known Caveats
|
||||||
|
|
||||||
@ -84,17 +92,20 @@
|
|||||||
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
|
||||||
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
|
||||||
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
|
||||||
* Fixes an integer truncation vulnerability in code using the work sharder API
|
* Fixes an integer truncation vulnerability in code using the work sharder
|
||||||
|
API
|
||||||
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
|
||||||
* Fixes a format string vulnerability in `tf.strings.as_string`
|
* Fixes a format string vulnerability in `tf.strings.as_string`
|
||||||
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
|
||||||
* Fixes segfault raised by calling session-only ops in eager mode
|
* Fixes segfault raised by calling session-only ops in eager mode
|
||||||
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
|
||||||
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
|
* Fixes data leak and potential ASLR violation from
|
||||||
|
`tf.raw_ops.StringNGrams`
|
||||||
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
|
||||||
* Fixes segfaults caused by incomplete `SavedModel` validation
|
* Fixes segfaults caused by incomplete `SavedModel` validation
|
||||||
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
|
||||||
* Fixes a data corruption due to a bug in negative indexing support in TFLite
|
* Fixes a data corruption due to a bug in negative indexing support in
|
||||||
|
TFLite
|
||||||
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
|
||||||
* Fixes a data corruption due to dimension mismatch in TFLite
|
* Fixes a data corruption due to dimension mismatch in TFLite
|
||||||
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
|
||||||
@ -107,53 +118,56 @@
|
|||||||
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
|
||||||
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
|
||||||
* TF Core:
|
* TF Core:
|
||||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
|
* `tf.types.experimental.TensorLike` is a new `Union` type that can be
|
||||||
type annotation for variables representing a Tensor or a value that can be
|
used as type annotation for variables representing a Tensor or a value
|
||||||
converted to Tensor by `tf.convert_to_tensor`.
|
that can be converted to Tensor by `tf.convert_to_tensor`.
|
||||||
* Calling ops with a python constants or numpy values is now consistent with
|
* Calling ops with a python constants or numpy values is now consistent
|
||||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
with tf.convert_to_tensor behavior. This avoids operations like
|
||||||
truncating inputs such as from int64 to int32.
|
tf.reshape truncating inputs such as from int64 to int32.
|
||||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
* Added `tf.sparse.map_values` to apply a function to the `.value`s of
|
||||||
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
|
`SparseTensor` arguments.
|
||||||
and `__invert__` now support non-`bool` arguments and apply the
|
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
|
||||||
corresponding bitwise ops. `bool` arguments continue to be supported and
|
`__xor__` and `__invert__` now support non-`bool` arguments and apply
|
||||||
dispatch to logical ops. This brings them more in line with Python and NumPy
|
the corresponding bitwise ops. `bool` arguments continue to be supported
|
||||||
benavior.
|
and dispatch to logical ops. This brings them more in line with Python
|
||||||
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
|
and NumPy behavior.
|
||||||
the same sparsity pattern, but with new provided values. It is similar to
|
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
|
||||||
the `with_values` function of `RaggedTensor`.
|
with the same sparsity pattern, but with new provided values. It is
|
||||||
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
|
similar to the `with_values` function of `RaggedTensor`.
|
||||||
* Added `tf.config.experimental.get_memory_usage` to return total memory usage
|
* Added `StatelessCase` op, and uses it if none of case branches has
|
||||||
of the device.
|
stateful ops.
|
||||||
|
* Added `tf.config.experimental.get_memory_usage` to return total memory
|
||||||
|
usage of the device.
|
||||||
* `tf.data`:
|
* `tf.data`:
|
||||||
* tf.data service:
|
* tf.data service:
|
||||||
* Added new `tf.data.experimental.service.register_dataset` and
|
* Added new `tf.data.experimental.service.register_dataset` and
|
||||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
|
`tf.data.experimental.service.from_dataset_id` APIs to enable one
|
||||||
to register a dataset with the tf.data service, and another process to
|
process to register a dataset with the tf.data service, and another
|
||||||
consume data from the dataset.
|
process to consume data from the dataset.
|
||||||
* Added support for dispatcher fault tolerance. To enable fault tolerance,
|
* Added support for dispatcher fault tolerance. To enable fault tolerance,
|
||||||
configure a `work_dir` when running your dispatcher server and set
|
configure a `work_dir` when running your dispatcher server and set
|
||||||
`dispatcher_fault_tolerance=True`. The dispatcher will store its state to
|
`dispatcher_fault_tolerance=True`. The dispatcher will store its state
|
||||||
`work_dir`, so that on restart it can continue from its previous state
|
to `work_dir`, so that on restart it can continue from its previous
|
||||||
after restart.
|
state after restart.
|
||||||
* Added support for sharing dataset graphs via shared filesystem instead of
|
* Added support for sharing dataset graphs via shared filesystem instead
|
||||||
over RPC. This reduces load on the dispatcher, improving performance of
|
of over RPC. This reduces load on the dispatcher, improving performance
|
||||||
distributing datasets. For this to work, the dispatcher's `work_dir` must
|
of distributing datasets. For this to work, the dispatcher's `work_dir`
|
||||||
be accessible from workers. If the worker fails to read from the
|
must be accessible from workers. If the worker fails to read from the
|
||||||
`work_dir`, it falls back to using RPC for dataset graph transfer.
|
`work_dir`, it falls back to using RPC for dataset graph transfer.
|
||||||
* Added support for a new "distributed_epoch" processing mode. This
|
* Added support for a new "distributed_epoch" processing mode. This
|
||||||
processing mode distributes a dataset across all tf.data workers, instead
|
processing mode distributes a dataset across all tf.data workers,
|
||||||
of having each worker process the full dataset. See
|
instead of having each worker process the full dataset. See
|
||||||
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
|
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
|
||||||
to learn more.
|
to learn more.
|
||||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||||
the complement of `select_cols`; at most one of these should be specified.
|
the complement of `select_cols`; at most one of these should be
|
||||||
|
specified.
|
||||||
* We have implemented an optimization which reorders data-discarding
|
* We have implemented an optimization which reorders data-discarding
|
||||||
transformations such as `take` and `shard` to happen earlier in the
|
transformations such as `take` and `shard` to happen earlier in the
|
||||||
dataset when it is safe to do so. The optimization can be disabled via
|
dataset when it is safe to do so. The optimization can be disabled via
|
||||||
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
the `experimental_optimization.reorder_data_discarding_ops` dataset
|
||||||
option.
|
option.
|
||||||
* `tf.data.Options` were previously immutable and can now be overriden.
|
* `tf.data.Options` were previously immutable and can now be overridden.
|
||||||
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
|
||||||
with a new `output_signature` argument, which allows `from_generator` to
|
with a new `output_signature` argument, which allows `from_generator` to
|
||||||
produce any type describable by a `tf.TypeSpec`.
|
produce any type describable by a `tf.TypeSpec`.
|
||||||
@ -162,23 +176,32 @@
|
|||||||
* `tf.image`:
|
* `tf.image`:
|
||||||
* Added deterministic `tf.image.stateless_random_*` functions for each
|
* Added deterministic `tf.image.stateless_random_*` functions for each
|
||||||
`tf.image.random_*` function. Added a new op
|
`tf.image.random_*` function. Added a new op
|
||||||
`stateless_sample_distorted_bounding_box` which is a determinstic
|
`stateless_sample_distorted_bounding_box` which is a deterministic
|
||||||
version of `sample_distorted_bounding_box` op. Given the same seed, these
|
version of `sample_distorted_bounding_box` op. Given the same seed,
|
||||||
stateless functions/ops produce the same results independent of how many
|
these stateless functions/ops produce the same results independent of
|
||||||
times the function is called, and independent of global seed settings.
|
how many times the function is called, and independent of global seed
|
||||||
|
settings.
|
||||||
* `tf.distribute`:
|
* `tf.distribute`:
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
* `tf.keras`:
|
* `tf.keras`:
|
||||||
* Improvements from the functional API refactoring:
|
* Improvements from the functional API refactoring:
|
||||||
* Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models.
|
* Functional model construction does not need to maintain a global
|
||||||
|
workspace graph, removing memory leaks especially when building many
|
||||||
|
models or very large models.
|
||||||
* Functional model construction should be ~8-10% faster on average.
|
* Functional model construction should be ~8-10% faster on average.
|
||||||
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument.
|
* Functional models can now contain non-symbolic values in their call
|
||||||
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale`
|
inputs inside of the first positional argument.
|
||||||
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
|
* Several classes of TF ops that were not reliably converted to Keras
|
||||||
|
layers during functional API construction should now work, e.g.
|
||||||
|
`tf.image.ssim_multiscale`
|
||||||
|
* Error messages when Functional API construction goes wrong (and when
|
||||||
|
ops cannot be converted to Keras layers automatically) should be
|
||||||
|
clearer and easier to understand.
|
||||||
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
|
||||||
as an alternative to accepting a `callable` loss.
|
as an alternative to accepting a `callable` loss.
|
||||||
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
|
||||||
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
|
to match FTRL paper
|
||||||
|
(https://research.google.com/pubs/archive/41159.pdf).
|
||||||
* Added `mobilenet_v3` to keras application model.
|
* Added `mobilenet_v3` to keras application model.
|
||||||
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
|
||||||
customization of how gradients are aggregated across devices, as well as
|
customization of how gradients are aggregated across devices, as well as
|
||||||
@ -191,15 +214,25 @@
|
|||||||
`fit()`. Running multiple batches inside a single `tf.function` call can
|
`fit()`. Running multiple batches inside a single `tf.function` call can
|
||||||
greatly improve performance on TPUs or small models with a large Python
|
greatly improve performance on TPUs or small models with a large Python
|
||||||
overhead.
|
overhead.
|
||||||
|
* Improvements to Keras preprocessing layers:
|
||||||
|
* TextVectorization can now accept a vocabulary list or file as an
|
||||||
|
init arg.
|
||||||
|
* Normalization can now accept mean and variance values as init args.
|
||||||
|
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||||
|
accepts a `return_attention_scores` argument. When set to
|
||||||
|
True, the layer returns the attention scores as an additional output
|
||||||
|
argument.
|
||||||
|
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||||
|
with the same implementation as their `tf.losses` equivalent.
|
||||||
* `tf.function` / AutoGraph:
|
* `tf.function` / AutoGraph:
|
||||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||||
True, the function may use type annotations to optimize the tracing
|
True, the function may use type annotations to optimize the tracing
|
||||||
performance.
|
performance.
|
||||||
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
|
||||||
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
|
||||||
the values of these symbols at an iteration does not depend on the previous
|
the values of these symbols at an iteration does not depend on the
|
||||||
iteration. These types of loops must run at least one iteration, and will
|
previous iteration. These types of loops must run at least one
|
||||||
raise a runtime error otherwise.
|
iteration, and will raise a runtime error otherwise.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -208,51 +241,97 @@
|
|||||||
outputs = train_step(batch)
|
outputs = train_step(batch)
|
||||||
tf.print('final outputs', outputs)
|
tf.print('final outputs', outputs)
|
||||||
```
|
```
|
||||||
|
|
||||||
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
|
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
|
||||||
info.
|
info.
|
||||||
|
|
||||||
* `tf.lite`:
|
* `tf.lite`:
|
||||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
|
||||||
string to be joined is empty.
|
|
||||||
* `TFLiteConverter`:
|
* `TFLiteConverter`:
|
||||||
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
|
* Support optional flags `inference_input_type` and
|
||||||
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
|
`inference_output_type` for full integer quantized models. This
|
||||||
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
|
allows users to modify the model input and output type to integer
|
||||||
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
|
types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
|
||||||
|
(`tf.float32`).
|
||||||
* TFLite Profiler for Android is available. See the detailed
|
* TFLite Profiler for Android is available. See the detailed
|
||||||
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
|
||||||
|
* NNAPI
|
||||||
|
* Added NNAPI Delegation support for requantization use cases by
|
||||||
|
converting the operation into a dequantize-quantize pair.
|
||||||
|
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
|
||||||
|
* Use `Interpreter.Options.setUseNNAPI` instead.
|
||||||
|
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
|
||||||
|
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||||
|
directly.
|
||||||
|
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
|
||||||
|
* Prefer controlling this via delegate options, e.g.
|
||||||
|
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
|
||||||
|
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
||||||
|
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||||
|
string to be joined is empty.
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* `tf.random`:
|
* `tf.random`:
|
||||||
|
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* Math and Linear Algebra:
|
* Math and Linear Algebra:
|
||||||
|
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* TPU Enhancements:
|
* TPU Enhancements:
|
||||||
|
|
||||||
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
* Added support for the `beta` parameter of the FTRL optimizer for TPU
|
||||||
embeddings. Users of other TensorFlow platforms can implement equivalent
|
embeddings. Users of other TensorFlow platforms can implement equivalent
|
||||||
behavior by adjusting the `l2` parameter.
|
behavior by adjusting the `l2` parameter.
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* XLA Support:
|
* XLA Support:
|
||||||
|
|
||||||
* xla.experimental.compile is deprecated, use
|
* xla.experimental.compile is deprecated, use
|
||||||
`tf.function(experimental_compile=True)` instead
|
`tf.function(experimental_compile=True)` instead
|
||||||
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
|
* Added `tf.function.experimental_get_compiler_ir` which returns compiler
|
||||||
(currently 'hlo' and 'optimized_hlo') for given input for given function.
|
IR (currently 'hlo' and 'optimized_hlo') for given input for given
|
||||||
|
function.
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* Tracing and Debugging:
|
* Tracing and Debugging:
|
||||||
|
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* `tf.train.Checkpoint`:
|
* `tf.train.Checkpoint`:
|
||||||
|
|
||||||
* Now accepts a `root` argument in the initialization, which generates a
|
* Now accepts a `root` argument in the initialization, which generates a
|
||||||
checkpoint with a root object. This allows users to create a `Checkpoint`
|
checkpoint with a root object. This allows users to create a
|
||||||
object that is compatible with Keras `model.save_weights()` and
|
`Checkpoint` object that is compatible with Keras `model.save_weights()`
|
||||||
`model.load_weights`. The checkpoint is also compatible with the
|
and `model.load_weights`. The checkpoint is also compatible with the
|
||||||
checkpoint saved in the `variables/` folder in the SavedModel.
|
checkpoint saved in the `variables/` folder in the SavedModel.
|
||||||
* When restoring, `save_path` can be a path to a SavedModel. The function
|
* When restoring, `save_path` can be a path to a SavedModel. The function
|
||||||
will automatically find the checkpoint in the SavedModel.
|
will automatically find the checkpoint in the SavedModel.
|
||||||
|
|
||||||
* `tf.nn`:
|
* `tf.nn`:
|
||||||
|
|
||||||
* `tf.nn.max_pool2d` now supports explicit padding.
|
* `tf.nn.max_pool2d` now supports explicit padding.
|
||||||
|
|
||||||
|
* `tf.debugging`:
|
||||||
|
|
||||||
|
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
||||||
|
|
||||||
|
* `tf.print`:
|
||||||
|
|
||||||
|
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
|
||||||
|
didn't have the keys sorted, the keys and values were not being printed
|
||||||
|
in accordance with their correct mapping.
|
||||||
|
|
||||||
* Other:
|
* Other:
|
||||||
|
|
||||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||||
and "denylist" where possible. Please see
|
and "denylist" where possible. Please see
|
||||||
https://developers.google.com/style/word-list#blacklist for more context.
|
https://developers.google.com/style/word-list#blacklist for more
|
||||||
<ADD RELEASE NOTES HERE>
|
context.
|
||||||
|
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
|
||||||
|
rollout the new MLIR TPU bridge.
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
## Thanks to our Contributors
|
## Thanks to our Contributors
|
||||||
|
|
||||||
@ -500,42 +579,87 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
|||||||
# Release 2.3.0
|
# Release 2.3.0
|
||||||
|
|
||||||
## Major Features and Improvements
|
## Major Features and Improvements
|
||||||
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
|
|
||||||
|
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and
|
||||||
|
save resources:
|
||||||
|
|
||||||
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
||||||
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
||||||
|
|
||||||
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
|
In addition checkout the detailed
|
||||||
|
[guide](https://www.tensorflow.org/guide/data_performance_analysis) for
|
||||||
|
analyzing input pipeline performance with TF Profiler.
|
||||||
|
|
||||||
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
|
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
|
||||||
|
is now a stable API and no longer considered experimental for TensorFlow.
|
||||||
|
(earlier `tf.distribute.experimental.TPUStrategy`).
|
||||||
|
|
||||||
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
|
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new
|
||||||
|
tools: a memory profiler to visualize your model’s memory usage over time
|
||||||
|
and a [python tracer](https://www.tensorflow.org/guide/profiler#events)
|
||||||
|
which allows you to trace python function calls in your model. Usability
|
||||||
|
improvements include better diagnostic messages and
|
||||||
|
[profile options](https://tensorflow.org/guide/profiler#collect_performance_data)
|
||||||
|
to customize the host and device trace verbosity level.
|
||||||
|
|
||||||
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
|
* Introduces experimental support for Keras Preprocessing Layers API
|
||||||
|
([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly))
|
||||||
|
to handle data preprocessing operations, with support for composite tensor
|
||||||
|
inputs. Please see below for additional details on these layers.
|
||||||
|
|
||||||
* TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
* TFLite now properly supports dynamic shapes during conversion and inference.
|
||||||
|
We’ve also added opt-in support on Android and iOS for
|
||||||
|
[XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack),
|
||||||
|
a highly optimized set of CPU kernels, as well as opt-in support for
|
||||||
|
[executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
||||||
|
|
||||||
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
* Libtensorflow packages are available in GCS starting this release. We have
|
||||||
|
also started to
|
||||||
|
[release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
||||||
|
|
||||||
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
|
* The experimental Python API
|
||||||
|
[`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info)
|
||||||
|
now allows you to instrument a TensorFlow program and dump debugging
|
||||||
|
information to a directory on the file system. The directory can be read and
|
||||||
|
visualized by a new interactive dashboard in TensorBoard 2.3 called
|
||||||
|
[Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which
|
||||||
|
reveals the details of the TensorFlow program including graph structures,
|
||||||
|
history of op executions at the Python (eager) and intra-graph levels, the
|
||||||
|
runtime dtype, shape, and numerical composition of tensors, as well as their
|
||||||
|
code locations.
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
||||||
* `tf.data`
|
* `tf.data`
|
||||||
* Makes the following (breaking) changes to the `tf.data`.
|
* Makes the following (breaking) changes to the `tf.data`.
|
||||||
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
|
* C++ API: - `IteratorBase::RestoreInternal`,
|
||||||
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
|
`IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState`
|
||||||
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
|
become pure-virtual and subclasses are now expected to provide an
|
||||||
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
|
implementation.
|
||||||
|
* The deprecated `DatasetBase::IsStateful` method is removed in favor of
|
||||||
|
`DatasetBase::CheckExternalState`.
|
||||||
|
* Deprecated overrides of `DatasetBase::MakeIterator` and
|
||||||
|
`MakeIteratorFromInputElement` are removed.
|
||||||
|
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and
|
||||||
|
`tensorflow::data::IteratorBase::SaveInput` has been extended with
|
||||||
|
`SerializationContext` argument to enable overriding the default policy
|
||||||
|
for the handling external state during iterator checkpointing. This is
|
||||||
|
not a backwards compatible change and all subclasses of `IteratorBase`
|
||||||
|
*need to be updated* accordingly.
|
||||||
* `tf.keras`
|
* `tf.keras`
|
||||||
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
|
* Add a new `BackupAndRestore` callback for handling distributed training
|
||||||
|
failures & restarts. Please take a look at this
|
||||||
|
[tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
|
||||||
|
for details on how to use the callback.
|
||||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||||
the output is different from (incorrect) previous versions. Note this
|
the output is different from (incorrect) previous versions. Note this
|
||||||
breaking change only impacts `tf.image.extract_glimpse` and
|
breaking change only impacts `tf.image.extract_glimpse` and
|
||||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||||
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
|
existing C++ kernel `ExtractGlimpse` does not change either, so saved models
|
||||||
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
||||||
|
|
||||||
## Known Caveats
|
## Known Caveats
|
||||||
* `tf.lite`
|
* `tf.lite`
|
||||||
@ -1525,8 +1649,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
|
|||||||
conversion. TensorRT initialization arguments are now passed wrapped in
|
conversion. TensorRT initialization arguments are now passed wrapped in
|
||||||
a named-tuple, `TrtConversionParams`, rather than as separate arguments
|
a named-tuple, `TrtConversionParams`, rather than as separate arguments
|
||||||
as in `TrtGraphConverter`.
|
as in `TrtGraphConverter`.
|
||||||
* Changed API to optimize TensorRT enginges during graph optimization.
|
* Changed API to optimize TensorRT engines during graph optimization. This
|
||||||
This is now done by calling `converter.build()` where previously
|
is now done by calling `converter.build()` where previously
|
||||||
`is_dynamic_op=False` would be set.
|
`is_dynamic_op=False` would be set.
|
||||||
* `converter.convert()` no longer returns a `tf.function`. Now the
|
* `converter.convert()` no longer returns a `tf.function`. Now the
|
||||||
function must be accessed from the saved model.
|
function must be accessed from the saved model.
|
||||||
|
@ -1485,6 +1485,7 @@ def main():
|
|||||||
'adding "--config=<>" to your build command. See .bazelrc for more '
|
'adding "--config=<>" to your build command. See .bazelrc for more '
|
||||||
'details.')
|
'details.')
|
||||||
config_info_line('mkl', 'Build with MKL support.')
|
config_info_line('mkl', 'Build with MKL support.')
|
||||||
|
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
|
||||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||||
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
config_info_line('ngraph', 'Build with Intel nGraph support.')
|
||||||
config_info_line('numa', 'Build with NUMA support.')
|
config_info_line('numa', 'Build with NUMA support.')
|
||||||
|
@ -568,17 +568,7 @@ selects.config_setting_group(
|
|||||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||||
package_group(
|
package_group(
|
||||||
name = "internal",
|
name = "internal",
|
||||||
packages = [
|
packages = ["//tensorflow/..."],
|
||||||
"//learning/brain/distribute/...",
|
|
||||||
"//learning/brain/swift/x10/...",
|
|
||||||
"//perftools/accelerators/xprof/api/...",
|
|
||||||
"//tensorflow/...",
|
|
||||||
"//tensorflow_estimator/python/estimator/...",
|
|
||||||
"//tensorflow_models/official/...",
|
|
||||||
"//third_party/py/autograph/...",
|
|
||||||
"//third_party/swift/tensorflow/x10/...",
|
|
||||||
"//third_party/swift/tensorflow_apis/...",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
package_group(
|
package_group(
|
||||||
@ -588,10 +578,8 @@ package_group(
|
|||||||
|
|
||||||
# Packages that use private types symbols, until they are exported.
|
# Packages that use private types symbols, until they are exported.
|
||||||
# TODO(b/154650521) Remove.
|
# TODO(b/154650521) Remove.
|
||||||
package_group(
|
# If this is modified, then copy.bara.sky must also be modified.
|
||||||
name = "types_whitelist",
|
package_group(name = "types_whitelist")
|
||||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Packages that use StructuredTensors.
|
# Packages that use StructuredTensors.
|
||||||
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
||||||
@ -719,7 +707,7 @@ tf_cc_shared_object(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||||
"//tensorflow/core:core_cpu_impl",
|
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||||
"//tensorflow/core:framework_internal_impl",
|
"//tensorflow/core:framework_internal_impl",
|
||||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
||||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||||
|
@ -217,6 +217,8 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/kernels:logging_ops",
|
"//tensorflow/core/kernels:logging_ops",
|
||||||
|
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
|
||||||
|
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
@ -254,6 +256,30 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_shape",
|
||||||
|
srcs = ["tf_shape.cc"],
|
||||||
|
hdrs = ["tf_shape.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":c_api_macros",
|
||||||
|
":tf_shape_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tf_shape_internal",
|
||||||
|
hdrs = ["tf_shape_internal.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
":conversion_macros",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tf_status",
|
name = "tf_status",
|
||||||
srcs = ["tf_status.cc"],
|
srcs = ["tf_status.cc"],
|
||||||
|
@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||||
|
TF_Status* status) {
|
||||||
|
using tensorflow::RecordMutation;
|
||||||
|
mutex_lock l(graph->mu);
|
||||||
|
tensorflow::shape_inference::InferenceContext* ic =
|
||||||
|
graph->refiner.GetContext(&new_src.oper->node);
|
||||||
|
|
||||||
|
if (ic->num_outputs() <= new_src.index) {
|
||||||
|
status->status = tensorflow::errors::OutOfRange(
|
||||||
|
"Cannot update edge. Output index [", new_src.index,
|
||||||
|
"] is greater than the number of total outputs [", ic->num_outputs(),
|
||||||
|
"].");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
||||||
|
|
||||||
|
tensorflow::shape_inference::InferenceContext* ic_dst =
|
||||||
|
graph->refiner.GetContext(&dst.oper->node);
|
||||||
|
if (ic_dst->num_inputs() <= dst.index) {
|
||||||
|
status->status = tensorflow::errors::OutOfRange(
|
||||||
|
"Cannot update edge. Input index [", dst.index,
|
||||||
|
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
||||||
|
"].");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!ic_dst->MergeInput(dst.index, shape)) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
||||||
|
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
||||||
|
&dst.oper->node, dst.index);
|
||||||
|
|
||||||
|
if (TF_GetCode(status) == TF_OK) {
|
||||||
|
// This modification only updates the destination node for
|
||||||
|
// the purposes of running this graph in a session. Thus, we don't
|
||||||
|
// record the source node as being modified.
|
||||||
|
RecordMutation(graph, *dst.oper, "updating input tensor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TF_Server functions ----------------------------------------------
|
// TF_Server functions ----------------------------------------------
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||||
|
@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
|
|||||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
||||||
const char* name, TF_Status* status);
|
const char* name, TF_Status* status);
|
||||||
|
|
||||||
|
// Update edge, switch input/ output in a node
|
||||||
|
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
|
||||||
|
TF_Input dst, TF_Status* status);
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
// In-process TensorFlow server functionality, for use in distributed training.
|
// In-process TensorFlow server functionality, for use in distributed training.
|
||||||
// A Server instance encapsulates a set of devices and a Session target that
|
// A Server instance encapsulates a set of devices and a Session target that
|
||||||
|
@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
|
|||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, UpdateEdge) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
// Make two scalar constants.
|
||||||
|
TF_Operation* one = ScalarConst(1, graph, s, "one");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
TF_Operation* two = ScalarConst(2, graph, s, "two");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Add oper.
|
||||||
|
TF_Operation* add = Add(one, two, graph, s, "add");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Add another oper to the graph.
|
||||||
|
TF_Operation* neg = Neg(add, graph, s, "neg");
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
NodeDef node_def_neg;
|
||||||
|
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||||
|
EXPECT_EQ(string("add"), node_def_neg.input(0));
|
||||||
|
|
||||||
|
// update edge of neg
|
||||||
|
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
|
||||||
|
|
||||||
|
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
|
||||||
|
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
TF_DeleteGraph(graph);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TODO(skyewm): this test currently DCHECKs, change to bad status
|
TODO(skyewm): this test currently DCHECKs, change to bad status
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||||
load(
|
load(
|
||||||
"//tensorflow:tensorflow.bzl",
|
"//tensorflow:tensorflow.bzl",
|
||||||
"if_tpu",
|
"if_libtpu",
|
||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_copts",
|
"tf_copts",
|
||||||
"tf_cuda_cc_test",
|
"tf_cuda_cc_test",
|
||||||
@ -116,7 +116,6 @@ filegroup(
|
|||||||
"immediate_execution_context.h",
|
"immediate_execution_context.h",
|
||||||
"immediate_execution_operation.h",
|
"immediate_execution_operation.h",
|
||||||
"immediate_execution_tensor_handle.h",
|
"immediate_execution_tensor_handle.h",
|
||||||
"mnist_gradients_testutil.h",
|
|
||||||
"tape.h",
|
"tape.h",
|
||||||
"tfe_cancellation_manager_internal.h",
|
"tfe_cancellation_manager_internal.h",
|
||||||
"tfe_context_internal.h",
|
"tfe_context_internal.h",
|
||||||
@ -290,7 +289,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
@ -314,6 +313,7 @@ cc_library(
|
|||||||
":gradients_internal",
|
":gradients_internal",
|
||||||
":gradients_util",
|
":gradients_util",
|
||||||
":tape",
|
":tape",
|
||||||
|
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||||
"//tensorflow/c/experimental/ops:array_ops",
|
"//tensorflow/c/experimental/ops:array_ops",
|
||||||
"//tensorflow/c/experimental/ops:math_ops",
|
"//tensorflow/c/experimental/ops:math_ops",
|
||||||
"//tensorflow/c/experimental/ops:nn_ops",
|
"//tensorflow/c/experimental/ops:nn_ops",
|
||||||
@ -354,7 +354,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/c/tf_tensor_internal.h"
|
#include "tensorflow/c/tf_tensor_internal.h"
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||||
#endif
|
#endif
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
@ -729,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|||||||
|
|
||||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||||
if (opts->use_tfrt) {
|
if (opts->use_tfrt) {
|
||||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||||
#else
|
#else
|
||||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||||
@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetThreadLocalDevicePlacementPolicy(
|
|
||||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
|||||||
// safe to call this function from the async EagerExecutor threads.
|
// safe to call this function from the async EagerExecutor threads.
|
||||||
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
||||||
TFE_Context* ctx) {
|
TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
||||||
context->GetDevicePlacementPolicy());
|
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
||||||
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||||
tensorflow::EagerContext* context =
|
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return context->FindFunctionDef(name) != nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
|
|||||||
// Placement policy which silently copies int32 tensors but not other dtypes.
|
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||||
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||||
} TFE_ContextDevicePlacementPolicy;
|
} TFE_ContextDevicePlacementPolicy;
|
||||||
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
|
// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
|
||||||
|
|
||||||
// Sets the default execution mode (sync/async). Note that this can be
|
// Sets the default execution mode (sync/async). Note that this can be
|
||||||
// overridden per thread using TFE_ContextSetExecutorForThread.
|
// overridden per thread using TFE_ContextSetExecutorForThread.
|
||||||
|
@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
|
|||||||
TestDistributedFunctionCancellation(false);
|
TestDistributedFunctionCancellation(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
// TODO(b/170399182): Update test once an alternative to using the function
|
||||||
|
// optimization hook is in place.
|
||||||
|
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
|
||||||
TestDistributedFunctionCancellation(true);
|
TestDistributedFunctionCancellation(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetShouldStoreGraphs(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||||
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetExecutorForThread(executor->executor());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||||
tensorflow::EagerContext* context =
|
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
return new TFE_Executor(&context->Executor());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||||
tensorflow::EagerContext* context =
|
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||||
context->HostCPU()->parsed_name());
|
tensorflow::unwrap(ctx)->HostCPUParsedName());
|
||||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||||
void* data = tensorflow::port::Malloc(str.length());
|
void* data = tensorflow::port::Malloc(str.length());
|
||||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||||
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
|||||||
|
|
||||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||||
TF_Buffer* buf, TF_Status* status) {
|
TF_Buffer* buf, TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
auto* function_def = context->FindFunctionDef(function_name);
|
|
||||||
if (function_def == nullptr) {
|
if (function_def == nullptr) {
|
||||||
status->status = tensorflow::errors::NotFound(
|
status->status = tensorflow::errors::NotFound(
|
||||||
"Unable to find FunctionDef with name: ", function_name);
|
"Unable to find FunctionDef with name: ", function_name);
|
||||||
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
|||||||
|
|
||||||
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetAllowSoftPlacement(enable);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
tensorflow::EagerContext* context =
|
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
||||||
context->SetLogDevicePlacement(enable);
|
|
||||||
}
|
}
|
||||||
|
@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction(
|
|||||||
&ctx, incoming_gradients, result);
|
&ctx, incoming_gradients, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TapeVSpace::BuildOnesLike(TapeTensor t,
|
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
|
||||||
AbstractTensorHandle** result) const {
|
AbstractTensorHandle** result) const {
|
||||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||||
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
|
||||||
|
@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
|
|||||||
// allow us to trace the data dependencies between operations and hence compute
|
// allow us to trace the data dependencies between operations and hence compute
|
||||||
// gradients.
|
// gradients.
|
||||||
//
|
//
|
||||||
// This also implements `OnesLike` to create the default
|
|
||||||
// incoming gradients for tensors which do not already have an incoming
|
|
||||||
// gradient.
|
|
||||||
//
|
|
||||||
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
||||||
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
||||||
// for each op.
|
// for each op.
|
||||||
@ -233,7 +229,7 @@ class TapeVSpace
|
|||||||
std::vector<AbstractTensorHandle*>* result) const override;
|
std::vector<AbstractTensorHandle*>* result) const override;
|
||||||
|
|
||||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||||
Status BuildOnesLike(TapeTensor t,
|
Status BuildOnesLike(const TapeTensor& t,
|
||||||
AbstractTensorHandle** result) const override;
|
AbstractTensorHandle** result) const override;
|
||||||
|
|
||||||
// Looks up the ID of a Gradient.
|
// Looks up the ID of a Gradient.
|
||||||
|
@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
|||||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||||
|
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes
|
||||||
|
// y = sqrt(inputs[0])
|
||||||
|
// return grad(y, {inputs[0]})
|
||||||
|
Status SqrtGradModel(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs,
|
||||||
|
const GradientRegistry& registry) {
|
||||||
|
TapeVSpace vspace(ctx);
|
||||||
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
|
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||||
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||||
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
|
std::vector<AbstractTensorHandle*> out_grads;
|
||||||
|
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||||
|
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||||
|
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||||
|
/*output_gradients=*/{}, &out_grads,
|
||||||
|
/*build_default_zeros_grads=*/false));
|
||||||
|
for (auto sqrt_output : sqrt_outputs) {
|
||||||
|
sqrt_output->Unref();
|
||||||
|
}
|
||||||
|
outputs[0] = out_grads[0];
|
||||||
|
delete tape;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Computes
|
// Computes
|
||||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||||
// return grad(y, {inputs[0], inputs[1]})
|
// return grad(y, {inputs[0], inputs[1]})
|
||||||
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
|||||||
result_tensor = nullptr;
|
result_tensor = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(CppGradients, TestSqrtGrad) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
AbstractContextPtr ctx;
|
||||||
|
{
|
||||||
|
AbstractContext* ctx_raw = nullptr;
|
||||||
|
Status s =
|
||||||
|
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ctx.reset(ctx_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientRegistry registry;
|
||||||
|
Status s = RegisterGradients(®istry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
// Pseudo-code:
|
||||||
|
//
|
||||||
|
// tape.watch(x)
|
||||||
|
// y = sqrt(x)
|
||||||
|
// outputs = tape.gradient(y, x)
|
||||||
|
std::vector<AbstractTensorHandle*> outputs(1);
|
||||||
|
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||||
|
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
TF_Tensor* result_tensor;
|
||||||
|
s = getValue(outputs[0], &result_tensor);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||||
|
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||||
|
outputs[0]->Unref();
|
||||||
|
TF_DeleteTensor(result_tensor);
|
||||||
|
result_tensor = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||||
// Pseudo-code:
|
// Pseudo-code:
|
||||||
//
|
//
|
||||||
|
@ -29,8 +29,25 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/tstring.h"
|
#include "tensorflow/core/platform/tstring.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
class EagerExecutor;
|
||||||
|
|
||||||
|
// LINT.IfChange
|
||||||
|
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||||
|
enum ContextDevicePlacementPolicy {
|
||||||
|
// Running operations with input tensors on the wrong device will fail.
|
||||||
|
DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||||
|
// Copy the tensor to the right device but log a warning.
|
||||||
|
DEVICE_PLACEMENT_WARN = 1,
|
||||||
|
// Silently copy the tensor, which has a performance cost since the operation
|
||||||
|
// will be blocked till the copy completes. This is the default policy.
|
||||||
|
DEVICE_PLACEMENT_SILENT = 2,
|
||||||
|
// Placement policy which silently copies int32 tensors but not other dtypes.
|
||||||
|
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
|
||||||
|
};
|
||||||
|
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
|
||||||
|
|
||||||
// Abstract interface to a context.
|
// Abstract interface to a context.
|
||||||
//
|
//
|
||||||
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// List attributes of available devices
|
// List attributes of available devices
|
||||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||||
|
|
||||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
|
||||||
|
|
||||||
// Initialize the step resource container for a training step. This is used
|
|
||||||
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
|
||||||
virtual void StartStep() = 0;
|
|
||||||
// Destroy the step resource container for a training step.
|
|
||||||
virtual void EndStep() = 0;
|
|
||||||
|
|
||||||
// Block until all pending nodes are finished.
|
// Block until all pending nodes are finished.
|
||||||
virtual Status AsyncWait() = 0;
|
virtual Status AsyncWait() = 0;
|
||||||
|
|
||||||
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
// already exists.
|
// already exists.
|
||||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||||
|
|
||||||
|
// Find and return a added function by its name.
|
||||||
|
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
|
||||||
|
|
||||||
|
// Return the ParsedName of Host CPU device.
|
||||||
|
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||||
|
|
||||||
|
// Configure soft device placement policy.
|
||||||
|
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||||
|
|
||||||
|
// Configure device placement policy logging.
|
||||||
|
virtual void SetLogDevicePlacement(bool enable) = 0;
|
||||||
|
|
||||||
|
// Sets the device placement policy for the current thread.
|
||||||
|
virtual void SetThreadLocalDevicePlacementPolicy(
|
||||||
|
ContextDevicePlacementPolicy policy) = 0;
|
||||||
|
// Returns the device placement policy for the current thread.
|
||||||
|
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
|
||||||
|
|
||||||
// For LLVM style RTTI.
|
// For LLVM style RTTI.
|
||||||
static bool classof(const AbstractContext* ptr) {
|
static bool classof(const AbstractContext* ptr) {
|
||||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Following are legacy features in TF Eager Runtime.
|
||||||
|
// TODO(tf-runtime): Figure out a way to deprecate following features after
|
||||||
|
// migrated to TFRT.
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Clear pending nodes in thread executors and kernel caches.
|
||||||
|
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||||
|
|
||||||
|
// Initialize the step resource container for a training step. This is used
|
||||||
|
// in current TF runtime. For tfrt, it is used by fallback op handler.
|
||||||
|
virtual void StartStep() = 0;
|
||||||
|
// Destroy the step resource container for a training step.
|
||||||
|
virtual void EndStep() = 0;
|
||||||
|
|
||||||
|
// Return the Eager Executor for current thread. Please note that Eager
|
||||||
|
// Executor is only used in current TF but not in TFRT.
|
||||||
|
virtual EagerExecutor& Executor() = 0;
|
||||||
|
// Update the Eager Executor for current thread.
|
||||||
|
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
|
||||||
|
|
||||||
|
// Configure graph collection in RunMetadata.
|
||||||
|
virtual void SetShouldStoreGraphs(bool value) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||||
: AbstractContext(kind) {}
|
: AbstractContext(kind) {}
|
||||||
|
@ -25,133 +25,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
#include "tensorflow/c/eager/gradients_internal.h"
|
#include "tensorflow/c/eager/gradients_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients_util.h"
|
#include "tensorflow/c/eager/gradients_util.h"
|
||||||
|
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
|
|
||||||
// ========================== Tape Ops ==============================
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using tensorflow::tracing::TracingOperation;
|
|
||||||
|
|
||||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
|
||||||
Status Add(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(add_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
|
||||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
bool transpose_a, bool transpose_b,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
|
||||||
/*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(matmul_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
|
||||||
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
|
||||||
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(mul_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
|
||||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(relu_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
|
|
||||||
// one-hot) and records it on the tape.
|
|
||||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
|
||||||
AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractTensorHandle* scores = inputs[0];
|
|
||||||
AbstractTensorHandle* labels = inputs[1];
|
|
||||||
|
|
||||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
|
||||||
/*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(sm_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 2; // returns loss values and backprop
|
|
||||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===================== Test Models to run =========================
|
//===================== Test Models to run =========================
|
||||||
|
|
||||||
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
registry)); // Compute x+y.
|
TF_RETURN_IF_ERROR(
|
||||||
|
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||||
/*transpose_b=*/false, registry)); // Compute x*y.
|
absl::MakeSpan(mm_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute x*y.
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W2)); // Watch W2.
|
tape->Watch(ToId(W2)); // Watch W2.
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
|
||||||
absl::MakeSpan(temp_outputs), "relu",
|
absl::MakeSpan(temp_outputs),
|
||||||
registry)); // Compute Relu(X*W1)
|
"relu")); // Compute Relu(X*W1)
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
|
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||||
absl::MakeSpan(temp_outputs), "matmul1",
|
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
|
||||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
"matmul1",
|
||||||
registry)); // Compute W2*Relu(X*W1)
|
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||||
|
|
||||||
AbstractTensorHandle* scores = temp_outputs[0];
|
AbstractTensorHandle* scores = temp_outputs[0];
|
||||||
|
|
||||||
temp_outputs.resize(2);
|
temp_outputs.resize(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||||
|
|
||||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||||
|
|
||||||
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W1));
|
tape->Watch(ToId(W1));
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/true,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/true,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
|
|
||||||
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
|
|||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
tape->Watch(ToId(inputs[0])); // Watch X
|
tape->Watch(ToId(inputs[0])); // Watch X
|
||||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"relu0", registry)); // Relu(X)
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||||
|
absl::MakeSpan(relu_outputs),
|
||||||
|
"relu0")); // Relu(X)
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
|
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W1)); // Watch W1.
|
tape->Watch(ToId(W1)); // Watch W1.
|
||||||
tape->Watch(ToId(W2)); // Watch W1.
|
tape->Watch(ToId(W2)); // Watch W1.
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
AbstractTensorHandle* mm = temp_outputs[0];
|
AbstractTensorHandle* mm = temp_outputs[0];
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
|
||||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||||
"relu0", registry));
|
"relu0"));
|
||||||
|
|
||||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
|
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||||
absl::MakeSpan(temp_outputs), "matmul1",
|
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
|
||||||
registry)); // W2*Relu(X*W1)
|
|
||||||
|
|
||||||
AbstractTensorHandle* scores = temp_outputs[0];
|
AbstractTensorHandle* scores = temp_outputs[0];
|
||||||
|
|
||||||
temp_outputs.resize(2);
|
temp_outputs.resize(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
"softmaxloss")); // W2*Relu(X*W1)
|
||||||
|
|
||||||
AbstractTensorHandle* loss = temp_outputs[0];
|
AbstractTensorHandle* loss = temp_outputs[0];
|
||||||
|
|
||||||
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
|
|||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"scalarMul0", registry)); // Compute eta*A
|
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
|
||||||
|
absl::MakeSpan(temp_outputs),
|
||||||
|
"scalarMul0")); // Compute eta*A
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
|
|
||||||
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
delete tape;
|
delete tape;
|
||||||
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"mul0", registry)); // Compute x*y
|
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
|
||||||
|
absl::MakeSpan(temp_outputs),
|
||||||
|
"mul0")); // Compute x*y
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
delete tape;
|
delete tape;
|
||||||
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
registry));
|
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0]; // loss values
|
outputs[0] = temp_outputs[0]; // loss values
|
||||||
|
|
||||||
|
@ -29,45 +29,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
// ========================== Tape Ops ==============================
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
|
||||||
Status Add(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
|
||||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
bool transpose_a, bool transpose_b,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
|
||||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
|
||||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
|
||||||
// tape.
|
|
||||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
|
||||||
AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// ====================== End Tape Ops ============================
|
|
||||||
|
|
||||||
// Computes
|
// Computes
|
||||||
// y = inputs[0] + inputs[1]
|
// y = inputs[0] + inputs[1]
|
||||||
|
@ -100,7 +100,8 @@ class VSpace {
|
|||||||
std::vector<Gradient*>* result) const = 0;
|
std::vector<Gradient*>* result) const = 0;
|
||||||
|
|
||||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||||
virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0;
|
virtual Status BuildOnesLike(const TapeTensor& t,
|
||||||
|
Gradient** result) const = 0;
|
||||||
|
|
||||||
// Looks up the ID of a Gradient.
|
// Looks up the ID of a Gradient.
|
||||||
virtual int64 TensorId(Gradient* tensor) const = 0;
|
virtual int64 TensorId(Gradient* tensor) const = 0;
|
||||||
|
@ -29,6 +29,7 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:env",
|
"//tensorflow/c:env",
|
||||||
|
"//tensorflow/c:logging",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
"//third_party/hadoop:hdfs",
|
"//third_party/hadoop:hdfs",
|
||||||
|
@ -22,11 +22,10 @@ limitations under the License.
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/synchronization/mutex.h"
|
|
||||||
#include "tensorflow/c/env.h"
|
#include "tensorflow/c/env.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/logging.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "third_party/hadoop/hdfs.h"
|
|
||||||
|
|
||||||
// Implementation of a filesystem for HADOOP environments.
|
// Implementation of a filesystem for HADOOP environments.
|
||||||
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
|
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
|
||||||
@ -149,15 +148,20 @@ class LibHDFS {
|
|||||||
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
||||||
if (hdfs_home != nullptr) {
|
if (hdfs_home != nullptr) {
|
||||||
auto JoinPath = [](std::string home, std::string lib) {
|
auto JoinPath = [](std::string home, std::string lib) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
if (home.back() != '\\') home.push_back('\\');
|
||||||
|
return home + "lib\\native\\" + lib;
|
||||||
|
#else
|
||||||
if (home.back() != '/') home.push_back('/');
|
if (home.back() != '/') home.push_back('/');
|
||||||
return home + "lib/native/" + lib;
|
return home + "lib/native/" + lib;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
||||||
TryLoadAndBind(path.c_str(), &handle_, status);
|
TryLoadAndBind(path.c_str(), &handle_, status);
|
||||||
if (TF_GetCode(status) == TF_OK) {
|
if (TF_GetCode(status) == TF_OK) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
|
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,13 +173,15 @@ class LibHDFS {
|
|||||||
void* handle_;
|
void* handle_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We rely on HDFS connection caching here. The HDFS client calls
|
// We implement connection caching in Tensorflow, which can significantly
|
||||||
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
|
// improve performance. Fixes #43187
|
||||||
// internally.
|
hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
||||||
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
const std::string& path, TF_Status* status) {
|
||||||
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
|
||||||
|
|
||||||
|
std::string cacheKey(scheme);
|
||||||
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
|
||||||
if (scheme == "file") {
|
if (scheme == "file") {
|
||||||
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
|
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
|
||||||
@ -200,14 +206,23 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
|
|||||||
SplitArchiveNameAndPath(&path_har, &namenode, status);
|
SplitArchiveNameAndPath(&path_har, &namenode, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||||
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
|
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
|
||||||
|
cacheKey += namenode;
|
||||||
} else {
|
} else {
|
||||||
libhdfs->hdfsBuilderSetNameNode(
|
libhdfs->hdfsBuilderSetNameNode(
|
||||||
builder, namenode.empty() ? "default" : namenode.c_str());
|
builder, namenode.empty() ? "default" : namenode.c_str());
|
||||||
|
cacheKey += namenode;
|
||||||
}
|
}
|
||||||
auto fs = libhdfs->hdfsBuilderConnect(builder);
|
absl::MutexLock l(&hadoop_file->connection_cache_lock);
|
||||||
if (fs == nullptr)
|
if (hadoop_file->connection_cache.find(cacheKey) ==
|
||||||
|
hadoop_file->connection_cache.end()) {
|
||||||
|
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
||||||
|
if (cacheFs == nullptr) {
|
||||||
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
|
||||||
else
|
return cacheFs;
|
||||||
|
}
|
||||||
|
hadoop_file->connection_cache[cacheKey] = cacheFs;
|
||||||
|
}
|
||||||
|
auto fs = hadoop_file->connection_cache[cacheKey];
|
||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
return fs;
|
return fs;
|
||||||
}
|
}
|
||||||
@ -409,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
|
|||||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_read_only_memory_region {
|
namespace tf_read_only_memory_region {
|
||||||
|
// Hadoop doesn't support Readonly Memory Region
|
||||||
// TODO(vnvo2409): Implement later
|
|
||||||
|
|
||||||
} // namespace tf_read_only_memory_region
|
} // namespace tf_read_only_memory_region
|
||||||
|
|
||||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_hadoop_filesystem {
|
namespace tf_hadoop_filesystem {
|
||||||
|
|
||||||
|
HadoopFile::HadoopFile(TF_Status* status)
|
||||||
|
: libhdfs(new LibHDFS(status)),
|
||||||
|
connection_cache_lock(),
|
||||||
|
connection_cache() {}
|
||||||
|
|
||||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||||
filesystem->plugin_filesystem = new LibHDFS(status);
|
filesystem->plugin_filesystem = new HadoopFile(status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cleanup(TF_Filesystem* filesystem) {
|
void Cleanup(TF_Filesystem* filesystem) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
delete libhdfs;
|
delete libhdfs;
|
||||||
|
delete hadoop_file;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_RandomAccessFile* file, TF_Status* status) {
|
TF_RandomAccessFile* file, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -448,8 +469,9 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_WritableFile* file, TF_Status* status) {
|
TF_WritableFile* file, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -465,8 +487,9 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_WritableFile* file, TF_Status* status) {
|
TF_WritableFile* file, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -497,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
|||||||
|
|
||||||
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
void PathExists(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -513,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void Stat(const TF_Filesystem* filesystem, const char* path,
|
void Stat(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_FileStatistics* stats, TF_Status* status) {
|
TF_FileStatistics* stats, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -532,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return -1;
|
if (TF_GetCode(status) != TF_OK) return -1;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -553,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -568,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -583,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -619,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
|||||||
|
|
||||||
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
||||||
const char* dst, TF_Status* status) {
|
const char* dst, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, src, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, src, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
|
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
|
||||||
@ -640,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
|
|||||||
|
|
||||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||||
char*** entries, TF_Status* status) {
|
char*** entries, TF_Status* status) {
|
||||||
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
|
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
|
||||||
auto fs = Connect(libhdfs, path, status);
|
auto libhdfs = hadoop_file->libhdfs;
|
||||||
|
auto fs = Connect(hadoop_file, path, status);
|
||||||
if (TF_GetCode(status) != TF_OK) return -1;
|
if (TF_GetCode(status) != TF_OK) return -1;
|
||||||
|
|
||||||
std::string scheme, namenode, hdfs_path;
|
std::string scheme, namenode, hdfs_path;
|
||||||
@ -677,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
|||||||
return num_entries;
|
return num_entries;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(vnvo2409): Implement later
|
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||||
|
return strdup(uri);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tf_hadoop_filesystem
|
} // namespace tf_hadoop_filesystem
|
||||||
|
|
||||||
@ -685,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
|||||||
const char* uri) {
|
const char* uri) {
|
||||||
TF_SetFilesystemVersionMetadata(ops);
|
TF_SetFilesystemVersionMetadata(ops);
|
||||||
ops->scheme = strdup(uri);
|
ops->scheme = strdup(uri);
|
||||||
|
|
||||||
|
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||||
|
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||||
|
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||||
|
|
||||||
|
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||||
|
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||||
|
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||||
|
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||||
|
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||||
|
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||||
|
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||||
|
|
||||||
|
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||||
|
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||||
|
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
|
||||||
|
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
|
||||||
|
ops->filesystem_ops->new_random_access_file =
|
||||||
|
tf_hadoop_filesystem::NewRandomAccessFile;
|
||||||
|
ops->filesystem_ops->new_writable_file =
|
||||||
|
tf_hadoop_filesystem::NewWritableFile;
|
||||||
|
ops->filesystem_ops->new_appendable_file =
|
||||||
|
tf_hadoop_filesystem::NewAppendableFile;
|
||||||
|
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||||
|
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||||
|
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
|
||||||
|
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
|
||||||
|
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
|
||||||
|
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
|
||||||
|
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
|
||||||
|
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
|
||||||
|
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
|
||||||
|
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
|
||||||
|
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||||
|
@ -15,10 +15,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "third_party/hadoop/hdfs.h"
|
||||||
|
|
||||||
void ParseHadoopPath(const std::string& fname, std::string* scheme,
|
void ParseHadoopPath(const std::string& fname, std::string* scheme,
|
||||||
std::string* namenode, std::string* path);
|
std::string* namenode, std::string* path);
|
||||||
@ -43,6 +46,14 @@ void Close(const TF_WritableFile* file, TF_Status* status);
|
|||||||
} // namespace tf_writable_file
|
} // namespace tf_writable_file
|
||||||
|
|
||||||
namespace tf_hadoop_filesystem {
|
namespace tf_hadoop_filesystem {
|
||||||
|
typedef struct HadoopFile {
|
||||||
|
LibHDFS* libhdfs;
|
||||||
|
absl::Mutex connection_cache_lock;
|
||||||
|
std::map<std::string, hdfsFS> connection_cache
|
||||||
|
ABSL_GUARDED_BY(connection_cache_lock);
|
||||||
|
HadoopFile(TF_Status* status);
|
||||||
|
} HadoopFile;
|
||||||
|
|
||||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||||
void Cleanup(TF_Filesystem* filesystem);
|
void Cleanup(TF_Filesystem* filesystem);
|
||||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
|
@ -352,6 +352,48 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) {
|
|||||||
EXPECT_TF_OK(status_);
|
EXPECT_TF_OK(status_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) {
|
||||||
|
static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1";
|
||||||
|
putenv(set_disable_var);
|
||||||
|
|
||||||
|
const std::string path = TmpDir("ReadWhileOverwriting");
|
||||||
|
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
|
||||||
|
|
||||||
|
const string content1 = "content1";
|
||||||
|
WriteString(path, content1);
|
||||||
|
ASSERT_TF_OK(status_);
|
||||||
|
|
||||||
|
auto reader = GetReader();
|
||||||
|
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
|
||||||
|
reader.get(), status_);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
result.resize(content1.size());
|
||||||
|
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
|
||||||
|
&result[0], status_);
|
||||||
|
result.resize(read);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
EXPECT_EQ(content1, result);
|
||||||
|
|
||||||
|
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
|
||||||
|
string content2 = "overwrite";
|
||||||
|
WriteString(path, content1 + content2);
|
||||||
|
ASSERT_TF_OK(status_);
|
||||||
|
|
||||||
|
result.resize(content2.size());
|
||||||
|
read = tf_random_access_file::Read(reader.get(), content1.size(),
|
||||||
|
content2.size(), &result[0], status_);
|
||||||
|
result.resize(read);
|
||||||
|
EXPECT_TF_OK(status_);
|
||||||
|
EXPECT_EQ(0, result.size());
|
||||||
|
|
||||||
|
static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0";
|
||||||
|
putenv(set_enable_var);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HadoopFileSystemTest, HarSplit) {
|
TEST_F(HadoopFileSystemTest, HarSplit) {
|
||||||
const std::string har_path =
|
const std::string har_path =
|
||||||
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
|
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";
|
||||||
|
@ -24,6 +24,7 @@ using std::vector;
|
|||||||
using tensorflow::ops::Conj;
|
using tensorflow::ops::Conj;
|
||||||
using tensorflow::ops::MatMul;
|
using tensorflow::ops::MatMul;
|
||||||
using tensorflow::ops::Mul;
|
using tensorflow::ops::Mul;
|
||||||
|
using tensorflow::ops::SqrtGrad;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
|
|||||||
AbstractTensorHandlePtr exp_;
|
AbstractTensorHandlePtr exp_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SqrtGradientFunction : public GradientFunction {
|
||||||
|
public:
|
||||||
|
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||||
|
sqrt->Ref();
|
||||||
|
}
|
||||||
|
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||||
|
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||||
|
std::string name = "Sqrt_Grad";
|
||||||
|
grad_outputs->resize(1);
|
||||||
|
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||||
|
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
~SqrtGradientFunction() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractTensorHandlePtr sqrt_;
|
||||||
|
};
|
||||||
|
|
||||||
class MatMulGradientFunction : public GradientFunction {
|
class MatMulGradientFunction : public GradientFunction {
|
||||||
public:
|
public:
|
||||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||||
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
|||||||
return new BackwardFunction(gradient_function, default_gradients);
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||||
|
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||||
|
// For ops with a single output, the gradient function is not called if there
|
||||||
|
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||||
|
// grads in this case.
|
||||||
|
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||||
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,9 +19,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
|
|
||||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -38,3 +38,29 @@ cc_library(
|
|||||||
"//tensorflow/c/eager:gradients_internal",
|
"//tensorflow/c/eager:gradients_internal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tape",
|
||||||
|
hdrs = [
|
||||||
|
"tape_context.h",
|
||||||
|
"tape_operation.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tape_context",
|
||||||
|
":tape_operation",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "pywrap_required_hdrs",
|
||||||
|
srcs = [
|
||||||
|
"tape_context.h",
|
||||||
|
"tape_operation.h",
|
||||||
|
],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
|||||||
return exp_op->Execute(outputs, &num_retvals);
|
return exp_op->Execute(outputs, &num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
|
|||||||
|
|
||||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -91,15 +91,24 @@ cc_library(
|
|||||||
":signature_def_function_metadata",
|
":signature_def_function_metadata",
|
||||||
"//tensorflow/c/eager:immediate_execution_operation",
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "signature_def_function_metadata",
|
name = "signature_def_function_metadata",
|
||||||
|
srcs = [
|
||||||
|
"signature_def_function_metadata.cc",
|
||||||
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"signature_def_function_metadata.h",
|
"signature_def_function_metadata.h",
|
||||||
],
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_spec",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -268,6 +277,20 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_spec",
|
||||||
|
srcs = [
|
||||||
|
"tensor_spec.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"tensor_spec.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "tf_concrete_function_loading_test",
|
name = "tf_concrete_function_loading_test",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -92,6 +92,8 @@ cc_library(
|
|||||||
"//tensorflow/c/eager:immediate_execution_context",
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
"//tensorflow/c/eager:immediate_execution_operation",
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
@ -164,6 +166,8 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/common_runtime/eager:context",
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
@ -30,14 +32,26 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using StructuredValueDictEntry =
|
||||||
|
protobuf::MapPair<std::string, StructuredValue>;
|
||||||
|
|
||||||
|
using NamedParamMap =
|
||||||
|
gtl::FlatMap<StringPiece, const TensorSpecProto*, StringPieceHasher>;
|
||||||
|
|
||||||
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
|
||||||
const PartiallyRevivedObjects& objects) {
|
const PartiallyRevivedObjects& objects) {
|
||||||
for (const auto& id_and_resource : objects.restored_resources) {
|
for (const auto& id_and_resource : objects.restored_resources) {
|
||||||
@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<SignatureDefParam> SignatureDefParamsFromNamedParamMap(
|
||||||
|
const NamedParamMap& params) {
|
||||||
|
// The underlying functiondef associated with the SignatureDef has
|
||||||
|
// nest.flattened inputs and outputs, which are sorted by string key.
|
||||||
|
std::vector<SignatureDefParam> result;
|
||||||
|
result.reserve(params.size());
|
||||||
|
for (const auto& named_param : params) {
|
||||||
|
result.push_back(SignatureDefParam(std::string(named_param.first),
|
||||||
|
TensorSpec(*named_param.second)));
|
||||||
|
}
|
||||||
|
std::sort(result.begin(), result.end(),
|
||||||
|
[](const SignatureDefParam& x, const SignatureDefParam& y) {
|
||||||
|
return x.name() < y.name();
|
||||||
|
});
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureDefArgsFromInputs takes the "canonicalized_input_signature"
|
||||||
|
// field of a SavedConcreteFunction, ensures it conforms to the structure of
|
||||||
|
// tuple(tuple(), dict<string,TensorSpec>()), and "returns" a list of
|
||||||
|
// SignatureDefParams of the SignatureDefFunction's arguments.
|
||||||
|
Status SignatureDefArgsFromInputs(
|
||||||
|
const StructuredValue& canonicalized_input_signature,
|
||||||
|
std::vector<SignatureDefParam>* out) {
|
||||||
|
// Note(bmzhao): canonicalized_input_signature should be a tuple of
|
||||||
|
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
|
||||||
|
// string keys to TensorSpecs.
|
||||||
|
if (!canonicalized_input_signature.has_tuple_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||||
|
"of form tuple(tuple(), dict()), but was instead: \n",
|
||||||
|
canonicalized_input_signature.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const TupleValue& args_kwargs_tuple =
|
||||||
|
canonicalized_input_signature.tuple_value();
|
||||||
|
if (args_kwargs_tuple.values_size() != 2) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature should be "
|
||||||
|
"a tuple of two elements (args, kwargs), but was instead: \n",
|
||||||
|
args_kwargs_tuple.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const StructuredValue& args = args_kwargs_tuple.values(0);
|
||||||
|
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature's args"
|
||||||
|
"should be an empty tuple, but instead got: \n",
|
||||||
|
args.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
|
||||||
|
if (!kwargs.has_dict_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||||
|
"should be a dictionary, but instead got: \n",
|
||||||
|
kwargs.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const DictValue& kwargs_dict = kwargs.dict_value();
|
||||||
|
NamedParamMap result;
|
||||||
|
result.reserve(kwargs_dict.fields_size());
|
||||||
|
|
||||||
|
for (const auto& key_value : kwargs_dict.fields()) {
|
||||||
|
const std::string& key = key_value.first;
|
||||||
|
const StructuredValue& value = key_value.second;
|
||||||
|
if (!value.has_tensor_spec_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's canonicalized_input_signature's kwargs"
|
||||||
|
"dictionary contained a non-tensorspec value for key-value pair: \n",
|
||||||
|
"Key: ", key, "Value: \n", value.DebugString());
|
||||||
|
}
|
||||||
|
result[key] = &value.tensor_spec_value();
|
||||||
|
}
|
||||||
|
|
||||||
|
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||||
|
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureDefReturnsFromOutputs takes the "output_signature" field of a
|
||||||
|
// SavedConcreteFunction, ensures it conforms to the structure of
|
||||||
|
// dict<string,TensorSpec>(), and "returns" a list of SignatureDefParams of the
|
||||||
|
// SignatureDefFunction's returns.
|
||||||
|
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
|
||||||
|
std::vector<SignatureDefParam>* out) {
|
||||||
|
if (!output_signature.has_dict_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's output_signature must be a dictionary, but "
|
||||||
|
"instead got: ",
|
||||||
|
output_signature.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const DictValue& output_dict = output_signature.dict_value();
|
||||||
|
NamedParamMap result;
|
||||||
|
result.reserve(output_dict.fields_size());
|
||||||
|
|
||||||
|
for (const auto& key_value : output_dict.fields()) {
|
||||||
|
const std::string& key = key_value.first;
|
||||||
|
const StructuredValue& value = key_value.second;
|
||||||
|
if (!value.has_tensor_spec_value()) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"SignatureDefFunction's output_signature dictionary contained a "
|
||||||
|
"non-tensorspec value for key-value pair: \n",
|
||||||
|
"Key: ", key, "Value: \n", value.DebugString());
|
||||||
|
}
|
||||||
|
result[key] = &value.tensor_spec_value();
|
||||||
|
}
|
||||||
|
*out = SignatureDefParamsFromNamedParamMap(result);
|
||||||
|
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The implementation takes advantage of the fact that SignatureDefFunction's
|
||||||
|
// "traced" Signature wrapper function always has inputs/outputs of dictionaries
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126
|
||||||
|
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178
|
||||||
|
// Additionally, we take advantage of the fact that the SignatureDefFunction's
|
||||||
|
// associated functiondef has lexicographically ordered inputs/outputs due to
|
||||||
|
// nest.flatten.
|
||||||
|
Status LoadSignatureDefFunctionMetadata(
|
||||||
|
const SavedConcreteFunction& saved_concrete_function,
|
||||||
|
SignatureDefFunctionMetadata* out) {
|
||||||
|
std::vector<SignatureDefParam> args;
|
||||||
|
TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs(
|
||||||
|
saved_concrete_function.canonicalized_input_signature(), &args));
|
||||||
|
|
||||||
|
std::vector<SignatureDefParam> rets;
|
||||||
|
TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs(
|
||||||
|
saved_concrete_function.output_signature(), &rets));
|
||||||
|
|
||||||
|
*out = SignatureDefFunctionMetadata(std::move(args), std::move(rets));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
// This function finds the necessary captures, then forwards to the builder
|
// This function finds the necessary captures, then forwards to the builder
|
||||||
// method
|
// method
|
||||||
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
|
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
|
||||||
@ -162,10 +312,14 @@ Status CreateSignatureDefFunction(
|
|||||||
&capture_handle));
|
&capture_handle));
|
||||||
captures.push_back(capture_handle);
|
captures.push_back(capture_handle);
|
||||||
}
|
}
|
||||||
// TODO(bmzhao): Create Metadata here
|
|
||||||
|
SignatureDefFunctionMetadata metadata;
|
||||||
|
TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata(
|
||||||
|
*builder.saved_concrete_func, &metadata));
|
||||||
|
|
||||||
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
|
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
|
||||||
/*captures=*/std::move(captures),
|
/*captures=*/std::move(captures),
|
||||||
/*metadata=*/{},
|
/*metadata=*/std::move(metadata),
|
||||||
/*ctx=*/ctx,
|
/*ctx=*/ctx,
|
||||||
/*out=*/out);
|
/*out=*/out);
|
||||||
}
|
}
|
||||||
@ -378,6 +532,7 @@ Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx,
|
|||||||
revived->variables = std::move(variables);
|
revived->variables = std::move(variables);
|
||||||
revived->assets = std::move(assets);
|
revived->assets = std::move(assets);
|
||||||
revived->constants = std::move(constants);
|
revived->constants = std::move(constants);
|
||||||
|
revived->signatures_map = std::move(signatures_map);
|
||||||
|
|
||||||
// 3b. Move over resources.
|
// 3b. Move over resources.
|
||||||
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));
|
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));
|
||||||
|
@ -36,7 +36,14 @@ namespace tensorflow {
|
|||||||
// Notably, resources and functions can be in a state where they reference
|
// Notably, resources and functions can be in a state where they reference
|
||||||
// other resources/functions that have not been constructed yet. We collect
|
// other resources/functions that have not been constructed yet. We collect
|
||||||
// *all* objects in a partially valid state here, then properly initialize
|
// *all* objects in a partially valid state here, then properly initialize
|
||||||
// resources and functions.
|
// resources and functions. Implementation-wise, PartiallyRevivedObjects
|
||||||
|
// contains maps keyed by the node number of the SavedObjectGraph, and map to an
|
||||||
|
// object of the corresponding type. So, if node 2 in the object graph is a
|
||||||
|
// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a
|
||||||
|
// tensorflow::Variable object. The only exception to this is the
|
||||||
|
// "signatures_map", which is keyed by the "signature" key
|
||||||
|
// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89),
|
||||||
|
// and maps to the SignatureDefFunction node in the SavedObjectGraph.
|
||||||
struct PartiallyRevivedObjects {
|
struct PartiallyRevivedObjects {
|
||||||
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
|
||||||
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
|
||||||
@ -44,6 +51,7 @@ struct PartiallyRevivedObjects {
|
|||||||
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
|
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
|
||||||
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
|
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
|
||||||
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
|
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
|
||||||
|
gtl::FlatMap<std::string, int> signatures_map;
|
||||||
|
|
||||||
Status Build(ImmediateExecutionContext* ctx,
|
Status Build(ImmediateExecutionContext* ctx,
|
||||||
const SavedObjectGraph& obj_graph, RevivedObjects* revived);
|
const SavedObjectGraph& obj_graph, RevivedObjects* revived);
|
||||||
|
@ -44,6 +44,7 @@ struct RevivedObjects {
|
|||||||
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
|
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
|
||||||
signature_def_functions;
|
signature_def_functions;
|
||||||
gtl::FlatMap<int, RestoredResource> restored_resources;
|
gtl::FlatMap<int, RestoredResource> restored_resources;
|
||||||
|
gtl::FlatMap<std::string, int> signatures_map;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,8 +20,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
|||||||
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
Status Variable::CreateUninitialized(
|
||||||
DataType dtype, TensorShape shape,
|
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||||
absl::optional<std::string> name,
|
absl::optional<std::string> name, const char* raw_device_name,
|
||||||
const char* raw_device_name,
|
const std::vector<std::string>& component_devices,
|
||||||
std::unique_ptr<Variable>* output) {
|
std::unique_ptr<Variable>* output) {
|
||||||
ImmediateTensorHandlePtr handle;
|
ImmediateTensorHandlePtr handle;
|
||||||
|
|
||||||
|
if (component_devices.empty()) {
|
||||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||||
ctx, dtype, shape, raw_device_name, &handle));
|
ctx, dtype, shape, raw_device_name, &handle));
|
||||||
|
output->reset(
|
||||||
|
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tensorflow::isa<EagerContext>(ctx)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Can only load distributed variables with EagerContext.");
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
|
||||||
|
|
||||||
|
std::vector<TensorHandle*> handles;
|
||||||
|
for (const auto& device : component_devices) {
|
||||||
|
ImmediateTensorHandlePtr handlePtr;
|
||||||
|
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||||
|
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
|
||||||
|
&handlePtr));
|
||||||
|
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
|
||||||
|
return errors::Internal("Returned replica handle has unsupported type.");
|
||||||
|
}
|
||||||
|
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
|
||||||
|
}
|
||||||
|
TensorHandle* packed_handle;
|
||||||
|
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||||
|
std::move(handles), eager_ctx, &packed_handle));
|
||||||
|
// The call to `CreatePackedHandle` incremented the handles' reference count,
|
||||||
|
// which we must now decrement to make the packed handle the owner of those
|
||||||
|
// handles. We can't loop through the `handles` vector because it was
|
||||||
|
// `std::move`d in the call above.
|
||||||
|
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
|
||||||
|
TensorHandle* component;
|
||||||
|
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
|
||||||
|
component->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
handle.reset(packed_handle);
|
||||||
output->reset(
|
output->reset(
|
||||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||||
return Status();
|
return Status();
|
||||||
|
@ -34,10 +34,10 @@ class Variable : public TensorHandleConvertible {
|
|||||||
public:
|
public:
|
||||||
// Creates an uninitialized resource variable. Note that a caller must
|
// Creates an uninitialized resource variable. Note that a caller must
|
||||||
// call "assign" to associate a value with the variable.
|
// call "assign" to associate a value with the variable.
|
||||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
static Status CreateUninitialized(
|
||||||
DataType dtype, TensorShape shape,
|
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||||
absl::optional<std::string> name,
|
absl::optional<std::string> name, const char* raw_device_name,
|
||||||
const char* raw_device_name,
|
const std::vector<std::string>& component_devices,
|
||||||
std::unique_ptr<Variable>* output);
|
std::unique_ptr<Variable>* output);
|
||||||
|
|
||||||
// The dtype of the underlying variable.
|
// The dtype of the underlying variable.
|
||||||
|
@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
|||||||
const std::string& name = variable.name();
|
const std::string& name = variable.name();
|
||||||
tensorflow::TensorShape shape(variable.shape());
|
tensorflow::TensorShape shape(variable.shape());
|
||||||
tensorflow::DataType dtype = variable.dtype();
|
tensorflow::DataType dtype = variable.dtype();
|
||||||
|
std::vector<std::string> component_devices;
|
||||||
|
|
||||||
|
for (const auto& component :
|
||||||
|
variable.experimental_distributed_variable_components()) {
|
||||||
|
component_devices.push_back(component.device());
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
||||||
ctx, dtype, shape, name,
|
ctx, dtype, shape, name,
|
||||||
variable.device().empty() ? nullptr : variable.device().c_str(), output));
|
variable.device().empty() ? nullptr : variable.device().c_str(),
|
||||||
|
component_devices, output));
|
||||||
return Status();
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -519,6 +526,8 @@ Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
objects->signatures_map = std::move(signatures_map);
|
||||||
|
|
||||||
return Status();
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
|||||||
Status status;
|
Status status;
|
||||||
std::unique_ptr<Variable> var;
|
std::unique_ptr<Variable> var;
|
||||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||||
absl::nullopt, nullptr, &var));
|
absl::nullopt, nullptr, {}, &var));
|
||||||
|
|
||||||
// Create a TensorHandle
|
// Create a TensorHandle
|
||||||
ImmediateTensorHandlePtr expected_handle =
|
ImmediateTensorHandlePtr expected_handle =
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec)
|
||||||
|
: name_(std::move(name)), spec_(std::move(spec)) {}
|
||||||
|
|
||||||
|
const std::string& SignatureDefParam::name() const { return name_; }
|
||||||
|
|
||||||
|
const TensorSpec& SignatureDefParam::spec() const { return spec_; }
|
||||||
|
|
||||||
|
SignatureDefFunctionMetadata::SignatureDefFunctionMetadata(
|
||||||
|
std::vector<SignatureDefParam> arguments,
|
||||||
|
std::vector<SignatureDefParam> returns)
|
||||||
|
: arguments_(std::move(arguments)), returns_(std::move(returns)) {}
|
||||||
|
|
||||||
|
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::arguments()
|
||||||
|
const {
|
||||||
|
return arguments_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::returns()
|
||||||
|
const {
|
||||||
|
return returns_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -16,10 +16,42 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// SignatureDefParam represents a named Tensor input or output to a
|
||||||
|
// SignatureDefFunction.
|
||||||
|
class SignatureDefParam {
|
||||||
|
public:
|
||||||
|
SignatureDefParam(std::string name, TensorSpec spec);
|
||||||
|
|
||||||
|
const std::string& name() const;
|
||||||
|
|
||||||
|
const TensorSpec& spec() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name_;
|
||||||
|
TensorSpec spec_;
|
||||||
|
};
|
||||||
|
|
||||||
class SignatureDefFunctionMetadata {
|
class SignatureDefFunctionMetadata {
|
||||||
// TODO(bmzhao): Fill in with fields as necessary
|
public:
|
||||||
|
SignatureDefFunctionMetadata() = default;
|
||||||
|
SignatureDefFunctionMetadata(std::vector<SignatureDefParam> arguments,
|
||||||
|
std::vector<SignatureDefParam> returns);
|
||||||
|
|
||||||
|
const std::vector<SignatureDefParam>& arguments() const;
|
||||||
|
const std::vector<SignatureDefParam>& returns() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<SignatureDefParam> arguments_;
|
||||||
|
std::vector<SignatureDefParam> returns_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
38
tensorflow/c/experimental/saved_model/core/tensor_spec.cc
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TensorSpec::TensorSpec()
|
||||||
|
: shape_(std::initializer_list<int64>()), dtype_(DT_FLOAT) {}
|
||||||
|
|
||||||
|
TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype)
|
||||||
|
: shape_(std::move(shape)), dtype_(dtype) {}
|
||||||
|
|
||||||
|
TensorSpec::TensorSpec(const TensorSpecProto& proto)
|
||||||
|
: shape_(proto.shape()), dtype_(proto.dtype()) {}
|
||||||
|
|
||||||
|
const PartialTensorShape& TensorSpec::shape() const { return shape_; }
|
||||||
|
|
||||||
|
DataType TensorSpec::dtype() const { return dtype_; }
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
51
tensorflow/c/experimental/saved_model/core/tensor_spec.h
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Note(bmzhao): TensorSpec deliberately does not store the "name" from a
|
||||||
|
// TensorSpecProto. From edloper@, "Names should really be associated with
|
||||||
|
// parameters, not the tensors inside those parameters. This would be
|
||||||
|
// inconsistent with the corresponding Python class, but I don't think that's
|
||||||
|
// necessarily a problem. If it turns out later that we really need a name
|
||||||
|
// attribute here, we can always add it back in; but let's see how far we can
|
||||||
|
// get without it."
|
||||||
|
class TensorSpec {
|
||||||
|
public:
|
||||||
|
// Constructs a scalar, DT_FLOAT TensorSpec
|
||||||
|
TensorSpec();
|
||||||
|
|
||||||
|
TensorSpec(PartialTensorShape shape, DataType dtype);
|
||||||
|
|
||||||
|
explicit TensorSpec(const TensorSpecProto& proto);
|
||||||
|
|
||||||
|
const PartialTensorShape& shape() const;
|
||||||
|
DataType dtype() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
PartialTensorShape shape_;
|
||||||
|
DataType dtype_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
|
@ -192,9 +192,23 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
|
|||||||
|
|
||||||
Status TFSavedModelAPI::GetSignatureDefFunction(
|
Status TFSavedModelAPI::GetSignatureDefFunction(
|
||||||
const std::string& signature_def_key, SignatureDefFunction** function) {
|
const std::string& signature_def_key, SignatureDefFunction** function) {
|
||||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
auto signatures_iter =
|
||||||
return errors::Unimplemented(
|
revived_objects_.signatures_map.find(signature_def_key);
|
||||||
"Retrieving SignatureDef functions is unimplemented currently");
|
if (signatures_iter == revived_objects_.signatures_map.end()) {
|
||||||
|
return errors::NotFound("No signature with key ", signature_def_key,
|
||||||
|
" was found");
|
||||||
|
}
|
||||||
|
int node = signatures_iter->second;
|
||||||
|
|
||||||
|
auto function_iter = revived_objects_.signature_def_functions.find(node);
|
||||||
|
if (function_iter == revived_objects_.signature_def_functions.end()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Unable to find SignatureDefFunction associated with key ",
|
||||||
|
signature_def_key, " despite key being valid.");
|
||||||
|
}
|
||||||
|
|
||||||
|
*function = function_iter->second.get();
|
||||||
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
|
||||||
|
@ -224,6 +224,8 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":signature_def_function_metadata_type",
|
":signature_def_function_metadata_type",
|
||||||
|
":signature_def_param_list",
|
||||||
|
":signature_def_param_list_type",
|
||||||
"//tensorflow/c:c_api_macros",
|
"//tensorflow/c:c_api_macros",
|
||||||
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
],
|
],
|
||||||
@ -240,6 +242,104 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param",
|
||||||
|
srcs = [
|
||||||
|
"signature_def_param.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":signature_def_param_type",
|
||||||
|
":tensor_spec",
|
||||||
|
":tensor_spec_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c:tf_shape_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param_type",
|
||||||
|
hdrs = [
|
||||||
|
"signature_def_param_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param_list",
|
||||||
|
srcs = [
|
||||||
|
"signature_def_param_list.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":signature_def_param",
|
||||||
|
":signature_def_param_list_type",
|
||||||
|
":signature_def_param_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "signature_def_param_list_type",
|
||||||
|
hdrs = [
|
||||||
|
"signature_def_param_list_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_spec",
|
||||||
|
srcs = [
|
||||||
|
"tensor_spec.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:tensor_spec.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_spec_type",
|
||||||
|
"//tensorflow/c:c_api_macros",
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
"//tensorflow/c:tf_shape",
|
||||||
|
"//tensorflow/c:tf_shape_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tensor_spec_type",
|
||||||
|
hdrs = [
|
||||||
|
"tensor_spec_type.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:conversion_macros",
|
||||||
|
"//tensorflow/c:tf_shape_internal",
|
||||||
|
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "saved_model_api_test",
|
name = "saved_model_api_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -252,6 +352,8 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":saved_model_api_type",
|
":saved_model_api_type",
|
||||||
|
"//tensorflow/c:tf_datatype",
|
||||||
|
"//tensorflow/c:tf_shape",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_tensor",
|
"//tensorflow/c:tf_tensor",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
@ -260,6 +362,11 @@ tf_cc_test(
|
|||||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
|
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
|
||||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_function",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list",
|
||||||
|
"//tensorflow/c/experimental/saved_model/public:tensor_spec",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
@ -24,6 +24,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/c/tf_shape.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
@ -143,6 +150,146 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This tests running the "serving_default" SignatureDefFunction from the
|
||||||
|
// VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
|
||||||
|
// protobuf in the metagraph looks like:
|
||||||
|
// signature_def: {
|
||||||
|
// key : "serving_default"
|
||||||
|
// value: {
|
||||||
|
// inputs: {
|
||||||
|
// key : "a"
|
||||||
|
// value: {
|
||||||
|
// name : "serving_default_a:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: {
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// inputs: {
|
||||||
|
// key : "b"
|
||||||
|
// value: {
|
||||||
|
// name : "serving_default_b:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: {
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// outputs: {
|
||||||
|
// key : "output_0"
|
||||||
|
// value: {
|
||||||
|
// name : "StatefulPartitionedCall:0"
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// tensor_shape: {
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// method_name: "tensorflow/serving/predict"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
bool use_tfrt = GetParam();
|
||||||
|
if (use_tfrt) {
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||||
|
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||||
|
|
||||||
|
TF_SavedModel* saved_model =
|
||||||
|
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TF_SignatureDefFunction* serving_default =
|
||||||
|
TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_SignatureDefFunctionMetadata* metadata =
|
||||||
|
TF_SignatureDefFunctionGetMetadata(serving_default);
|
||||||
|
|
||||||
|
const TF_SignatureDefParamList* args =
|
||||||
|
TF_SignatureDefFunctionMetadataArgs(metadata);
|
||||||
|
const TF_SignatureDefParamList* returns =
|
||||||
|
TF_SignatureDefFunctionMetadataReturns(metadata);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
|
||||||
|
const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
|
||||||
|
const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
|
||||||
|
const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
|
||||||
|
|
||||||
|
// Input "a" is a scalar, float32 tensor
|
||||||
|
EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
|
||||||
|
EXPECT_EQ(0, TF_ShapeDims(shape_a));
|
||||||
|
|
||||||
|
const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
|
||||||
|
const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
|
||||||
|
const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
|
||||||
|
|
||||||
|
// Input "b" is a scalar, float32 tensor
|
||||||
|
EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
|
||||||
|
EXPECT_EQ(0, TF_ShapeDims(shape_b));
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
|
||||||
|
|
||||||
|
const TF_SignatureDefParam* param_out =
|
||||||
|
TF_SignatureDefParamListGet(returns, 0);
|
||||||
|
const TF_TensorSpec* tensor_spec_out =
|
||||||
|
TF_SignatureDefParamTensorSpec(param_out);
|
||||||
|
const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
|
||||||
|
|
||||||
|
// Output "output_0" is a scalar, float32 tensor
|
||||||
|
EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
|
||||||
|
EXPECT_EQ(0, TF_ShapeDims(shape_out));
|
||||||
|
|
||||||
|
std::vector<TFE_TensorHandle*> compute_fn_inputs;
|
||||||
|
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
|
||||||
|
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
|
||||||
|
compute_fn_inputs.push_back(input_a);
|
||||||
|
compute_fn_inputs.push_back(input_b);
|
||||||
|
|
||||||
|
TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
|
||||||
|
serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
std::vector<TFE_TensorHandle*> compute_fn_outputs(
|
||||||
|
TF_SignatureDefParamListSize(returns));
|
||||||
|
int num_retvals = TF_SignatureDefParamListSize(returns);
|
||||||
|
|
||||||
|
TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
EXPECT_EQ(TF_NumDims(result), 0);
|
||||||
|
float output_value = *static_cast<float*>(TF_TensorData(result));
|
||||||
|
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
|
||||||
|
EXPECT_FLOAT_EQ(output_value, 8.0);
|
||||||
|
|
||||||
|
TF_DeleteTensor(result);
|
||||||
|
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
|
||||||
|
TFE_DeleteTensorHandle(input_a);
|
||||||
|
TFE_DeleteTensorHandle(input_b);
|
||||||
|
TFE_DeleteOp(serving_default_op);
|
||||||
|
TF_DeleteSavedModel(saved_model);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
@ -16,5 +16,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||||
|
|
||||||
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
|
||||||
|
|
||||||
// TODO(bmzhao): Add getter functions here as necessary.
|
extern "C" {
|
||||||
|
|
||||||
|
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataArgs(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(list)->arguments());
|
||||||
|
}
|
||||||
|
|
||||||
|
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataReturns(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(list)->returns());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
extern const char* TF_SignatureDefParamName(const TF_SignatureDefParam* param) {
|
||||||
|
return tensorflow::unwrap(param)->name().c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
|
||||||
|
const TF_SignatureDefParam* param) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(param)->spec());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
extern size_t TF_SignatureDefParamListSize(
|
||||||
|
const TF_SignatureDefParamList* list) {
|
||||||
|
return tensorflow::unwrap(list)->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
|
||||||
|
const TF_SignatureDefParamList* list, int i) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(list)->at(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
|
||||||
|
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(std::vector<SignatureDefParam>,
|
||||||
|
TF_SignatureDefParamList)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
|
||||||
|
|
||||||
|
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
|
@ -0,0 +1,32 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
|
||||||
|
#include "tensorflow/c/tf_shape_internal.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
TF_DataType TF_TensorSpecDataType(const TF_TensorSpec* spec) {
|
||||||
|
return static_cast<TF_DataType>(tensorflow::unwrap(spec)->dtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
const TF_Shape* TF_TensorSpecShape(const TF_TensorSpec* spec) {
|
||||||
|
return tensorflow::wrap(&tensorflow::unwrap(spec)->shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end extern "C"
|
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
|
||||||
|
|
||||||
|
typedef struct TF_TensorSpec TF_TensorSpec;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec)
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
|
@ -28,6 +28,9 @@ exports_files(
|
|||||||
"saved_model_api.h",
|
"saved_model_api.h",
|
||||||
"signature_def_function.h",
|
"signature_def_function.h",
|
||||||
"signature_def_function_metadata.h",
|
"signature_def_function_metadata.h",
|
||||||
|
"signature_def_param.h",
|
||||||
|
"signature_def_param_list.h",
|
||||||
|
"tensor_spec.h",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||||
)
|
)
|
||||||
@ -45,6 +48,9 @@ cc_library(
|
|||||||
":saved_model_api",
|
":saved_model_api",
|
||||||
":signature_def_function",
|
":signature_def_function",
|
||||||
":signature_def_function_metadata",
|
":signature_def_function_metadata",
|
||||||
|
":signature_def_param",
|
||||||
|
":signature_def_param_list",
|
||||||
|
":tensor_spec",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -77,3 +83,18 @@ alias(
|
|||||||
name = "signature_def_function_metadata",
|
name = "signature_def_function_metadata",
|
||||||
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
|
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "signature_def_param",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param",
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "signature_def_param_list",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param_list",
|
||||||
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "tensor_spec",
|
||||||
|
actual = "//tensorflow/c/experimental/saved_model/internal:tensor_spec",
|
||||||
|
)
|
||||||
|
@ -23,6 +23,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
@ -24,6 +27,18 @@ extern "C" {
|
|||||||
// SavedModel.
|
// SavedModel.
|
||||||
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
|
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
|
||||||
|
|
||||||
|
// Retrieves the arguments of the SignatureDefFunction. The caller is not
|
||||||
|
// responsible for freeing the returned pointer.
|
||||||
|
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
|
||||||
|
TF_SignatureDefFunctionMetadataArgs(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list);
|
||||||
|
|
||||||
|
// Retrieves the returns of the SignatureDefFunction. The caller is not
|
||||||
|
// responsible for freeing the returned pointer.
|
||||||
|
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
|
||||||
|
TF_SignatureDefFunctionMetadataReturns(
|
||||||
|
const TF_SignatureDefFunctionMetadata* list);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // end extern "C"
|
} // end extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type that containing metadata of an input/output of a
|
||||||
|
// TF_SignatureDefFunction loaded from a SavedModel.
|
||||||
|
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
|
||||||
|
|
||||||
|
// Returns the name of the given parameter. The caller is not responsible for
|
||||||
|
// freeing the returned char*.
|
||||||
|
TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName(
|
||||||
|
const TF_SignatureDefParam* param);
|
||||||
|
|
||||||
|
// Returns the TensorSpec associated with the given parameter. The caller is
|
||||||
|
// not reponsible for freeing the returned TF_TensorSpec*.
|
||||||
|
TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
|
||||||
|
const TF_SignatureDefParam* param);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
|
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type that containing metadata of an input/output of a
|
||||||
|
// ConcreteFunction loaded from a SavedModel.
|
||||||
|
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
|
||||||
|
|
||||||
|
// Returns the size of `list`.
|
||||||
|
TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize(
|
||||||
|
const TF_SignatureDefParamList* list);
|
||||||
|
|
||||||
|
// Returns the `i`th TF_SignatureDefParam in the list.
|
||||||
|
TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
|
||||||
|
const TF_SignatureDefParamList* list, int i);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
|
46
tensorflow/c/experimental/saved_model/public/tensor_spec.h
Normal file
46
tensorflow/c/experimental/saved_model/public/tensor_spec.h
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
|
#include "tensorflow/c/tf_shape.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
// An opaque type corresponding to TensorSpec
|
||||||
|
typedef struct TF_TensorSpec TF_TensorSpec;
|
||||||
|
|
||||||
|
// Returns the dtype associated with the TensorSpec.
|
||||||
|
TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType(
|
||||||
|
const TF_TensorSpec* spec);
|
||||||
|
|
||||||
|
// Returns the shape associated with the TensorSpec. The returned Shape is not
|
||||||
|
// owned by the caller. Caller must not call TF_DeleteShape on the returned
|
||||||
|
// shape.
|
||||||
|
TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape(
|
||||||
|
const TF_TensorSpec* spec);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // end extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
|
@ -22,6 +22,8 @@ cc_library(
|
|||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:regexp",
|
||||||
|
"//tensorflow/core/platform:strcat",
|
||||||
"//tensorflow/stream_executor:executor_cache",
|
"//tensorflow/stream_executor:executor_cache",
|
||||||
"//tensorflow/stream_executor:multi_platform_manager",
|
"//tensorflow/stream_executor:multi_platform_manager",
|
||||||
"//tensorflow/stream_executor:platform",
|
"//tensorflow/stream_executor:platform",
|
||||||
|
@ -27,7 +27,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
#include "tensorflow/stream_executor/executor_cache.h"
|
#include "tensorflow/stream_executor/executor_cache.h"
|
||||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
@ -39,6 +42,8 @@ limitations under the License.
|
|||||||
using tensorflow::StatusFromTF_Status;
|
using tensorflow::StatusFromTF_Status;
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
using tensorflow::StringPiece;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
|
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
|
||||||
@ -58,10 +63,35 @@ namespace {
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
port::Status ValidateDeviceType(StringPiece type) {
|
||||||
|
// Validate device type. Device type must start with a capital letter and
|
||||||
|
// consist of capital letters and underscores. Reasoning behind this decision:
|
||||||
|
// * At the minimum we want to disallow '/' and ':' since
|
||||||
|
// these characters are used in device spec, for e.g.
|
||||||
|
// /job:foo/replica:12/device:GPU:1.
|
||||||
|
// * Underscores seem useful, for e.g. XLA_GPU uses underscores.
|
||||||
|
// * Allowing lowercase might get confusing. For example, say someone
|
||||||
|
// registers a new type called "Gpu". It might be confusing for users that
|
||||||
|
// "Gpu" is not the same device type as "GPU".
|
||||||
|
// Note that lowercase "cpu" and "gpu" are currently supported only for
|
||||||
|
// legacy reasons:
|
||||||
|
// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
|
||||||
|
static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
|
||||||
|
bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
|
||||||
|
if (!matches) {
|
||||||
|
return port::FailedPreconditionError(
|
||||||
|
tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
|
||||||
|
kTfDeviceTypeRegEx->pattern(), "."));
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
port::Status ValidateSPPlatform(const SP_Platform& platform) {
|
port::Status ValidateSPPlatform(const SP_Platform& platform) {
|
||||||
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
|
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
|
||||||
VALIDATE_MEMBER(SP_Platform, platform, name);
|
VALIDATE_MEMBER(SP_Platform, platform, name);
|
||||||
VALIDATE_MEMBER(SP_Platform, platform, type);
|
VALIDATE_MEMBER(SP_Platform, platform, type);
|
||||||
|
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
|
||||||
// `visible_device_count` could be 0 at initialization time.
|
// `visible_device_count` could be 0 at initialization time.
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ limitations under the License.
|
|||||||
// params.device = &device;
|
// params.device = &device;
|
||||||
//
|
//
|
||||||
// /* Plugin code below */
|
// /* Plugin code below */
|
||||||
// constexpr char DEVICE_NAME[] = "MyDevice";
|
// constexpr char DEVICE_NAME[] = "MY_DEVICE";
|
||||||
// constexpr char DEVICE_TYPE[] = "GPU";
|
// constexpr char DEVICE_TYPE[] = "GPU";
|
||||||
//
|
//
|
||||||
// void create_device(const SP_Platform* platform,
|
// void create_device(const SP_Platform* platform,
|
||||||
@ -416,10 +416,15 @@ typedef struct SP_Platform {
|
|||||||
|
|
||||||
void* ext; // free-form data set by plugin
|
void* ext; // free-form data set by plugin
|
||||||
|
|
||||||
// Platform name. Must be null-terminated.
|
// Platform name (also referred to as subtype), for example MY_DEVICE.
|
||||||
|
// The name must start with a capital letter and consist of
|
||||||
|
// capital letters and underscores.
|
||||||
|
// Must be null-terminated.
|
||||||
const char* name;
|
const char* name;
|
||||||
|
|
||||||
// Device type name, for example GPU. Must be null-terminated.
|
// Device type name, for example GPU. Must be null-terminated.
|
||||||
|
// The name must start with a capital letter and consist of
|
||||||
|
// capital letters and underscores.
|
||||||
const char* type;
|
const char* type;
|
||||||
|
|
||||||
// Number of visible devices
|
// Number of visible devices
|
||||||
|
@ -41,9 +41,9 @@ struct SP_Timer_st {
|
|||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int DEVICE_COUNT = 2;
|
constexpr int kDeviceCount = 2;
|
||||||
constexpr char DEVICE_NAME[] = "MyDevice";
|
constexpr char kDeviceName[] = "MY_DEVICE";
|
||||||
constexpr char DEVICE_TYPE[] = "GPU";
|
constexpr char kDeviceType[] = "GPU";
|
||||||
|
|
||||||
/*** Create SP_StreamExecutor (with empty functions) ***/
|
/*** Create SP_StreamExecutor (with empty functions) ***/
|
||||||
void allocate(const SP_Device* const device, uint64_t size,
|
void allocate(const SP_Device* const device, uint64_t size,
|
||||||
@ -190,9 +190,9 @@ void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
|
|||||||
void PopulateDefaultPlatform(SP_Platform* platform,
|
void PopulateDefaultPlatform(SP_Platform* platform,
|
||||||
SP_PlatformFns* platform_fns) {
|
SP_PlatformFns* platform_fns) {
|
||||||
*platform = {SP_PLATFORM_STRUCT_SIZE};
|
*platform = {SP_PLATFORM_STRUCT_SIZE};
|
||||||
platform->name = DEVICE_NAME;
|
platform->name = kDeviceName;
|
||||||
platform->type = DEVICE_TYPE;
|
platform->type = kDeviceType;
|
||||||
platform->visible_device_count = DEVICE_COUNT;
|
platform->visible_device_count = kDeviceCount;
|
||||||
platform_fns->create_device = create_device;
|
platform_fns->create_device = create_device;
|
||||||
platform_fns->destroy_device = destroy_device;
|
platform_fns->destroy_device = destroy_device;
|
||||||
platform_fns->create_device_fns = create_device_fns;
|
platform_fns->create_device_fns = create_device_fns;
|
||||||
@ -218,11 +218,11 @@ TEST(StreamExecutor, SuccessfulRegistration) {
|
|||||||
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||||
TF_ASSERT_OK(status);
|
TF_ASSERT_OK(status);
|
||||||
port::StatusOr<Platform*> maybe_platform =
|
port::StatusOr<Platform*> maybe_platform =
|
||||||
MultiPlatformManager::PlatformWithName("MyDevice");
|
MultiPlatformManager::PlatformWithName("MY_DEVICE");
|
||||||
TF_ASSERT_OK(maybe_platform.status());
|
TF_ASSERT_OK(maybe_platform.status());
|
||||||
Platform* platform = maybe_platform.ConsumeValueOrDie();
|
Platform* platform = maybe_platform.ConsumeValueOrDie();
|
||||||
ASSERT_EQ(platform->Name(), DEVICE_NAME);
|
ASSERT_EQ(platform->Name(), kDeviceName);
|
||||||
ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT);
|
ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount);
|
||||||
|
|
||||||
port::StatusOr<StreamExecutor*> maybe_executor =
|
port::StatusOr<StreamExecutor*> maybe_executor =
|
||||||
platform->ExecutorForDevice(0);
|
platform->ExecutorForDevice(0);
|
||||||
@ -244,6 +244,39 @@ TEST(StreamExecutor, NameNotSet) {
|
|||||||
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
|
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StreamExecutor, InvalidNameWithSemicolon) {
|
||||||
|
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||||
|
TF_Status* const status) -> void {
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||||
|
params->platform->name = "INVALID:NAME";
|
||||||
|
params->destroy_platform = destroy_platform;
|
||||||
|
params->destroy_platform_fns = destroy_platform_fns;
|
||||||
|
};
|
||||||
|
|
||||||
|
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||||
|
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||||
|
EXPECT_THAT(
|
||||||
|
status.error_message(),
|
||||||
|
testing::ContainsRegex("Device name/type 'INVALID:NAME' must match"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StreamExecutor, InvalidNameWithSlash) {
|
||||||
|
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||||
|
TF_Status* const status) -> void {
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
PopulateDefaultPlatform(params->platform, params->platform_fns);
|
||||||
|
params->platform->name = "INVALID/";
|
||||||
|
params->destroy_platform = destroy_platform;
|
||||||
|
params->destroy_platform_fns = destroy_platform_fns;
|
||||||
|
};
|
||||||
|
|
||||||
|
port::Status status = InitStreamExecutorPlugin(plugin_init);
|
||||||
|
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
|
||||||
|
EXPECT_THAT(status.error_message(),
|
||||||
|
testing::ContainsRegex("Device name/type 'INVALID/' must match"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StreamExecutor, CreateDeviceNotSet) {
|
TEST(StreamExecutor, CreateDeviceNotSet) {
|
||||||
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
|
||||||
TF_Status* const status) -> void {
|
TF_Status* const status) -> void {
|
||||||
|
@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
|
|||||||
|
|
||||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
mutex_lock l(graph->mu);
|
TF_UpdateEdge(graph, new_src, dst, status);
|
||||||
tensorflow::shape_inference::InferenceContext* ic =
|
|
||||||
graph->refiner.GetContext(&new_src.oper->node);
|
|
||||||
|
|
||||||
if (ic->num_outputs() <= new_src.index) {
|
|
||||||
status->status = tensorflow::errors::OutOfRange(
|
|
||||||
"Cannot update edge. Output index [", new_src.index,
|
|
||||||
"] is greater than the number of total outputs [", ic->num_outputs(),
|
|
||||||
"].");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
|
|
||||||
|
|
||||||
tensorflow::shape_inference::InferenceContext* ic_dst =
|
|
||||||
graph->refiner.GetContext(&dst.oper->node);
|
|
||||||
if (ic_dst->num_inputs() <= dst.index) {
|
|
||||||
status->status = tensorflow::errors::OutOfRange(
|
|
||||||
"Cannot update edge. Input index [", dst.index,
|
|
||||||
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
|
|
||||||
"].");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!ic_dst->MergeInput(dst.index, shape)) {
|
|
||||||
status->status = tensorflow::errors::InvalidArgument(
|
|
||||||
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
|
|
||||||
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
|
|
||||||
&dst.oper->node, dst.index);
|
|
||||||
|
|
||||||
if (TF_GetCode(status) == TF_OK) {
|
|
||||||
// This modification only updates the destination node for
|
|
||||||
// the purposes of running this graph in a session. Thus, we don't
|
|
||||||
// record the source node as being modified.
|
|
||||||
RecordMutation(graph, *dst.oper, "updating input tensor");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
|
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
|
||||||
@ -136,6 +100,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
|
|||||||
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
auto* out_shape_and_type = handle_data.add_shape_and_type();
|
||||||
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
|
||||||
out_shape_and_type->set_dtype(p.dtype);
|
out_shape_and_type->set_dtype(p.dtype);
|
||||||
|
out_shape_and_type->set_specialized_type(p.specialized_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
string result;
|
string result;
|
||||||
@ -163,7 +128,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
|||||||
status->status =
|
status->status =
|
||||||
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
|
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
|
||||||
|
shape_and_type_proto.specialized_type());
|
||||||
}
|
}
|
||||||
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
||||||
}
|
}
|
||||||
|
39
tensorflow/c/tf_shape.cc
Normal file
39
tensorflow/c/tf_shape.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_shape.h"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_shape_internal.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
TF_Shape* TF_NewShape() {
|
||||||
|
return tensorflow::wrap(new tensorflow::PartialTensorShape());
|
||||||
|
}
|
||||||
|
|
||||||
|
int TF_ShapeDims(const TF_Shape* shape) {
|
||||||
|
return tensorflow::unwrap(shape)->dims();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t TF_ShapeDimSize(const TF_Shape* shape, int d) {
|
||||||
|
return tensorflow::unwrap(shape)->dim_size(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteShape(TF_Shape* shape) { delete tensorflow::unwrap(shape); }
|
||||||
|
|
||||||
|
} // end extern "C"
|
50
tensorflow/c/tf_shape.h
Normal file
50
tensorflow/c/tf_shape.h
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_macros.h"
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_TF_SHAPE_H_
|
||||||
|
#define TENSORFLOW_C_TF_SHAPE_H_
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// An opaque type corresponding to a shape in tensorflow. In the future,
|
||||||
|
// we may expose the ABI of TF_Shape for performance reasons.
|
||||||
|
typedef struct TF_Shape TF_Shape;
|
||||||
|
|
||||||
|
// Return a new, unknown rank shape object. The caller is responsible for
|
||||||
|
// calling TF_DeleteShape to deallocate and destroy the returned shape.
|
||||||
|
TF_CAPI_EXPORT extern TF_Shape* TF_NewShape();
|
||||||
|
|
||||||
|
// Returns the rank of `shape`. If `shape` has unknown rank, returns -1.
|
||||||
|
TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape);
|
||||||
|
|
||||||
|
// Returns the `d`th dimension of `shape`. If `shape` has unknown rank,
|
||||||
|
// invoking this function is undefined behavior. Returns -1 if dimension is
|
||||||
|
// unknown.
|
||||||
|
TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d);
|
||||||
|
|
||||||
|
// Deletes `shape`.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_TF_SHAPE_H_
|
30
tensorflow/c/tf_shape_internal.h
Normal file
30
tensorflow/c/tf_shape_internal.h
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/conversion_macros.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
|
typedef struct TF_Shape TF_Shape;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
|
@ -251,7 +251,6 @@ cc_library_with_android_deps(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_experimental",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -266,7 +265,6 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_experimental",
|
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
@ -15,13 +15,12 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradients.h"
|
||||||
#include "tensorflow/cc/ops/array_ops_internal.h"
|
#include "tensorflow/cc/ops/array_ops_internal.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
|
||||||
#include "tensorflow/cc/framework/gradients.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace {
|
namespace {
|
||||||
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
|
|||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
|
||||||
|
|
||||||
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
|
Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
|
||||||
|
const Operation& op,
|
||||||
const std::vector<Output>& grad_inputs,
|
const std::vector<Output>& grad_inputs,
|
||||||
std::vector<Output>* grad_outputs) {
|
std::vector<Output>* grad_outputs) {
|
||||||
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
|
Input input = Shape(scope, op.input(0));
|
||||||
grad_outputs->push_back(NoGradient());
|
Input input_min = op.input(1);
|
||||||
grad_outputs->push_back(NoGradient());
|
Input input_max = op.input(2);
|
||||||
|
int64 axis;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
|
||||||
|
auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
|
||||||
|
scope, grad_inputs[0], input, input_min, input_max,
|
||||||
|
QuantizeAndDequantizeV4Grad::Axis(axis));
|
||||||
|
grad_outputs->push_back(qdq_v4_grad.input_backprop);
|
||||||
|
grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
|
||||||
|
grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
|
||||||
return scope.status();
|
return scope.status();
|
||||||
}
|
}
|
||||||
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
|
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
|
||||||
|
QuantizeAndDequantizeV4GradHelper);
|
||||||
|
|
||||||
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
|
||||||
const std::vector<Output>& grad_inputs,
|
const std::vector<Output>& grad_inputs,
|
||||||
|
@ -21,10 +21,7 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files([
|
exports_files(["loader.h"])
|
||||||
"LICENSE",
|
|
||||||
"loader.h",
|
|
||||||
])
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "constants",
|
name = "constants",
|
||||||
@ -45,13 +42,15 @@ cc_library(
|
|||||||
name = "reader",
|
name = "reader",
|
||||||
srcs = ["reader.cc"],
|
srcs = ["reader.cc"],
|
||||||
hdrs = ["reader.h"],
|
hdrs = ["reader.h"],
|
||||||
deps = [":constants"] + if_not_mobile([
|
deps = [
|
||||||
|
":constants",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
] + if_not_mobile([
|
||||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||||
# tf_lib depending on the build platform.
|
# tf_lib depending on the build platform.
|
||||||
"@com_google_absl//absl/memory:memory",
|
"@com_google_absl//absl/memory:memory",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,8 +12,6 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "freeze_saved_model",
|
name = "freeze_saved_model",
|
||||||
srcs = ["freeze_saved_model.cc"],
|
srcs = ["freeze_saved_model.cc"],
|
||||||
|
@ -75,7 +75,7 @@ cc_library(
|
|||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//llvm:Target",
|
"@llvm-project//llvm:Target",
|
||||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||||
"//tensorflow/core:regexp_internal",
|
"//tensorflow/core/platform:regexp",
|
||||||
] + if_llvm_system_z_available([
|
] + if_llvm_system_z_available([
|
||||||
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
|
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
|
||||||
]) + if_llvm_aarch64_available([
|
]) + if_llvm_aarch64_available([
|
||||||
|
@ -336,9 +336,9 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:regexp_internal",
|
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:regexp",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
@ -559,9 +559,9 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:regexp_internal",
|
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform:regexp",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -127,7 +127,7 @@ def tf_library(
|
|||||||
"$(location " + tfcompile_tool + ")" +
|
"$(location " + tfcompile_tool + ")" +
|
||||||
" --config=$(location " + config + ")" +
|
" --config=$(location " + config + ")" +
|
||||||
" --dump_fetch_nodes > $@"),
|
" --dump_fetch_nodes > $@"),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
# Run tfcompile on the build host, rather than forge, since it's
|
# Run tfcompile on the build host, rather than forge, since it's
|
||||||
# typically way faster on the local machine.
|
# typically way faster on the local machine.
|
||||||
local = 1,
|
local = 1,
|
||||||
@ -242,7 +242,7 @@ def tf_library(
|
|||||||
" --out_function_object=$(@D)/" + function_object_file +
|
" --out_function_object=$(@D)/" + function_object_file +
|
||||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||||
),
|
),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
# Run tfcompile on the build host since it's typically faster on the
|
# Run tfcompile on the build host since it's typically faster on the
|
||||||
@ -281,7 +281,7 @@ def tf_library(
|
|||||||
" --out_session_module=$(@D)/" + session_module_pb +
|
" --out_session_module=$(@D)/" + session_module_pb +
|
||||||
" " + flags
|
" " + flags
|
||||||
),
|
),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
local = 1,
|
local = 1,
|
||||||
|
@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
|||||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts")
|
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
|
||||||
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
# buildifier: disable=same-origin-load
|
||||||
@ -77,7 +77,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
|
||||||
if_true = [],
|
if_true = [],
|
||||||
),
|
),
|
||||||
@ -114,7 +114,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
||||||
],
|
],
|
||||||
@ -141,7 +141,7 @@ cc_library(
|
|||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
"//tensorflow/core/common_runtime/gpu:gpu_init",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||||
],
|
],
|
||||||
@ -204,7 +204,7 @@ XLA_DEVICE_DEPS = [
|
|||||||
"//tensorflow/core:resource_variable_ops_op_lib",
|
"//tensorflow/core:resource_variable_ops_op_lib",
|
||||||
"//tensorflow/core:sendrecv_ops_op_lib",
|
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||||
"//tensorflow/core:state_ops_op_lib",
|
"//tensorflow/core:state_ops_op_lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/kernels:constant_op",
|
"//tensorflow/core/kernels:constant_op",
|
||||||
"//tensorflow/core/kernels:fifo_queue",
|
"//tensorflow/core/kernels:fifo_queue",
|
||||||
"//tensorflow/core/kernels:function_ops",
|
"//tensorflow/core/kernels:function_ops",
|
||||||
@ -375,7 +375,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
] + if_tpu(
|
] + if_libtpu(
|
||||||
if_false = [
|
if_false = [
|
||||||
"//tensorflow/compiler/mlir:array_container_utils",
|
"//tensorflow/compiler/mlir:array_container_utils",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
@ -435,6 +435,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
@ -1022,10 +1023,10 @@ tf_cc_test(
|
|||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/core:all_kernels",
|
"//tensorflow/core:all_kernels",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:direct_session_internal",
|
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:ops",
|
"//tensorflow/core:ops",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core/common_runtime:direct_session_internal",
|
||||||
"//tensorflow/core/kernels:cwise_op",
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
"//tensorflow/core/kernels:matmul_op",
|
"//tensorflow/core/kernels:matmul_op",
|
||||||
"//tensorflow/core/kernels:partitioned_function_ops",
|
"//tensorflow/core/kernels:partitioned_function_ops",
|
||||||
|
@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
|
||||||
|
const Node& node, absl::string_view attr_name,
|
||||||
|
absl::string_view call_name) {
|
||||||
|
std::vector<NameAttrList> attr_lists;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
|
||||||
|
|
||||||
|
std::vector<NodeDef> out;
|
||||||
|
for (int i = 0; i < attr_lists.size(); i++) {
|
||||||
|
out.emplace_back();
|
||||||
|
NodeDef& inserted = out.back();
|
||||||
|
inserted.set_name(absl::StrCat(call_name, "_", i));
|
||||||
|
inserted.set_op(attr_lists[i].name());
|
||||||
|
*inserted.mutable_attr() = attr_lists[i].attr();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
// Utility which searches for values in a sorted list by scanning over it once.
|
// Utility which searches for values in a sorted list by scanning over it once.
|
||||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||||
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
|
|||||||
return is_compilable;
|
return is_compilable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool RecursiveCompilabilityChecker::IsCompilableCase(
|
||||||
|
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
|
||||||
|
std::vector<StackFrameView>* stack_trace,
|
||||||
|
NameAttrList* encapsulating_function,
|
||||||
|
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||||
|
const {
|
||||||
|
xla::StatusOr<std::vector<NodeDef>> calls =
|
||||||
|
MakeCallNodesFromAttribute(case_node, "branches", "branch");
|
||||||
|
if (!calls.ok()) {
|
||||||
|
VLOG(2) << "Rejecting node " << case_node.name() << ": "
|
||||||
|
<< "missing attribute 'branches'";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_compilable = true;
|
||||||
|
|
||||||
|
for (const NodeDef& call : *calls) {
|
||||||
|
is_compilable &=
|
||||||
|
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
|
||||||
|
uncompilable_nodes);
|
||||||
|
}
|
||||||
|
return is_compilable;
|
||||||
|
}
|
||||||
|
|
||||||
// Tests whether 'while_node' is a completely compilable loop.
|
// Tests whether 'while_node' is a completely compilable loop.
|
||||||
// Every operator in the condition and body functions must be compilable for a
|
// Every operator in the condition and body functions must be compilable for a
|
||||||
// while loop to be compilable.
|
// while loop to be compilable.
|
||||||
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
|
||||||
|
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
|
||||||
|
uncompilable_nodes)) {
|
||||||
|
LogNotCompilable(node, "unsupported case");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (!op_filter_.allow_stateful_rng_ops &&
|
if (!op_filter_.allow_stateful_rng_ops &&
|
||||||
IsStatefulRandomOp(node.type_string())) {
|
IsStatefulRandomOp(node.type_string())) {
|
||||||
absl::string_view uncompilable_reason = "stateful random op";
|
absl::string_view uncompilable_reason = "stateful random op";
|
||||||
|
@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
|
|||||||
// Whether ops known to have numerical accuracy issues should be considered
|
// Whether ops known to have numerical accuracy issues should be considered
|
||||||
// compilable..
|
// compilable..
|
||||||
bool allow_inaccurate_ops = false;
|
bool allow_inaccurate_ops = false;
|
||||||
|
|
||||||
|
// Require the function to be always compilable, regardless whether some
|
||||||
|
// control flow branches might be dead for a given input.
|
||||||
|
bool require_always_compilable = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
||||||
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
|
|||||||
NameAttrList* encapsulating_function,
|
NameAttrList* encapsulating_function,
|
||||||
UncompilableNodesMap* uncompilable_nodes) const;
|
UncompilableNodesMap* uncompilable_nodes) const;
|
||||||
|
|
||||||
|
// Tests whether 'case_node' is compilable. Every operator in all branches
|
||||||
|
// must be compilable.
|
||||||
|
bool IsCompilableCase(const Node& case_node,
|
||||||
|
FunctionLibraryRuntime* lib_runtime,
|
||||||
|
std::vector<StackFrameView>* stack_trace,
|
||||||
|
NameAttrList* encapsulating_function,
|
||||||
|
UncompilableNodesMap* uncompilable_nodes) const;
|
||||||
|
|
||||||
// Returns compilability of node def retrieved from `node`'s attribute with
|
// Returns compilability of node def retrieved from `node`'s attribute with
|
||||||
// name `attr_name`.
|
// name `attr_name`.
|
||||||
bool ExtractNodeDefAndCheckCompilability(
|
bool ExtractNodeDefAndCheckCompilability(
|
||||||
|
@ -34,7 +34,16 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
|
||||||
|
AttrValue attr;
|
||||||
|
for (const char* name : names) {
|
||||||
|
attr.mutable_list()->add_func()->set_name(name);
|
||||||
|
}
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr char kFunctionalIfNodeName[] = "If";
|
constexpr char kFunctionalIfNodeName[] = "If";
|
||||||
|
constexpr char kFunctionalCaseNodeName[] = "Case";
|
||||||
constexpr char kFunctionalWhileNodeName[] = "While";
|
constexpr char kFunctionalWhileNodeName[] = "While";
|
||||||
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
||||||
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
||||||
@ -76,7 +85,11 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
|
|||||||
op_filter_.allow_inaccurate_ops = false;
|
op_filter_.allow_inaccurate_ops = false;
|
||||||
op_filter_.allow_slow_ops = false;
|
op_filter_.allow_slow_ops = false;
|
||||||
|
|
||||||
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
checker_ = CreateCompilabilityChecker();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
|
||||||
|
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||||
device_type_);
|
device_type_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
|||||||
"unsupported op"));
|
"unsupported op"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
|
||||||
|
FunctionDefLibrary flib;
|
||||||
|
*flib.add_function() = FunctionDefHelper::Define(
|
||||||
|
/*Function*/ kUncompilableFunctionName,
|
||||||
|
/*Inputs*/ {"n_a:float"},
|
||||||
|
/*Outputs*/ {"n_c_uncompilable:float"},
|
||||||
|
/*Attributes*/ {},
|
||||||
|
// Node info
|
||||||
|
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
|
||||||
|
*flib.add_function() = FunctionDefHelper::Define(
|
||||||
|
/*Function*/ kUncompilableFunctionTwoName,
|
||||||
|
/*Inputs*/ {"n_a:float"},
|
||||||
|
/*Outputs*/ {"n_d_uncompilable:float"},
|
||||||
|
/*Attribute*/ {},
|
||||||
|
// Node info
|
||||||
|
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
|
||||||
|
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
|
||||||
|
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
|
||||||
|
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputes(
|
||||||
|
{NodeBuilder::NodeOut(placeholder.node())});
|
||||||
|
Node* case_node;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
|
||||||
|
.Input(branch_index.node())
|
||||||
|
.Input(inputes)
|
||||||
|
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
|
||||||
|
kUncompilableFunctionTwoName}))
|
||||||
|
.Attr("Tout", {DT_INT32})
|
||||||
|
.Finalize(root.graph(), &case_node));
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||||
|
|
||||||
|
auto case_node_it = std::find_if(
|
||||||
|
graph->nodes().begin(), graph->nodes().end(),
|
||||||
|
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
|
||||||
|
EXPECT_NE(case_node_it, graph->nodes().end());
|
||||||
|
auto* flib_runtime = GetFunctionLibraryRuntime();
|
||||||
|
|
||||||
|
op_filter_.require_always_compilable = false;
|
||||||
|
checker_ = CreateCompilabilityChecker();
|
||||||
|
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||||
|
op_filter_.require_always_compilable = true;
|
||||||
|
checker_ = CreateCompilabilityChecker();
|
||||||
|
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
||||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||||
Scope root = Scope::NewRootScope().ExitOnError();
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -47,8 +48,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
|
|||||||
|
|
||||||
xla::StatusOr<std::string> GetCompilerIr(
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
absl::string_view func_name, Device* dev,
|
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||||
absl::Span<const Tensor* const> inputs) {
|
absl::Span<const TensorHandle* const> inputs_handles) {
|
||||||
NameAttrList function;
|
NameAttrList function;
|
||||||
function.set_name(std::string{func_name});
|
function.set_name(std::string{func_name});
|
||||||
|
|
||||||
@ -65,6 +66,25 @@ xla::StatusOr<std::string> GetCompilerIr(
|
|||||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||||
|
|
||||||
|
std::deque<Tensor> inputs_storage;
|
||||||
|
std::vector<const Tensor*> inputs;
|
||||||
|
inputs.reserve(inputs_handles.size());
|
||||||
|
for (int i = 0; i < inputs_handles.size(); i++) {
|
||||||
|
const TensorHandle* th = inputs_handles[i];
|
||||||
|
const Tensor* t;
|
||||||
|
// Handle owns the tensor.
|
||||||
|
TF_RETURN_IF_ERROR(th->Tensor(&t));
|
||||||
|
if (absl::c_binary_search(constant_arg_indices, i)) {
|
||||||
|
// Need to make sure it's on the host.
|
||||||
|
inputs_storage.emplace_back(t->dtype(), t->shape());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
|
||||||
|
inputs.push_back(&inputs_storage.back());
|
||||||
|
} else {
|
||||||
|
inputs.push_back(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<VariableInfo> variable_infos;
|
std::vector<VariableInfo> variable_infos;
|
||||||
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||||
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||||
|
@ -24,6 +24,8 @@ namespace tensorflow {
|
|||||||
class ProcessFunctionLibraryRuntime;
|
class ProcessFunctionLibraryRuntime;
|
||||||
class Device;
|
class Device;
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
class TensorHandle;
|
||||||
|
class EagerContext;
|
||||||
|
|
||||||
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
||||||
|
|
||||||
@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
|||||||
// `runtime` on a device `dev` with given `inputs`.
|
// `runtime` on a device `dev` with given `inputs`.
|
||||||
xla::StatusOr<std::string> GetCompilerIr(
|
xla::StatusOr<std::string> GetCompilerIr(
|
||||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||||
absl::string_view func_name, Device* dev,
|
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||||
absl::Span<const Tensor* const> inputs);
|
absl::Span<const TensorHandle* const> inputs);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ XLA_OPS_DEPS = [
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:state_ops_op_lib",
|
"//tensorflow/core:state_ops_op_lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core/platform:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
"//tensorflow/stream_executor:tf_allocator_adapter",
|
"//tensorflow/stream_executor:tf_allocator_adapter",
|
||||||
]
|
]
|
||||||
|
@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RecursiveCompilabilityChecker{
|
RecursiveCompilabilityChecker::OperationFilter filter =
|
||||||
CreateOperationFilter(*registration),
|
CreateOperationFilter(*registration);
|
||||||
DeviceType{registration->compilation_device_name}}
|
filter.require_always_compilable = true;
|
||||||
.IsCompilableNode(*node, lib_runtime)) {
|
|
||||||
|
RecursiveCompilabilityChecker checker(
|
||||||
|
filter, DeviceType{registration->compilation_device_name});
|
||||||
|
|
||||||
|
if (!checker.IsCompilableNode(*node, lib_runtime)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2062,6 +2066,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
|||||||
"XlaSpmdFullToShardShape",
|
"XlaSpmdFullToShardShape",
|
||||||
"XlaSpmdShardToFullShape",
|
"XlaSpmdShardToFullShape",
|
||||||
"XlaSvd",
|
"XlaSvd",
|
||||||
|
"XlaVariadicReduce",
|
||||||
"XlaWhile",
|
"XlaWhile",
|
||||||
"Zeta",
|
"Zeta",
|
||||||
"_Arg",
|
"_Arg",
|
||||||
|
@ -47,7 +47,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/core/util/dump_graph.h"
|
#include "tensorflow/core/util/dump_graph.h"
|
||||||
|
|
||||||
#if !defined(LIBTFTPU)
|
#if !defined(LIBTPU_ON_GCE)
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
|
||||||
#endif
|
#endif
|
||||||
@ -289,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
});
|
});
|
||||||
const ConfigProto* config = ctx->function_library()->config_proto();
|
const ConfigProto* config = ctx->function_library()->config_proto();
|
||||||
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
bool use_mlir = config && config->experimental().enable_mlir_bridge();
|
||||||
#ifdef LIBTFTPU
|
#ifdef LIBTPU_ON_GCE
|
||||||
if (use_mlir && has_tensor_list_arg) {
|
if (use_mlir && has_tensor_list_arg) {
|
||||||
LOG(WARNING) << "MLIR is not supported in this environment.";
|
LOG(WARNING) << "MLIR is not supported in this environment.";
|
||||||
}
|
}
|
||||||
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
|
std::vector<std::string> control_rets;
|
||||||
|
if (result_dtypes.empty()) {
|
||||||
|
control_rets.push_back(node_def.name());
|
||||||
|
}
|
||||||
return CompileGraphToXlaHlo(
|
return CompileGraphToXlaHlo(
|
||||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
|
||||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||||
#endif
|
#endif
|
||||||
|
@ -9,3 +9,31 @@ dialects and utilities for
|
|||||||
3. TF Lite
|
3. TF Lite
|
||||||
|
|
||||||
See [MLIR's website](https://mlir.llvm.org) for complete documentation.
|
See [MLIR's website](https://mlir.llvm.org) for complete documentation.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
Building dialects and utilities here follow the standard approach using
|
||||||
|
`bazel` as the rest of TensorFlow.
|
||||||
|
|
||||||
|
### Using local LLVM repo
|
||||||
|
|
||||||
|
To develop across MLIR core and TensorFlow, it is useful to override the repo
|
||||||
|
to use a local version instead of fetching from head. This can be achieved as
|
||||||
|
below but note, the BUILD files are not automatically generated from or CMake
|
||||||
|
used, so if your change requires a BUILD file change (or you are using a
|
||||||
|
different version of LLVM than set in tensorflow/workspace.bzl's LLVM_COMMIT)
|
||||||
|
then manual BUILD file changes may be required.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
LLVM_SRC=...
|
||||||
|
|
||||||
|
# Create basic workspace file
|
||||||
|
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
|
||||||
|
# and copy over the bazel BUILD files.
|
||||||
|
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
|
||||||
|
cp third_party/mlir/BUILD $LLVM_SRC/mlir
|
||||||
|
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD
|
||||||
|
|
||||||
|
bazel build --override_repository=llvm-project=$LLVM_SRC \
|
||||||
|
-c opt tensorflow/compiler/mlir:tf-opt
|
||||||
|
```
|
||||||
|
@ -48,6 +48,7 @@ filegroup(
|
|||||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
||||||
|
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||||
@ -539,6 +540,8 @@ cc_library(
|
|||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:Shape",
|
||||||
|
"@llvm-project//mlir:ShapeTransforms",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
|
@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Atan operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Atan(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\atan(x) = \atan2(x, 1)
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Sinh operation";
|
let summary = "Sinh operation";
|
||||||
|
@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
|||||||
>];
|
>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
|
||||||
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
|
||||||
|
|
||||||
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
||||||
|
|
||||||
@ -193,12 +196,10 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one",
|
|||||||
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
|
||||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
|
||||||
|
|
||||||
def HLO_ImagOp: HLO_Op<
|
def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
|
||||||
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
let builders = [OpBuilder<
|
DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
HLO_ComplexTensor>, BASE_HLO_ImagOp {
|
||||||
|
|
||||||
let arguments = (ins HLO_ComplexTensor);
|
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
@ -237,12 +238,10 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
|
|||||||
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
|
||||||
BASE_HLO_PopulationCountOp;
|
BASE_HLO_PopulationCountOp;
|
||||||
|
|
||||||
def HLO_RealOp: HLO_Op<
|
def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
|
||||||
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
let builders = [OpBuilder<
|
DeclareOpInterfaceMethods<InferTypeOpInterface>],
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value val">];
|
HLO_ComplexTensor>, BASE_HLO_RealOp {
|
||||||
|
|
||||||
let arguments = (ins HLO_ComplexTensor);
|
|
||||||
let results = (outs HLO_FpTensor);
|
let results = (outs HLO_FpTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
@ -321,12 +320,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
|
|||||||
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
|
||||||
|
|
||||||
def HLO_ComplexOp: HLO_Op<"complex",
|
def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex",
|
||||||
[NoSideEffect, SameOperandsAndResultShape]>,
|
[NoSideEffect, SameOperandsAndResultShape,
|
||||||
|
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
|
||||||
BASE_HLO_ComplexOp {
|
BASE_HLO_ComplexOp {
|
||||||
let builders = [OpBuilder<
|
|
||||||
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
|
|
||||||
|
|
||||||
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
|
||||||
let results = (outs HLO_ComplexTensor);
|
let results = (outs HLO_ComplexTensor);
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
@ -356,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
|||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
||||||
|
|
||||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
||||||
@ -913,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
|||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Make this struct dialect independent so that it can be shared
|
|
||||||
// between HLO and LHLO dialect.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
|
(ins
|
||||||
HLO_Tensor:$lhs,
|
HLO_Tensor:$lhs,
|
||||||
HLO_Tensor:$rhs,
|
HLO_Tensor:$rhs),
|
||||||
// Default value: one for each of the spatial dimension.
|
ConvolutionAttributes<HLO_Dialect>.attributes);
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
@ -1198,14 +1170,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
|
|||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp {
|
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<HLO_Tensor>:$operands,
|
Variadic<HLO_Tensor>:$operands,
|
||||||
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_TensorOrTuple);
|
let results = (outs Variadic<HLO_Tensor>);
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$comparator);
|
let regions = (region SizedRegion<1>:$comparator);
|
||||||
|
|
||||||
@ -1429,4 +1401,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
|
|||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is an op for purposes internal to XLA/GPU.
|
||||||
|
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
|
||||||
|
let arguments = (ins HLO_Tensor:$operand);
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
|
||||||
|
BASE_HLO_ReducePrecisionOp {
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_FpTensor:$operand,
|
||||||
|
I32Attr:$exponent_bits,
|
||||||
|
I32Attr:$mantissa_bits
|
||||||
|
);
|
||||||
|
let results = (outs HLO_FpTensor:$output);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // HLO_OPS
|
#endif // HLO_OPS
|
||||||
|
@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_CbrtOp {
|
||||||
|
string summary = "Cubic root operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns element-wise cubic root of the operand.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_CeilOp {
|
class BASE_HLO_CeilOp {
|
||||||
string summary = "Ceil operator";
|
string summary = "Ceil operator";
|
||||||
|
|
||||||
@ -996,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common convolution attributes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class ConvDimensionNumbersBase<Dialect dialect>
|
||||||
|
: StructAttr<"ConvDimensionNumbers", dialect, [
|
||||||
|
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||||
|
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||||
|
|
||||||
|
let description = "Structure of dimension information for conv op";
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConvolutionAttributes<Dialect dialect> {
|
||||||
|
dag attributes = (ins
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||||
|
// Default value: zero for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$padding,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||||
|
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
|
||||||
|
I64Attr:$feature_group_count,
|
||||||
|
I64Attr:$batch_group_count,
|
||||||
|
HLO_PrecisionConfigAttr:$precision_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_ConvOp {
|
class BASE_HLO_ConvOp {
|
||||||
string summary = "Convolution operator";
|
string summary = "Convolution operator";
|
||||||
|
|
||||||
@ -1336,4 +1383,17 @@ class BASE_HLO_WhileOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_BitcastOp {
|
||||||
|
string summary = "Bitcast operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
This op changes the shape of the input in the way that the physical
|
||||||
|
arranggment of elements are unchanged.
|
||||||
|
|
||||||
|
However, the op needs layout information to make sense of "physical
|
||||||
|
arrangement of elements". Layout support in MHLO is currently under
|
||||||
|
exploration.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // HLO_OPS_BASE
|
#endif // HLO_OPS_BASE
|
||||||
|
@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
|
|||||||
include "mlir/Interfaces/CopyOpInterface.td"
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||||
|
|
||||||
def LHLO_Dialect : Dialect {
|
def LHLO_Dialect : Dialect {
|
||||||
let name = "lmhlo";
|
let name = "lmhlo";
|
||||||
let cppNamespace = "::mlir::lmhlo";
|
let cppNamespace = "::mlir::lmhlo";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LMHLO type definitions.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Any integer tensor types
|
|
||||||
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
|
||||||
|
|
||||||
// Any floating-point tensor types
|
|
||||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
|
||||||
|
|
||||||
// Any integer or floating-point tensor types
|
|
||||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
|
||||||
|
|
||||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LMHLO nullary op definitions.
|
// LMHLO nullary op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -289,6 +264,16 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
|||||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
StrAttr:$call_target_name,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
|
||||||
|
DefaultValuedAttr<StrAttr, "">:$backend_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LMHLO tuple op definitions.
|
// LMHLO tuple op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -335,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
|||||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
"modifies the offset, sizes and strides of a statically shaped memref.
|
modifies the offset, sizes and strides of a statically shaped memref
|
||||||
}];
|
}];
|
||||||
let description = [{
|
let description = [{
|
||||||
Allows to modify the offset, sizes and strides of a statically shaped memref.
|
Casts the statically shaped memref operand to a memref with optionally
|
||||||
|
modified offsets, sizes and strides.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
@ -354,11 +340,10 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
|||||||
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
|
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
|
||||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<"MemRefType resultType, Value operand",
|
||||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
[{
|
||||||
"Value operand", [{
|
$_state.addOperands(operand);
|
||||||
result.addOperands(operand);
|
$_state.types.push_back(resultType);
|
||||||
result.types.push_back(resultType);
|
|
||||||
}]>];
|
}]>];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
@ -400,13 +385,13 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
|||||||
);
|
);
|
||||||
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [
|
||||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, "
|
||||||
"Value operand, ValueRange sizes, ValueRange strides", [{
|
"ValueRange strides", [{
|
||||||
result.addOperands(operand);
|
$_state.addOperands(operand);
|
||||||
result.addOperands(sizes);
|
$_state.addOperands(sizes);
|
||||||
result.addOperands(strides);
|
$_state.addOperands(strides);
|
||||||
result.types.push_back(resultType);
|
$_state.types.push_back(resultType);
|
||||||
}]>];
|
}]>];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
@ -582,40 +567,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
|
||||||
// shared between the HLO and LHLO dialects.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
|
(ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
|
||||||
// Default value: one for each of the spatial dimension.
|
ConvolutionAttributes<LHLO_Dialect>.attributes);
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||||
@ -856,8 +814,7 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||||||
|
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
OpBuilder<"ArrayRef<NamedAttribute> attributes">
|
||||||
"ArrayRef<NamedAttribute> attributes">
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -867,9 +824,8 @@ def TerminatorOp :
|
|||||||
let description = [{
|
let description = [{
|
||||||
Terminator operation for the LHLO dialect.
|
Terminator operation for the LHLO dialect.
|
||||||
}];
|
}];
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<"ValueRange operands",
|
||||||
"OpBuilder &b, OperationState &result, ValueRange operands",
|
[{ build($_builder, $_state, llvm::None, operands, llvm::None); }]
|
||||||
[{ build(b, result, llvm::None, operands, llvm::None); }]
|
|
||||||
>];
|
>];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef LHLO_OPS_BASE
|
||||||
|
#define LHLO_OPS_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LMHLO type definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Any integer tensor types
|
||||||
|
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||||
|
|
||||||
|
// Any floating-point tensor types
|
||||||
|
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||||
|
|
||||||
|
// Any integer or floating-point tensor types
|
||||||
|
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||||
|
|
||||||
|
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
||||||
|
|
||||||
|
#endif // LHLO_OPS_BASE
|
@ -50,6 +50,7 @@ MAP_HLO_TO_LHLO(ConvOp);
|
|||||||
MAP_HLO_TO_LHLO(ConvertOp);
|
MAP_HLO_TO_LHLO(ConvertOp);
|
||||||
MAP_HLO_TO_LHLO(CopyOp);
|
MAP_HLO_TO_LHLO(CopyOp);
|
||||||
MAP_HLO_TO_LHLO(CosOp);
|
MAP_HLO_TO_LHLO(CosOp);
|
||||||
|
MAP_HLO_TO_LHLO(CustomCallOp);
|
||||||
MAP_HLO_TO_LHLO(DivOp);
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
MAP_HLO_TO_LHLO(DotOp);
|
MAP_HLO_TO_LHLO(DotOp);
|
||||||
MAP_HLO_TO_LHLO(ExpOp);
|
MAP_HLO_TO_LHLO(ExpOp);
|
||||||
@ -57,11 +58,13 @@ MAP_HLO_TO_LHLO(FloorOp);
|
|||||||
MAP_HLO_TO_LHLO(GatherOp);
|
MAP_HLO_TO_LHLO(GatherOp);
|
||||||
MAP_HLO_TO_LHLO(ImagOp);
|
MAP_HLO_TO_LHLO(ImagOp);
|
||||||
MAP_HLO_TO_LHLO(IotaOp);
|
MAP_HLO_TO_LHLO(IotaOp);
|
||||||
|
MAP_HLO_TO_LHLO(IsFiniteOp);
|
||||||
MAP_HLO_TO_LHLO(LogOp);
|
MAP_HLO_TO_LHLO(LogOp);
|
||||||
MAP_HLO_TO_LHLO(MaxOp);
|
MAP_HLO_TO_LHLO(MaxOp);
|
||||||
MAP_HLO_TO_LHLO(MinOp);
|
MAP_HLO_TO_LHLO(MinOp);
|
||||||
MAP_HLO_TO_LHLO(MulOp);
|
MAP_HLO_TO_LHLO(MulOp);
|
||||||
MAP_HLO_TO_LHLO(NegOp);
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
|
MAP_HLO_TO_LHLO(NotOp);
|
||||||
MAP_HLO_TO_LHLO(RealOp);
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
MAP_HLO_TO_LHLO(ReduceOp);
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
|
@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
|||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename PredicateType>
|
template <typename PredicateType>
|
||||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
@ -345,6 +354,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
|||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
if (args[0].getType().isa<FloatType>()) {
|
||||||
|
auto pos_inf = APFloat::getInf(
|
||||||
|
args[0].getType().cast<FloatType>().getFloatSemantics());
|
||||||
|
auto const_pos_inf =
|
||||||
|
b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
|
||||||
|
Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
|
||||||
|
return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
|
||||||
|
const_pos_inf);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
||||||
/// linalg.generic op) for compare-select style operations like min/max.
|
/// linalg.generic op) for compare-select style operations like min/max.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
@ -431,6 +456,21 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||||
|
// lmhlo.not(x) -> x ^ -1
|
||||||
|
auto all_ones =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
|
||||||
|
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
@ -454,11 +494,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
|||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = args.front().getType();
|
Type element_type = args.front().getType();
|
||||||
if (element_type.isa<FloatType>()) {
|
if (auto float_type = element_type.dyn_cast<FloatType>()) {
|
||||||
FloatType float_type = element_type.cast<FloatType>();
|
bool ignored;
|
||||||
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
APFloat one_apfloat(1.0f);
|
||||||
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
|
one_apfloat.convert(float_type.getFloatSemantics(),
|
||||||
|
APFloat::rmNearestTiesToEven, &ignored);
|
||||||
|
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
|
||||||
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||||
|
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||||
|
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
|
||||||
|
Value zero =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||||
|
Value cmp =
|
||||||
|
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
|
||||||
|
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
|
||||||
|
loc, integer_type.getWidth() - 1, integer_type.getWidth());
|
||||||
|
Value ashr =
|
||||||
|
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
|
||||||
|
Value one =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
|
||||||
|
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
|
||||||
|
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
|
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
|
||||||
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
|
let summary = "Legalize CHLO to HLO.";
|
||||||
let constructor = "createTestChloLegalizeToHloPass()";
|
let constructor = "createChloLegalizeToHloPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||||
|
@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
|||||||
/// Lowers from HLO dialect to Standard dialect.
|
/// Lowers from HLO dialect to Standard dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
|
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||||
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
|
||||||
|
|
||||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
/// buffers if necessary. If `results_escape_functions` is set to true,
|
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||||
/// allocated buffers for function results will be returned and escape the
|
/// allocated buffers for function results will be returned and escape the
|
||||||
@ -63,7 +66,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
|||||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
||||||
|
|
||||||
/// Lowers trigonometric operations from the standard dialect to approximations
|
/// Lowers trigonometric operations from the standard dialect to approximations
|
||||||
// that do not use intrinsics.
|
/// that do not use intrinsics.
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
createLegalizeTrigonometricToApproximationPass();
|
createLegalizeTrigonometricToApproximationPass();
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
|
|
||||||
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
||||||
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
||||||
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Transforms/BufferPlacement.h"
|
#include "mlir/Transforms/Bufferize.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -185,8 +185,7 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const auto& dnums = gather.dimension_numbers();
|
const auto& dnums = gather.dimension_numbers();
|
||||||
if (dnums.collapsed_slice_dims().getNumElements() != 0 ||
|
if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
|
||||||
dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// TODO(tberghammer): Remove when the verifier catches this case what is
|
// TODO(tberghammer): Remove when the verifier catches this case what is
|
||||||
@ -206,11 +205,35 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
|
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
|
||||||
rewriter.replaceOpWithNewOp<SliceOp>(
|
llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
|
||||||
gather, gather.getType(), gather.getOperand(0),
|
for (int64_t i = 0; i < slice_end.size(); ++i) {
|
||||||
|
slice_shape[i] = slice_end[i] - slice_start[i];
|
||||||
|
}
|
||||||
|
Type element_type = gather.getType().cast<TensorType>().getElementType();
|
||||||
|
auto slice_type = RankedTensorType::get(slice_shape, element_type);
|
||||||
|
Value result = rewriter.create<SliceOp>(
|
||||||
|
gather.getLoc(), slice_type, gather.getOperand(0),
|
||||||
GetI64ElementsAttr(slice_start, &rewriter),
|
GetI64ElementsAttr(slice_start, &rewriter),
|
||||||
GetI64ElementsAttr(slice_end, &rewriter),
|
GetI64ElementsAttr(slice_end, &rewriter),
|
||||||
GetI64ElementsAttr(slice_stride, &rewriter));
|
GetI64ElementsAttr(slice_stride, &rewriter));
|
||||||
|
|
||||||
|
if (dnums.collapsed_slice_dims().getNumElements() > 0) {
|
||||||
|
auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
|
||||||
|
dnums.collapsed_slice_dims().getIntValues(),
|
||||||
|
[](const llvm::APInt& i) { return i.getSExtValue(); }));
|
||||||
|
llvm::SmallVector<int64_t, 8> reshape_shape;
|
||||||
|
for (int64_t i = 0; i < slice_shape.size(); ++i) {
|
||||||
|
if (llvm::count(collapsed_slice_dims, i) == 0) {
|
||||||
|
reshape_shape.push_back(slice_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
|
||||||
|
result =
|
||||||
|
rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.setType(gather.getType());
|
||||||
|
rewriter.replaceOp(gather, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -889,9 +912,10 @@ static LogicalResult Verify(ClampOp op) {
|
|||||||
// ComplexOp
|
// ComplexOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
LogicalResult ComplexOp::inferReturnTypes(
|
||||||
Value rhs) {
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
auto type = lhs.getType();
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
auto type = operands[0].getType();
|
||||||
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
||||||
Type result_ty;
|
Type result_ty;
|
||||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
@ -901,8 +925,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
|||||||
} else {
|
} else {
|
||||||
result_ty = element_ty;
|
result_ty = element_ty;
|
||||||
}
|
}
|
||||||
|
inferredReturnTypes.push_back(result_ty);
|
||||||
build(builder, state, result_ty, lhs, rhs);
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||||
@ -932,8 +956,11 @@ Type CreateRealType(Type type) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
LogicalResult ImagOp::inferReturnTypes(
|
||||||
build(builder, state, CreateRealType(val.getType()), val);
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||||
@ -945,8 +972,11 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
LogicalResult RealOp::inferReturnTypes(
|
||||||
build(builder, state, CreateRealType(val.getType()), val);
|
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
|
||||||
|
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
|
||||||
|
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||||
@ -1971,6 +2001,23 @@ struct divide<APInt> {
|
|||||||
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct remainder : std::modulus<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APInt> {
|
||||||
|
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APFloat> {
|
||||||
|
APFloat operator()(const APFloat& a, const APFloat& b) const {
|
||||||
|
APFloat result(a);
|
||||||
|
result.remainder(b);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct max {
|
struct max {
|
||||||
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
||||||
@ -2012,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
|
|||||||
BINARY_FOLDER(SubOp, std::minus);
|
BINARY_FOLDER(SubOp, std::minus);
|
||||||
BINARY_FOLDER(MulOp, std::multiplies);
|
BINARY_FOLDER(MulOp, std::multiplies);
|
||||||
BINARY_FOLDER(DivOp, divide);
|
BINARY_FOLDER(DivOp, divide);
|
||||||
|
BINARY_FOLDER(RemOp, remainder);
|
||||||
BINARY_FOLDER(MaxOp, max);
|
BINARY_FOLDER(MaxOp, max);
|
||||||
BINARY_FOLDER(MinOp, min);
|
BINARY_FOLDER(MinOp, min);
|
||||||
|
|
||||||
@ -2261,10 +2309,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state,
|
|||||||
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
|
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
|
||||||
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
|
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
|
||||||
|
|
||||||
SmallVector<Type, 2> element_types;
|
for (Value operand : operands) state.addTypes(operand.getType());
|
||||||
element_types.reserve(operands.size());
|
|
||||||
for (Value operand : operands) element_types.push_back(operand.getType());
|
|
||||||
state.addTypes(builder.getTupleType(element_types));
|
|
||||||
|
|
||||||
state.addRegion();
|
state.addRegion();
|
||||||
}
|
}
|
||||||
|
@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||||||
auto if_op = rewriter.create<scf::IfOp>(
|
auto if_op = rewriter.create<scf::IfOp>(
|
||||||
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
|
||||||
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
|
||||||
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
|
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
|
||||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
|
||||||
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
|
||||||
@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
|
|||||||
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
|
||||||
if_rhs_scalar_op.getResult(0));
|
if_rhs_scalar_op.getResult(0));
|
||||||
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
|
||||||
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
|
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
|
||||||
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
|
||||||
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
|
||||||
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
|
||||||
@ -516,7 +516,7 @@ struct HloCompareAdaptor {
|
|||||||
|
|
||||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
populateWithGenerated(context, patterns);
|
populateWithGenerated(context, *patterns);
|
||||||
|
|
||||||
// Instantiate conversion templates for conforming binary elementwise ops
|
// Instantiate conversion templates for conforming binary elementwise ops
|
||||||
// that do not have different dtypes between operands and results and do
|
// that do not have different dtypes between operands and results and do
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user