Merge branch 'master' into pluggable_device_load

This commit is contained in:
Zhoulong Jiang 2020-10-23 16:03:48 +00:00
commit 22d22ff5b3
1932 changed files with 56329 additions and 31099 deletions

View File

@ -174,6 +174,12 @@ build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only -c opt
# Config setting to build with oneDNN for Arm.
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_aarch64 --define=build_with_mkl_opensource=true
build:mkl_aarch64 -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true

View File

@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees
assignees:

28
.github/workflows/update-nightly.yml vendored Normal file
View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
name: Set nightly branch to master HEAD
jobs:
master-to-nightly:
runs-on: ubuntu-latest
steps:
- uses: zofrex/mirror-branch@v1
name: Set nightly branch to master HEAD
with:
target-branch: 'nightly'

View File

@ -4,7 +4,7 @@
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
/tenosrflow/core/debug @caisq
/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @gunan @mihaimaruseac
/tensorflow/core/platform/windows/ @mihaimaruseac
/tensorflow/lite/experimental/micro @petewarden @advaitjain
/tensorflow/python/autograph/ @mdanatg @kkimdev
/tensorflow/python/debug @caisq

View File

@ -34,6 +34,7 @@
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
* TF pip packages now use CUDA11 and cuDNN 8.0.2.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
@ -46,6 +47,13 @@
* `tf.data.experimental.service.WorkerServer` now takes a config tuple
instead of individual arguments. Usages should be updated to
`tf.data.experimental.service.WorkerServer(worker_config)`.
* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which
updates the gradient definition for quantization which is outside the range
to be 0. To simulate the V1 the behavior of
tf.quantization.quantize_and_dequantize(...) use
tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...).
* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please
use `tf.data.Dataset.from_tensor_slices` instead.
## Known Caveats
@ -63,143 +71,168 @@
## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* Security:
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
`SparseCountSparseOutput` operations
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
* Fixes an integer truncation vulnerability in code using the work sharder API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Fixes several vulnerabilities in TFLite implementation of segment sum
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
* TF Core:
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
type annotation for variables representing a Tensor or a value that can be
converted to Tensor by `tf.convert_to_tensor`.
* Calling ops with a python constants or numpy values is now consistent with
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
and `__invert__` now support non-`bool` arguments and apply the
corresponding bitwise ops. `bool` arguments continue to be supported and
dispatch to logical ops. This brings them more in line with Python and NumPy
benavior.
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with
the same sparsity pattern, but with new provided values. It is similar to
the `with_values` function of `RaggedTensor`.
* Added `StatelessCase` op, and uses it if none of case branches has stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory usage
of the device.
* `tf.data`:
* tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
to register a dataset with the tf.data service, and another process to
consume data from the dataset.
* Added support for dispatcher fault tolerance. To enable fault tolerance,
configure a `work_dir` when running your dispatcher server and set
`dispatcher_fault_tolerance=True`. The dispatcher will store its state to
`work_dir`, so that on restart it can continue from its previous state
after restart.
* Added support for sharing dataset graphs via shared filesystem instead of
over RPC. This reduces load on the dispatcher, improving performance of
distributing datasets. For this to work, the dispatcher's `work_dir` must
be accessible from workers. If the worker fails to read from the
`work_dir`, it falls back to using RPC for dataset graph transfer.
* Added support for a new "distributed_epoch" processing mode. This
processing mode distributes a dataset across all tf.data workers, instead
of having each worker process the full dataset. See
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
to learn more.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overriden.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
with a new `output_signature` argument, which allows `from_generator` to
produce any type describable by a `tf.TypeSpec`.
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
`tf.data.AUTOTUNE`.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a determinstic
version of `sample_distorted_bounding_box` op. Given the same seed, these
stateless functions/ops produce the same results independent of how many
times the function is called, and independent of global seed settings.
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* Security:
* Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`
([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190))
* Fixes three vulnerabilities in conversion to DLPack format
([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191),
[CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192),
[CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193))
* Fixes two vulnerabilities in `SparseFillEmptyRowsGrad`
([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194),
[CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195))
* Fixes several vulnerabilities in `RaggedCountSparseOutput` and
`SparseCountSparseOutput` operations
([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196),
[CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197),
[CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198),
[CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199),
[CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200),
[CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201))
* Fixes an integer truncation vulnerability in code using the work sharder
API
([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202))
* Fixes a format string vulnerability in `tf.strings.as_string`
([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203))
* Fixes segfault raised by calling session-only ops in eager mode
([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204))
* Fixes data leak and potential ASLR violation from
`tf.raw_ops.StringNGrams`
([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205))
* Fixes segfaults caused by incomplete `SavedModel` validation
([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206))
* Fixes a data corruption due to a bug in negative indexing support in
TFLite
([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207))
* Fixes a data corruption due to dimension mismatch in TFLite
([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208))
* Fixes several vulnerabilities in TFLite saved model format
([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209),
[CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210),
[CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211))
* Fixes several vulnerabilities in TFLite implementation of segment sum
([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212),
[CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213),
[CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214))
* TF Core:
* `tf.types.experimental.TensorLike` is a new `Union` type that can be
used as type annotation for variables representing a Tensor or a value
that can be converted to Tensor by `tf.convert_to_tensor`.
* Calling ops with a python constants or numpy values is now consistent
with tf.convert_to_tensor behavior. This avoids operations like
tf.reshape truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of
`SparseTensor` arguments.
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`,
`__xor__` and `__invert__` now support non-`bool` arguments and apply
the corresponding bitwise ops. `bool` arguments continue to be supported
and dispatch to logical ops. This brings them more in line with Python
and NumPy behavior.
* Added `tf.SparseTensor.with_values`. This returns a new SparseTensor
with the same sparsity pattern, but with new provided values. It is
similar to the `with_values` function of `RaggedTensor`.
* Added `StatelessCase` op, and uses it if none of case branches has
stateful ops.
* Added `tf.config.experimental.get_memory_usage` to return total memory
usage of the device.
* `tf.data`:
* tf.data service:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one
process to register a dataset with the tf.data service, and another
process to consume data from the dataset.
* Added support for dispatcher fault tolerance. To enable fault tolerance,
configure a `work_dir` when running your dispatcher server and set
`dispatcher_fault_tolerance=True`. The dispatcher will store its state
to `work_dir`, so that on restart it can continue from its previous
state after restart.
* Added support for sharing dataset graphs via shared filesystem instead
of over RPC. This reduces load on the dispatcher, improving performance
of distributing datasets. For this to work, the dispatcher's `work_dir`
must be accessible from workers. If the worker fails to read from the
`work_dir`, it falls back to using RPC for dataset graph transfer.
* Added support for a new "distributed_epoch" processing mode. This
processing mode distributes a dataset across all tf.data workers,
instead of having each worker process the full dataset. See
[the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode)
to learn more.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be
specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.data.Options` were previously immutable and can now be overridden.
* `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors
with a new `output_signature` argument, which allows `from_generator` to
produce any type describable by a `tf.TypeSpec`.
* `tf.data.experimental.AUTOTUNE` is now available in the core API as
`tf.data.AUTOTUNE`.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a deterministic
version of `sample_distorted_bounding_box` op. Given the same seed,
these stateless functions/ops produce the same results independent of
how many times the function is called, and independent of global seed
settings.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
* Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models.
* Functional model construction should be ~8-10% faster on average.
* Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument.
* Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale`
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* The `steps_per_execution` argument in `compile()` is no longer
experimental; if you were passing `experimental_steps_per_execution`,
rename it to `steps_per_execution` in your code. This argument controls
the number of batches to run during each `tf.function` call when calling
`fit()`. Running multiple batches inside a single `tf.function` call can
greatly improve performance on TPUs or small models with a large Python
overhead.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the previous
iteration. These types of loops must run at least one iteration, and will
raise a runtime error otherwise.
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
* Improvements from the functional API refactoring:
* Functional model construction does not need to maintain a global
workspace graph, removing memory leaks especially when building many
models or very large models.
* Functional model construction should be ~8-10% faster on average.
* Functional models can now contain non-symbolic values in their call
inputs inside of the first positional argument.
* Several classes of TF ops that were not reliably converted to Keras
layers during functional API construction should now work, e.g.
`tf.image.ssim_multiscale`
* Error messages when Functional API construction goes wrong (and when
ops cannot be converted to Keras layers automatically) should be
clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
to match FTRL paper
(https://research.google.com/pubs/archive/41159.pdf).
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* The `steps_per_execution` argument in `compile()` is no longer
experimental; if you were passing `experimental_steps_per_execution`,
rename it to `steps_per_execution` in your code. This argument controls
the number of batches to run during each `tf.function` call when calling
`fit()`. Running multiple batches inside a single `tf.function` call can
greatly improve performance on TPUs or small models with a large Python
overhead.
* Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an
init arg.
* Normalization can now accept mean and variance values as init args.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
accepts a `return_attention_scores` argument. When set to
True, the layer returns the attention scores as an additional output
argument.
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
with the same implementation as their `tf.losses` equivalent.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the
previous iteration. These types of loops must run at least one
iteration, and will raise a runtime error otherwise.
Example:
@ -208,51 +241,97 @@
outputs = train_step(batch)
tf.print('final outputs', outputs)
```
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info.
* `tf.lite`:
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
* TFLite Profiler for Android is available. See the detailed
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
* <ADD RELEASE NOTES HERE>
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and
`inference_output_type` for full integer quantized models. This
allows users to modify the model input and output type to integer
types (`tf.int8`, `tf.uint8`) instead of defaulting to float type
(`tf.float32`).
* TFLite Profiler for Android is available. See the detailed
[guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android).
* NNAPI
* Added NNAPI Delegation support for requantization use cases by
converting the operation into a dequantize-quantize pair.
* Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API.
* Use `Interpreter.Options.setUseNNAPI` instead.
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API
* Prefer controlling this via delegate options, e.g.
`tflite::StatefulNnApiDelegate::Options::allow_fp16' or
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* TPU Enhancements:
* Added support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* Added support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* Added `tf.function.experimental_get_compiler_ir` which returns compiler IR
(currently 'hlo' and 'optimized_hlo') for given input for given function.
* <ADD RELEASE NOTES HERE>
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* Added `tf.function.experimental_get_compiler_ir` which returns compiler
IR (currently 'hlo' and 'optimized_hlo') for given input for given
function.
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* `tf.train.Checkpoint`:
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a `Checkpoint`
object that is compatible with Keras `model.save_weights()` and
`model.load_weights`. The checkpoint is also compatible with the
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* Now accepts a `root` argument in the initialization, which generates a
checkpoint with a root object. This allows users to create a
`Checkpoint` object that is compatible with Keras `model.save_weights()`
and `model.load_weights`. The checkpoint is also compatible with the
checkpoint saved in the `variables/` folder in the SavedModel.
* When restoring, `save_path` can be a path to a SavedModel. The function
will automatically find the checkpoint in the SavedModel.
* `tf.nn`:
* `tf.nn.max_pool2d` now supports explicit padding.
* `tf.nn.max_pool2d` now supports explicit padding.
* `tf.debugging`:
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
* `tf.print`:
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context.
<ADD RELEASE NOTES HERE>
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more
context.
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
rollout the new MLIR TPU bridge.
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors
@ -500,42 +579,87 @@ stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.0
## Major Features and Improvements
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and
save resources:
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your models memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
In addition checkout the detailed
[guide](https://www.tensorflow.org/guide/data_performance_analysis) for
analyzing input pipeline performance with TF Profiler.
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
is now a stable API and no longer considered experimental for TensorFlow.
(earlier `tf.distribute.experimental.TPUStrategy`).
* TFLite now properly supports dynamic shapes during conversion and inference. Weve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new
tools: a memory profiler to visualize your models memory usage over time
and a [python tracer](https://www.tensorflow.org/guide/profiler#events)
which allows you to trace python function calls in your model. Usability
improvements include better diagnostic messages and
[profile options](https://tensorflow.org/guide/profiler#collect_performance_data)
to customize the host and device trace verbosity level.
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
* Introduces experimental support for Keras Preprocessing Layers API
([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly))
to handle data preprocessing operations, with support for composite tensor
inputs. Please see below for additional details on these layers.
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
* TFLite now properly supports dynamic shapes during conversion and inference.
Weve also added opt-in support on Android and iOS for
[XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack),
a highly optimized set of CPU kernels, as well as opt-in support for
[executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
* Libtensorflow packages are available in GCS starting this release. We have
also started to
[release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
* The experimental Python API
[`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info)
now allows you to instrument a TensorFlow program and dump debugging
information to a directory on the file system. The directory can be read and
visualized by a new interactive dashboard in TensorBoard 2.3 called
[Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which
reveals the details of the TensorFlow program including graph structures,
history of op executions at the Python (eager) and intra-graph levels, the
runtime dtype, shape, and numerical composition of tensors, as well as their
code locations.
## Breaking Changes
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
* `tf.data`
* Makes the following (breaking) changes to the `tf.data`.
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
* `tf.keras`
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
* `tf.data`
* Makes the following (breaking) changes to the `tf.data`.
* C++ API: - `IteratorBase::RestoreInternal`,
`IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState`
become pure-virtual and subclasses are now expected to provide an
implementation.
* The deprecated `DatasetBase::IsStateful` method is removed in favor of
`DatasetBase::CheckExternalState`.
* Deprecated overrides of `DatasetBase::MakeIterator` and
`MakeIteratorFromInputElement` are removed.
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and
`tensorflow::data::IteratorBase::SaveInput` has been extended with
`SerializationContext` argument to enable overriding the default policy
for the handling external state during iterator checkpointing. This is
not a backwards compatible change and all subclasses of `IteratorBase`
*need to be updated* accordingly.
* `tf.keras`
* Add a new `BackupAndRestore` callback for handling distributed training
failures & restarts. Please take a look at this
[tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras)
for details on how to use the callback.
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
existing C++ kernel `ExtractGlimpse` does not change either, so saved models
using `tf.raw_ops.ExtractGlimpse` will not be impacted.
## Known Caveats
* `tf.lite`
@ -1105,7 +1229,7 @@ This release contains contributions from many people at Google, as well as:
8bitmp3, Aaron Ma, AbdüLhamit Yilmaz, Abhai Kollara, aflc, Ag Ramesh, Albert Z. Guo, Alex Torres, amoitra, Andrii Prymostka, angeliand, Anshuman Tripathy, Anthony Barbier, Anton Kachatkou, Anubh-V, Anuja Jakhade, Artem Ryabov, autoih, Bairen Yi, Bas Aarts, Basit Ayantunde, Ben Barsdell, Bhavani Subramanian, Brett Koonce, candy.dc, Captain-Pool, caster, cathy, Chong Yan, Choong Yin Thong, Clayne Robison, Colle, Dan Ganea, David Norman, David Refaeli, dengziming, Diego Caballero, Divyanshu, djshen, Douman, Duncan Riach, EFanZh, Elena Zhelezina, Eric Schweitz, Evgenii Zheltonozhskii, Fei Hu, fo40225, Fred Reiss, Frederic Bastien, Fredrik Knutsson, fsx950223, fwcore, George Grzegorz Pawelczak, George Sterpu, Gian Marco Iodice, Giorgio Arena, giuros01, Gomathi Ramamurthy, Guozhong Zhuang, Haifeng Jin, Haoyu Wu, HarikrishnanBalagopal, HJYOO, Huang Chen-Yi, Ilham Firdausi Putra, Imran Salam, Jared Nielsen, Jason Zaman, Jasper Vicenti, Jeff Daily, Jeff Poznanovic, Jens Elofsson, Jerry Shih, jerryyin, Jesper Dramsch, jim.meyer, Jongwon Lee, Jun Wan, Junyuan Xie, Kaixi Hou, kamalkraj, Kan Chen, Karthik Muthuraman, Keiji Ariyama, Kevin Rose, Kevin Wang, Koan-Sin Tan, kstuedem, Kwabena W. Agyeman, Lakshay Tokas, latyas, Leslie-Fang-Intel, Li, Guizi, Luciano Resende, Lukas Folle, Lukas Geiger, Mahmoud Abuzaina, Manuel Freiberger, Mark Ryan, Martin Mlostek, Masaki Kozuki, Matthew Bentham, Matthew Denton, mbhuiyan, mdfaijul, Muhwan Kim, Nagy Mostafa, nammbash, Nathan Luehr, Nathan Wells, Niranjan Hasabnis, Oleksii Volkovskyi, Olivier Moindrot, olramde, Ouyang Jin, OverLordGoldDragon, Pallavi G, Paul Andrey, Paul Wais, pkanwar23, Pooya Davoodi, Prabindh Sundareson, Rajeshwar Reddy T, Ralovich, Kristof, Refraction-Ray, Richard Barnes, richardbrks, Robert Herbig, Romeo Kienzler, Ryan Mccormick, saishruthi, Saket Khandelwal, Sami Kama, Sana Damani, Satoshi Tanaka, Sergey Mironov, Sergii Khomenko, Shahid, Shawn Presser, ShengYang1, Siddhartha Bagaria, Simon Plovyt, skeydan, srinivasan.narayanamoorthy, Stephen Mugisha, sunway513, Takeshi Watanabe, Taylor Jakobson, TengLu, TheMindVirus, ThisIsIsaac, Tim Gates, Timothy Liu, Tomer Gafner, Trent Lo, Trevor Hickey, Trevor Morris, vcarpani, Wei Wang, Wen-Heng (Jack) Chung, wenshuai, Wenshuai-Xiaomi, wenxizhu, william, William D. Irons, Xinan Jiang, Yannic, Yasir Modak, Yasuhiro Matsumoto, Yong Tang, Yongfeng Gu, Youwei Song, Zaccharie Ramzi, Zhang, Zhenyu Guo, 王振华 (Zhenhua Wang), 韩董, 이중건 Isaac Lee
# Release 1.15.0
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
This is the last 1.x release for TensorFlow. We do not expect to update the 1.x branch with features, although we will issue patch releases to fix vulnerabilities for at least one year.
## Major Features and Improvements
* As [announced](https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0), `tensorflow` pip package will by default include GPU support (same as `tensorflow-gpu` now) for the platforms we currently have GPU support (Linux and Windows). It will work on machines with and without Nvidia GPUs. `tensorflow-gpu` will still be available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
@ -1115,7 +1239,7 @@ This enables writing forward compatible code: by explicitly importing either `te
* Add toggles `tf.enable_control_flow_v2()` and `tf.disable_control_flow_v2()` for enabling/disabling v2 control flow.
* Enable v2 control flow as part of `tf.enable_v2_behavior()` and `TF2_BEHAVIOR=1`.
* AutoGraph translates Python control flow into TensorFlow expressions, allowing users to write regular Python inside `tf.function`-decorated functions. AutoGraph is also applied in functions used with `tf.data`, `tf.distribute` and `tf.keras` APIS.
* Adds `enable_tensor_equality()`, which switches the behavior such that:
* Adds `enable_tensor_equality()`, which switches the behavior such that:
* Tensors are no longer hashable.
* Tensors can be compared with `==` and `!=`, yielding a Boolean Tensor with element-wise comparison results. This will be the default behavior in 2.0.
@ -1271,12 +1395,12 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
* TensorFlow 2.0.0 is built using devtoolset7 (GCC7) on Ubuntu 16. This may lead to ABI incompatibilities with extensions built against earlier versions of TensorFlow.
* Tensorflow code now produces 2 different pip packages: tensorflow_core containing all the code (in the future it will contain only the private implementation) and tensorflow which is a virtual pip package doing forwarding to tensorflow_core (and in the future will contain only the public API of tensorflow). We don't expect this to be breaking, unless you were importing directly from the implementation.
Removed the `freeze_graph` command line tool; `SavedModel` should be used in place of frozen graphs.
* `tf.contrib`:
* `tf.contrib` has been deprecated, and functionality has been either migrated to the core TensorFlow API, to an ecosystem project such as [tensorflow/addons](https://www.github.com/tensorflow/addons) or [tensorflow/io](https://www.github.com/tensorflow/io), or removed entirely.
* Remove `tf.contrib.timeseries` dependency on TF distributions.
* Replace contrib references with `tf.estimator.experimental.*` for apis in `early_stopping.py`.
* `tf.estimator`:
* Premade estimators in the tf.estimator.DNN/Linear/DNNLinearCombined family have been updated to use `tf.keras.optimizers` instead of the `tf.compat.v1.train.Optimizer`s. If you do not pass in an `optimizer=` arg or if you use a string, the premade estimator will use the Keras optimizer. This is checkpoint breaking, as the optimizers have separate variables. A checkpoint converter tool for converting optimizers is included with the release, but if you want to avoid any change, switch to the v1 version of the estimator: `tf.compat.v1.estimator.DNN/Linear/DNNLinearCombined*`.
* Default aggregation for canned Estimators is now `SUM_OVER_BATCH_SIZE`. To maintain previous default behavior, please pass `SUM` as the loss aggregation method.
@ -1284,13 +1408,13 @@ For information on upgrading your existing TensorFlow 1.x models, please refer t
* `Estimator.export_savedmodel` has been renamed to `export_saved_model`.
* When saving to SavedModel, Estimators will strip default op attributes. This is almost always the correct behavior, as it is more forwards compatible, but if you require that default attributes to be saved with the model, please use `tf.compat.v1.Estimator`.
* Feature Columns have been upgraded to be more Eager-friendly and to work with Keras. As a result, `tf.feature_column.input_layer` has been deprecated in favor of `tf.keras.layers.DenseFeatures`. v1 feature columns have direct analogues in v2 except for `shared_embedding_columns`, which are not cross-compatible with v1 and v2. Use `tf.feature_column.shared_embeddings` instead.
* `tf.keras`:
* `OMP_NUM_THREADS` is no longer used by the default Keras config. To configure the number of threads, use `tf.config.threading` APIs.
* `tf.keras.model.save_model` and `model.save` now defaults to saving a TensorFlow SavedModel. HDF5 files are still supported.
* Deprecated `tf.keras.experimental.export_saved_model` and `tf.keras.experimental.function`. Please use `tf.keras.models.save_model(..., save_format='tf')` and `tf.keras.models.load_model` instead.
* Layers now default to float32, and automatically cast their inputs to the layer's dtype. If you had a model that used float64, it will probably silently use float32 in TensorFlow 2, and a warning will be issued that starts with `Layer <layer-name>` is casting an input tensor from dtype float64 to the layer's dtype of float32. To fix, either set the default dtype to float64 with `tf.keras.backend.set_floatx('float64')`, or pass `dtype='float64'` to each of the Layer constructors. See `tf.keras.layers.Layer` for more information.
* `tf.lite`:
* Removed `lite.OpHint`, `lite.experimental`, and `lite.constant` from 2.0 API.
* Tensors are no longer hashable, but instead compare element-wise with `==` and `!=`. Use `tf.compat.v1.disable_tensor_equality()` to return to the previous behavior.
@ -1525,8 +1649,8 @@ If you experience any snags when using TF 2.0, please let us know at the [TF 2.0
conversion. TensorRT initialization arguments are now passed wrapped in
a named-tuple, `TrtConversionParams`, rather than as separate arguments
as in `TrtGraphConverter`.
* Changed API to optimize TensorRT enginges during graph optimization.
This is now done by calling `converter.build()` where previously
* Changed API to optimize TensorRT engines during graph optimization. This
is now done by calling `converter.build()` where previously
`is_dynamic_op=False` would be set.
* `converter.convert()` no longer returns a `tf.function`. Now the
function must be accessed from the saved model.
@ -2536,7 +2660,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A
* [`tf.contrib.estimator.RNNEstimator`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/estimator/RNNClassifier)
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
API supports broadcasting for Bijectors with new API changes.
## Breaking Changes
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by
`variable_scope(tf.get_variable_scope(), ...)`.
@ -3015,7 +3139,7 @@ Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim
Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan,
Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay,
Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang,
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武
We are also grateful to all who filed issues or helped resolve them, asked and
answered questions, and were part of inspiring discussions.

View File

@ -1485,6 +1485,7 @@ def main():
'adding "--config=<>" to your build command. See .bazelrc for more '
'details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
config_info_line('ngraph', 'Build with Intel nGraph support.')
config_info_line('numa', 'Build with NUMA support.')

View File

@ -568,17 +568,7 @@ selects.config_setting_group(
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
package_group(
name = "internal",
packages = [
"//learning/brain/distribute/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
],
packages = ["//tensorflow/..."],
)
package_group(
@ -588,10 +578,8 @@ package_group(
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
package_group(
name = "types_whitelist",
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "types_whitelist")
# Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
@ -719,7 +707,7 @@ tf_cc_shared_object(
deps = [
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core:core_cpu_impl",
"//tensorflow/core/common_runtime:core_cpu_impl",
"//tensorflow/core:framework_internal_impl",
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",

View File

@ -217,6 +217,8 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:logging_ops",
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
],
}),
alwayslink = 1,
@ -254,6 +256,30 @@ tf_cuda_library(
}),
)
cc_library(
name = "tf_shape",
srcs = ["tf_shape.cc"],
hdrs = ["tf_shape.h"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tf_shape_internal",
"//tensorflow/core:framework",
],
)
cc_library(
name = "tf_shape_internal",
hdrs = ["tf_shape_internal.h"],
copts = tf_copts(),
visibility = ["//tensorflow:internal"],
deps = [
":conversion_macros",
"//tensorflow/core:framework",
],
)
cc_library(
name = "tf_status",
srcs = ["tf_status.cc"],

View File

@ -2488,6 +2488,48 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
return ret;
}
void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
using tensorflow::RecordMutation;
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
}
// TF_Server functions ----------------------------------------------
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)

View File

@ -1524,6 +1524,10 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
const char* name, TF_Status* status);
// Update edge, switch input/ output in a node
TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src,
TF_Input dst, TF_Status* status);
// --------------------------------------------------------------------------
// In-process TensorFlow server functionality, for use in distributed training.
// A Server instance encapsulates a set of devices and a Session target that

View File

@ -634,6 +634,40 @@ TEST(CAPI, Graph) {
TF_DeleteStatus(s);
}
TEST(CAPI, UpdateEdge) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
// Make two scalar constants.
TF_Operation* one = ScalarConst(1, graph, s, "one");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Operation* two = ScalarConst(2, graph, s, "two");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add oper.
TF_Operation* add = Add(one, two, graph, s, "add");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add another oper to the graph.
TF_Operation* neg = Neg(add, graph, s, "neg");
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
NodeDef node_def_neg;
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("add"), node_def_neg.input(0));
// update edge of neg
TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
EXPECT_EQ(string("one:0"), node_def_neg.input(0));
// Clean up
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
}
/*
TODO(skyewm): this test currently DCHECKs, change to bad status

View File

@ -3,7 +3,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"if_tpu",
"if_libtpu",
"tf_cc_test",
"tf_copts",
"tf_cuda_cc_test",
@ -116,7 +116,6 @@ filegroup(
"immediate_execution_context.h",
"immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"mnist_gradients_testutil.h",
"tape.h",
"tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
@ -290,7 +289,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
] + if_tpu(
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
@ -314,6 +313,7 @@ cc_library(
":gradients_internal",
":gradients_util",
":tape",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
@ -354,7 +354,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
] + if_tpu(
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),

View File

@ -39,7 +39,7 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif
#include "tensorflow/core/common_runtime/device.h"
@ -729,7 +729,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#if defined(PLATFORM_GOOGLE) && !defined(LIBTFTPU)
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
@ -904,9 +904,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetThreadLocalDevicePlacementPolicy(
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
@ -915,10 +913,8 @@ void TFE_ContextSetThreadLocalDevicePlacementPolicy(
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return static_cast<TFE_ContextDevicePlacementPolicy>(
context->GetDevicePlacementPolicy());
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
}
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
@ -1429,21 +1425,15 @@ void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->FindFunctionDef(name) != nullptr;
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
}
} // extern "C"

View File

@ -74,7 +74,7 @@ typedef enum TFE_ContextDevicePlacementPolicy {
// Placement policy which silently copies int32 tensors but not other dtypes.
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy;
// LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h)
// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
// Sets the default execution mode (sync/async). Note that this can be
// overridden per thread using TFE_ContextSetExecutorForThread.

View File

@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
// TODO(b/170399182): Update test once an alternative to using the function
// optimization hook is in place.
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}

View File

@ -49,15 +49,11 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(true);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetShouldStoreGraphs(false);
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
}
uint64_t TFE_GetContextId(TFE_Context* ctx) {
@ -544,22 +540,16 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
}
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetExecutorForThread(executor->executor());
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return new TFE_Executor(&context->Executor());
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
context->HostCPU()->parsed_name());
tensorflow::unwrap(ctx)->HostCPUParsedName());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
@ -572,9 +562,7 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto* function_def = context->FindFunctionDef(function_name);
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
@ -643,14 +631,10 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
}

View File

@ -191,7 +191,7 @@ Status TapeVSpace::CallBackwardFunction(
&ctx, incoming_gradients, result);
}
Status TapeVSpace::BuildOnesLike(TapeTensor t,
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const {
AbstractOperationPtr op(ctx_->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));

View File

@ -180,10 +180,6 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op.
@ -233,7 +229,7 @@ class TapeVSpace
std::vector<AbstractTensorHandle*>* result) const override;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(TapeTensor t,
Status BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const override;
// Looks up the ID of a Gradient.

View File

@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
return Status::OK();
}
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// y = sqrt(inputs[0])
// return grad(y, {inputs[0]})
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sqrt_output : sqrt_outputs) {
sqrt_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSqrtGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = sqrt(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//

View File

@ -29,8 +29,25 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class EagerExecutor;
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
enum ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default policy.
DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
};
// LINT.ThenChange(//tensorflow/c/eager/c_api.h)
// Abstract interface to a context.
//
@ -81,14 +98,6 @@ class ImmediateExecutionContext : public AbstractContext {
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
@ -97,11 +106,52 @@ class ImmediateExecutionContext : public AbstractContext {
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Find and return a added function by its name.
virtual const FunctionDef* FindFunctionDef(const string& name) const = 0;
// Return the ParsedName of Host CPU device.
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
// Configure soft device placement policy.
virtual void SetAllowSoftPlacement(bool enable) = 0;
// Configure device placement policy logging.
virtual void SetLogDevicePlacement(bool enable) = 0;
// Sets the device placement policy for the current thread.
virtual void SetThreadLocalDevicePlacementPolicy(
ContextDevicePlacementPolicy policy) = 0;
// Returns the device placement policy for the current thread.
virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
//===--------------------------------------------------------------------===//
// Following are legacy features in TF Eager Runtime.
// TODO(tf-runtime): Figure out a way to deprecate following features after
// migrated to TFRT.
//===--------------------------------------------------------------------===//
// Clear pending nodes in thread executors and kernel caches.
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Return the Eager Executor for current thread. Please note that Eager
// Executor is only used in current TF but not in TFRT.
virtual EagerExecutor& Executor() = 0;
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
// Configure graph collection in RunMetadata.
virtual void SetShouldStoreGraphs(bool value) = 0;
protected:
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}

View File

@ -25,133 +25,18 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
using std::vector;
using tensorflow::tracing::TracingOperation;
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
int num_retvals = 1;
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
// one-hot) and records it on the tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractTensorHandle* scores = inputs[0];
AbstractTensorHandle* labels = inputs[1];
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
int num_retvals = 2; // returns loss values and backprop
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
//===================== Test Models to run =========================
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
registry)); // Compute x+y.
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute x*y.
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
absl::MakeSpan(mm_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
absl::MakeSpan(temp_outputs), "relu",
registry)); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
absl::MakeSpan(temp_outputs),
"relu")); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // Compute W2*Relu(X*W1)
TF_RETURN_IF_ERROR(ops::MatMul(
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
"matmul1",
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss")); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0];
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/true,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/true,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0];
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1);
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
"relu0", registry)); // Relu(X)
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(relu_outputs),
"relu0")); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0", registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0"));
AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // W2*Relu(X*W1)
TF_RETURN_IF_ERROR(ops::MatMul(
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1)
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss")); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0];
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
"scalarMul0", registry)); // Compute eta*A
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
absl::MakeSpan(temp_outputs),
"scalarMul0")); // Compute eta*A
outputs[0] = temp_outputs[0];
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
"mul0", registry)); // Compute x*y
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
absl::MakeSpan(temp_outputs),
"mul0")); // Compute x*y
outputs[0] = temp_outputs[0];
delete tape;
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
registry));
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
outputs[0] = temp_outputs[0]; // loss values

View File

@ -29,45 +29,10 @@ limitations under the License.
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` and records it on the tape.
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// ====================== End Tape Ops ============================
// Computes
// y = inputs[0] + inputs[1]

View File

@ -100,7 +100,8 @@ class VSpace {
std::vector<Gradient*>* result) const = 0;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
virtual Status BuildOnesLike(TapeTensor t, Gradient** result) const = 0;
virtual Status BuildOnesLike(const TapeTensor& t,
Gradient** result) const = 0;
// Looks up the ID of a Gradient.
virtual int64 TensorId(Gradient* tensor) const = 0;

View File

@ -29,6 +29,7 @@ cc_library(
}),
deps = [
"//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//third_party/hadoop:hdfs",

View File

@ -22,11 +22,10 @@ limitations under the License.
#include <sstream>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
// Implementation of a filesystem for HADOOP environments.
// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes.
@ -149,15 +148,20 @@ class LibHDFS {
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
if (hdfs_home != nullptr) {
auto JoinPath = [](std::string home, std::string lib) {
#if defined(_WIN32)
if (home.back() != '\\') home.push_back('\\');
return home + "lib\\native\\" + lib;
#else
if (home.back() != '/') home.push_back('/');
return home + "lib/native/" + lib;
#endif
};
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
TryLoadAndBind(path.c_str(), &handle_, status);
if (TF_GetCode(status) == TF_OK) {
return;
} else {
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
}
}
@ -169,13 +173,15 @@ class LibHDFS {
void* handle_;
};
// We rely on HDFS connection caching here. The HDFS client calls
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
// internally.
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
// We implement connection caching in Tensorflow, which can significantly
// improve performance. Fixes #43187
hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
const std::string& path, TF_Status* status) {
auto libhdfs = hadoop_file->libhdfs;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
std::string cacheKey(scheme);
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
if (scheme == "file") {
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
@ -200,15 +206,24 @@ hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
SplitArchiveNameAndPath(&path_har, &namenode, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
cacheKey += namenode;
} else {
libhdfs->hdfsBuilderSetNameNode(
builder, namenode.empty() ? "default" : namenode.c_str());
cacheKey += namenode;
}
auto fs = libhdfs->hdfsBuilderConnect(builder);
if (fs == nullptr)
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
else
TF_SetStatus(status, TF_OK, "");
absl::MutexLock l(&hadoop_file->connection_cache_lock);
if (hadoop_file->connection_cache.find(cacheKey) ==
hadoop_file->connection_cache.end()) {
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
if (cacheFs == nullptr) {
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
return cacheFs;
}
hadoop_file->connection_cache[cacheKey] = cacheFs;
}
auto fs = hadoop_file->connection_cache[cacheKey];
TF_SetStatus(status, TF_OK, "");
return fs;
}
@ -409,30 +424,36 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(vnvo2409): Implement later
// Hadoop doesn't support Readonly Memory Region
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_hadoop_filesystem {
HadoopFile::HadoopFile(TF_Status* status)
: libhdfs(new LibHDFS(status)),
connection_cache_lock(),
connection_cache() {}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new LibHDFS(status);
filesystem->plugin_filesystem = new HadoopFile(status);
if (TF_GetCode(status) != TF_OK) return;
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
delete libhdfs;
delete hadoop_file;
}
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -448,8 +469,9 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -465,8 +487,9 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -497,8 +520,9 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -513,8 +537,9 @@ void PathExists(const TF_Filesystem* filesystem, const char* path,
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -532,8 +557,9 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
@ -553,8 +579,9 @@ int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -568,8 +595,9 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -583,8 +611,9 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path,
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
@ -619,8 +648,9 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, src, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
@ -640,8 +670,9 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src,
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
auto hadoop_file = static_cast<HadoopFile*>(filesystem->plugin_filesystem);
auto libhdfs = hadoop_file->libhdfs;
auto fs = Connect(hadoop_file, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
@ -677,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
return num_entries;
}
// TODO(vnvo2409): Implement later
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
} // namespace tf_hadoop_filesystem
@ -685,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_hadoop_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file =
tf_hadoop_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_hadoop_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -15,10 +15,13 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_
#include <map>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
void ParseHadoopPath(const std::string& fname, std::string* scheme,
std::string* namenode, std::string* path);
@ -43,6 +46,14 @@ void Close(const TF_WritableFile* file, TF_Status* status);
} // namespace tf_writable_file
namespace tf_hadoop_filesystem {
typedef struct HadoopFile {
LibHDFS* libhdfs;
absl::Mutex connection_cache_lock;
std::map<std::string, hdfsFS> connection_cache
ABSL_GUARDED_BY(connection_cache_lock);
HadoopFile(TF_Status* status);
} HadoopFile;
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,

View File

@ -352,6 +352,48 @@ TEST_F(HadoopFileSystemTest, WriteWhileReading) {
EXPECT_TF_OK(status_);
}
TEST_F(HadoopFileSystemTest, ReadWhileOverwriting) {
static char set_disable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=1";
putenv(set_disable_var);
const std::string path = TmpDir("ReadWhileOverwriting");
if (path.find_first_of("hdfs://") != 0) GTEST_SKIP();
const string content1 = "content1";
WriteString(path, content1);
ASSERT_TF_OK(status_);
auto reader = GetReader();
tf_hadoop_filesystem::NewRandomAccessFile(filesystem_, path.c_str(),
reader.get(), status_);
EXPECT_TF_OK(status_);
std::string result;
result.resize(content1.size());
auto read = tf_random_access_file::Read(reader.get(), 0, content1.size(),
&result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(content1, result);
tf_hadoop_filesystem::DeleteFile(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
string content2 = "overwrite";
WriteString(path, content1 + content2);
ASSERT_TF_OK(status_);
result.resize(content2.size());
read = tf_random_access_file::Read(reader.get(), content1.size(),
content2.size(), &result[0], status_);
result.resize(read);
EXPECT_TF_OK(status_);
EXPECT_EQ(0, result.size());
static char set_enable_var[] = "HDFS_DISABLE_READ_EOF_RETRIED=0";
putenv(set_enable_var);
}
TEST_F(HadoopFileSystemTest, HarSplit) {
const std::string har_path =
"har://hdfs-root/user/j.doe/my_archive.har/dir0/dir1/file.txt";

View File

@ -24,6 +24,7 @@ using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::MatMul;
using tensorflow::ops::Mul;
using tensorflow::ops::SqrtGrad;
namespace tensorflow {
namespace gradients {
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
AbstractTensorHandlePtr exp_;
};
class SqrtGradientFunction : public GradientFunction {
public:
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
sqrt->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
std::string name = "Sqrt_Grad";
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~SqrtGradientFunction() override {}
private:
AbstractTensorHandlePtr sqrt_;
};
class MatMulGradientFunction : public GradientFunction {
public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -19,10 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -38,3 +38,29 @@ cc_library(
"//tensorflow/c/eager:gradients_internal",
],
)
cc_library(
name = "tape",
hdrs = [
"tape_context.h",
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tape_context",
":tape_operation",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"tape_context.h",
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
)

View File

@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return exp_op->Execute(outputs, &num_retvals);
}
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
int num_retvals = 1;
Status s = sqrt_op->Execute(outputs, &num_retvals);
return s;
}
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
int num_retvals = 1;
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
return s;
}
} // namespace ops
} // namespace tensorflow

View File

@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -91,15 +91,24 @@ cc_library(
":signature_def_function_metadata",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "signature_def_function_metadata",
srcs = [
"signature_def_function_metadata.cc",
],
hdrs = [
"signature_def_function_metadata.h",
],
deps = [
":tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
@ -268,6 +277,20 @@ tf_cc_test(
],
)
cc_library(
name = "tensor_spec",
srcs = [
"tensor_spec.cc",
],
hdrs = [
"tensor_spec.h",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "tf_concrete_function_loading_test",
srcs = [

View File

@ -92,6 +92,8 @@ cc_library(
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
@ -164,6 +166,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include "absl/types/span.h"
@ -30,14 +32,26 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace {
using StructuredValueDictEntry =
protobuf::MapPair<std::string, StructuredValue>;
using NamedParamMap =
gtl::FlatMap<StringPiece, const TensorSpecProto*, StringPieceHasher>;
Status AssertAllCreateResourceFunctionsHaveNoCaptures(
const PartiallyRevivedObjects& objects) {
for (const auto& id_and_resource : objects.restored_resources) {
@ -124,6 +138,142 @@ Status TensorHandleFromNode(int node_id, const SavedObjectGraph& obj_graph,
}
}
std::vector<SignatureDefParam> SignatureDefParamsFromNamedParamMap(
const NamedParamMap& params) {
// The underlying functiondef associated with the SignatureDef has
// nest.flattened inputs and outputs, which are sorted by string key.
std::vector<SignatureDefParam> result;
result.reserve(params.size());
for (const auto& named_param : params) {
result.push_back(SignatureDefParam(std::string(named_param.first),
TensorSpec(*named_param.second)));
}
std::sort(result.begin(), result.end(),
[](const SignatureDefParam& x, const SignatureDefParam& y) {
return x.name() < y.name();
});
return result;
}
// SignatureDefArgsFromInputs takes the "canonicalized_input_signature"
// field of a SavedConcreteFunction, ensures it conforms to the structure of
// tuple(tuple(), dict<string,TensorSpec>()), and "returns" a list of
// SignatureDefParams of the SignatureDefFunction's arguments.
Status SignatureDefArgsFromInputs(
const StructuredValue& canonicalized_input_signature,
std::vector<SignatureDefParam>* out) {
// Note(bmzhao): canonicalized_input_signature should be a tuple of
// (args, kwargs), where args is an empty tuple, and kwargs is a dictionary of
// string keys to TensorSpecs.
if (!canonicalized_input_signature.has_tuple_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature should be "
"of form tuple(tuple(), dict()), but was instead: \n",
canonicalized_input_signature.DebugString());
}
const TupleValue& args_kwargs_tuple =
canonicalized_input_signature.tuple_value();
if (args_kwargs_tuple.values_size() != 2) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature should be "
"a tuple of two elements (args, kwargs), but was instead: \n",
args_kwargs_tuple.DebugString());
}
const StructuredValue& args = args_kwargs_tuple.values(0);
if (!args.has_tuple_value() || !args.tuple_value().values().empty()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's args"
"should be an empty tuple, but instead got: \n",
args.DebugString());
}
const StructuredValue& kwargs = args_kwargs_tuple.values(1);
if (!kwargs.has_dict_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"should be a dictionary, but instead got: \n",
kwargs.DebugString());
}
const DictValue& kwargs_dict = kwargs.dict_value();
NamedParamMap result;
result.reserve(kwargs_dict.fields_size());
for (const auto& key_value : kwargs_dict.fields()) {
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's canonicalized_input_signature's kwargs"
"dictionary contained a non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
}
result[key] = &value.tensor_spec_value();
}
*out = SignatureDefParamsFromNamedParamMap(result);
return Status();
}
// SignatureDefReturnsFromOutputs takes the "output_signature" field of a
// SavedConcreteFunction, ensures it conforms to the structure of
// dict<string,TensorSpec>(), and "returns" a list of SignatureDefParams of the
// SignatureDefFunction's returns.
Status SignatureDefReturnsFromOutputs(const StructuredValue& output_signature,
std::vector<SignatureDefParam>* out) {
if (!output_signature.has_dict_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's output_signature must be a dictionary, but "
"instead got: ",
output_signature.DebugString());
}
const DictValue& output_dict = output_signature.dict_value();
NamedParamMap result;
result.reserve(output_dict.fields_size());
for (const auto& key_value : output_dict.fields()) {
const std::string& key = key_value.first;
const StructuredValue& value = key_value.second;
if (!value.has_tensor_spec_value()) {
return errors::FailedPrecondition(
"SignatureDefFunction's output_signature dictionary contained a "
"non-tensorspec value for key-value pair: \n",
"Key: ", key, "Value: \n", value.DebugString());
}
result[key] = &value.tensor_spec_value();
}
*out = SignatureDefParamsFromNamedParamMap(result);
return Status();
}
// The implementation takes advantage of the fact that SignatureDefFunction's
// "traced" Signature wrapper function always has inputs/outputs of dictionaries
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L119-L126
// https://github.com/tensorflow/tensorflow/blob/53cdd5e87c423b195f33775753273286fd5a1a65/tensorflow/python/saved_model/signature_serialization.py#L153-L178
// Additionally, we take advantage of the fact that the SignatureDefFunction's
// associated functiondef has lexicographically ordered inputs/outputs due to
// nest.flatten.
Status LoadSignatureDefFunctionMetadata(
const SavedConcreteFunction& saved_concrete_function,
SignatureDefFunctionMetadata* out) {
std::vector<SignatureDefParam> args;
TF_RETURN_IF_ERROR(SignatureDefArgsFromInputs(
saved_concrete_function.canonicalized_input_signature(), &args));
std::vector<SignatureDefParam> rets;
TF_RETURN_IF_ERROR(SignatureDefReturnsFromOutputs(
saved_concrete_function.output_signature(), &rets));
*out = SignatureDefFunctionMetadata(std::move(args), std::move(rets));
return Status();
}
// This function finds the necessary captures, then forwards to the builder
// method
Status CreateConcreteFunction(ImmediateExecutionContext* ctx,
@ -162,10 +312,14 @@ Status CreateSignatureDefFunction(
&capture_handle));
captures.push_back(capture_handle);
}
// TODO(bmzhao): Create Metadata here
SignatureDefFunctionMetadata metadata;
TF_RETURN_IF_ERROR(LoadSignatureDefFunctionMetadata(
*builder.saved_concrete_func, &metadata));
return TFSignatureDefFunction::Create(/*function_def=*/builder.fdef,
/*captures=*/std::move(captures),
/*metadata=*/{},
/*metadata=*/std::move(metadata),
/*ctx=*/ctx,
/*out=*/out);
}
@ -378,6 +532,7 @@ Status PartiallyRevivedObjects::Build(ImmediateExecutionContext* ctx,
revived->variables = std::move(variables);
revived->assets = std::move(assets);
revived->constants = std::move(constants);
revived->signatures_map = std::move(signatures_map);
// 3b. Move over resources.
TF_RETURN_IF_ERROR(BuildResources(ctx, obj_graph, this, revived));

View File

@ -36,7 +36,14 @@ namespace tensorflow {
// Notably, resources and functions can be in a state where they reference
// other resources/functions that have not been constructed yet. We collect
// *all* objects in a partially valid state here, then properly initialize
// resources and functions.
// resources and functions. Implementation-wise, PartiallyRevivedObjects
// contains maps keyed by the node number of the SavedObjectGraph, and map to an
// object of the corresponding type. So, if node 2 in the object graph is a
// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a
// tensorflow::Variable object. The only exception to this is the
// "signatures_map", which is keyed by the "signature" key
// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89),
// and maps to the SignatureDefFunction node in the SavedObjectGraph.
struct PartiallyRevivedObjects {
gtl::FlatMap<int, std::unique_ptr<Variable>> variables;
gtl::FlatMap<int, std::unique_ptr<Asset>> assets;
@ -44,6 +51,7 @@ struct PartiallyRevivedObjects {
gtl::FlatMap<int, TFConcreteFunctionRevivalState> concrete_functions;
gtl::FlatMap<int, TFSignatureDefFunctionRevivalState> signature_def_functions;
gtl::FlatMap<int, RestoredResourceRevivalState> restored_resources;
gtl::FlatMap<std::string, int> signatures_map;
Status Build(ImmediateExecutionContext* ctx,
const SavedObjectGraph& obj_graph, RevivedObjects* revived);

View File

@ -44,6 +44,7 @@ struct RevivedObjects {
gtl::FlatMap<int, std::unique_ptr<TFSignatureDefFunction>>
signature_def_functions;
gtl::FlatMap<int, RestoredResource> restored_resources;
gtl::FlatMap<std::string, int> signatures_map;
};
} // namespace tensorflow

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
}
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
const char* raw_device_name,
std::unique_ptr<Variable>* output) {
Status Variable::CreateUninitialized(
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, const char* raw_device_name,
const std::vector<std::string>& component_devices,
std::unique_ptr<Variable>* output) {
ImmediateTensorHandlePtr handle;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, raw_device_name, &handle));
if (component_devices.empty()) {
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, raw_device_name, &handle));
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();
}
if (!tensorflow::isa<EagerContext>(ctx)) {
return errors::InvalidArgument(
"Can only load distributed variables with EagerContext.");
}
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
std::vector<TensorHandle*> handles;
for (const auto& device : component_devices) {
ImmediateTensorHandlePtr handlePtr;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
&handlePtr));
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
return errors::Internal("Returned replica handle has unsupported type.");
}
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
}
TensorHandle* packed_handle;
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
std::move(handles), eager_ctx, &packed_handle));
// The call to `CreatePackedHandle` incremented the handles' reference count,
// which we must now decrement to make the packed handle the owner of those
// handles. We can't loop through the `handles` vector because it was
// `std::move`d in the call above.
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
TensorHandle* component;
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
component->Unref();
}
handle.reset(packed_handle);
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();

View File

@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible {
public:
// Creates an uninitialized resource variable. Note that a caller must
// call "assign" to associate a value with the variable.
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
const char* raw_device_name,
std::unique_ptr<Variable>* output);
static Status CreateUninitialized(
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, const char* raw_device_name,
const std::vector<std::string>& component_devices,
std::unique_ptr<Variable>* output);
// The dtype of the underlying variable.
DataType dtype();

View File

@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const std::string& name = variable.name();
tensorflow::TensorShape shape(variable.shape());
tensorflow::DataType dtype = variable.dtype();
std::vector<std::string> component_devices;
for (const auto& component :
variable.experimental_distributed_variable_components()) {
component_devices.push_back(component.device());
}
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
ctx, dtype, shape, name,
variable.device().empty() ? nullptr : variable.device().c_str(), output));
variable.device().empty() ? nullptr : variable.device().c_str(),
component_devices, output));
return Status();
}
@ -519,6 +526,8 @@ Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
}
}
objects->signatures_map = std::move(signatures_map);
return Status();
}

View File

@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
Status status;
std::unique_ptr<Variable> var;
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
absl::nullopt, nullptr, &var));
absl::nullopt, nullptr, {}, &var));
// Create a TensorHandle
ImmediateTensorHandlePtr expected_handle =

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
namespace tensorflow {
SignatureDefParam::SignatureDefParam(std::string name, TensorSpec spec)
: name_(std::move(name)), spec_(std::move(spec)) {}
const std::string& SignatureDefParam::name() const { return name_; }
const TensorSpec& SignatureDefParam::spec() const { return spec_; }
SignatureDefFunctionMetadata::SignatureDefFunctionMetadata(
std::vector<SignatureDefParam> arguments,
std::vector<SignatureDefParam> returns)
: arguments_(std::move(arguments)), returns_(std::move(returns)) {}
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::arguments()
const {
return arguments_;
}
const std::vector<SignatureDefParam>& SignatureDefFunctionMetadata::returns()
const {
return returns_;
}
} // namespace tensorflow

View File

@ -16,10 +16,42 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_
#include <string>
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
// SignatureDefParam represents a named Tensor input or output to a
// SignatureDefFunction.
class SignatureDefParam {
public:
SignatureDefParam(std::string name, TensorSpec spec);
const std::string& name() const;
const TensorSpec& spec() const;
private:
std::string name_;
TensorSpec spec_;
};
class SignatureDefFunctionMetadata {
// TODO(bmzhao): Fill in with fields as necessary
public:
SignatureDefFunctionMetadata() = default;
SignatureDefFunctionMetadata(std::vector<SignatureDefParam> arguments,
std::vector<SignatureDefParam> returns);
const std::vector<SignatureDefParam>& arguments() const;
const std::vector<SignatureDefParam>& returns() const;
private:
std::vector<SignatureDefParam> arguments_;
std::vector<SignatureDefParam> returns_;
};
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include <initializer_list>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
TensorSpec::TensorSpec()
: shape_(std::initializer_list<int64>()), dtype_(DT_FLOAT) {}
TensorSpec::TensorSpec(PartialTensorShape shape, DataType dtype)
: shape_(std::move(shape)), dtype_(dtype) {}
TensorSpec::TensorSpec(const TensorSpecProto& proto)
: shape_(proto.shape()), dtype_(proto.dtype()) {}
const PartialTensorShape& TensorSpec::shape() const { return shape_; }
DataType TensorSpec::dtype() const { return dtype_; }
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
// Note(bmzhao): TensorSpec deliberately does not store the "name" from a
// TensorSpecProto. From edloper@, "Names should really be associated with
// parameters, not the tensors inside those parameters. This would be
// inconsistent with the corresponding Python class, but I don't think that's
// necessarily a problem. If it turns out later that we really need a name
// attribute here, we can always add it back in; but let's see how far we can
// get without it."
class TensorSpec {
public:
// Constructs a scalar, DT_FLOAT TensorSpec
TensorSpec();
TensorSpec(PartialTensorShape shape, DataType dtype);
explicit TensorSpec(const TensorSpecProto& proto);
const PartialTensorShape& shape() const;
DataType dtype() const;
private:
PartialTensorShape shape_;
DataType dtype_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_

View File

@ -192,9 +192,23 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path,
Status TFSavedModelAPI::GetSignatureDefFunction(
const std::string& signature_def_key, SignatureDefFunction** function) {
// TODO(bmzhao): Add support for retrieving a signaturedef function.
return errors::Unimplemented(
"Retrieving SignatureDef functions is unimplemented currently");
auto signatures_iter =
revived_objects_.signatures_map.find(signature_def_key);
if (signatures_iter == revived_objects_.signatures_map.end()) {
return errors::NotFound("No signature with key ", signature_def_key,
" was found");
}
int node = signatures_iter->second;
auto function_iter = revived_objects_.signature_def_functions.find(node);
if (function_iter == revived_objects_.signature_def_functions.end()) {
return errors::Internal(
"Unable to find SignatureDefFunction associated with key ",
signature_def_key, " despite key being valid.");
}
*function = function_iter->second.get();
return Status();
}
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {

View File

@ -224,6 +224,8 @@ cc_library(
],
deps = [
":signature_def_function_metadata_type",
":signature_def_param_list",
":signature_def_param_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
@ -240,6 +242,104 @@ cc_library(
],
)
cc_library(
name = "signature_def_param",
srcs = [
"signature_def_param.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:signature_def_param.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":signature_def_param_type",
":tensor_spec",
":tensor_spec_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_shape_internal",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
cc_library(
name = "signature_def_param_type",
hdrs = [
"signature_def_param_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
cc_library(
name = "signature_def_param_list",
srcs = [
"signature_def_param_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":signature_def_param",
":signature_def_param_list_type",
":signature_def_param_type",
"//tensorflow/c:c_api_macros",
],
)
cc_library(
name = "signature_def_param_list_type",
hdrs = [
"signature_def_param_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata",
],
)
cc_library(
name = "tensor_spec",
srcs = [
"tensor_spec.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensor_spec.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensor_spec_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_shape",
"//tensorflow/c:tf_shape_internal",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
],
)
cc_library(
name = "tensor_spec_type",
hdrs = [
"tensor_spec_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_shape_internal",
"//tensorflow/c/experimental/saved_model/core:tensor_spec",
],
)
tf_cc_test(
name = "saved_model_api_test",
size = "small",
@ -252,6 +352,8 @@ tf_cc_test(
],
deps = [
":saved_model_api_type",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_shape",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/eager:c_api",
@ -260,6 +362,11 @@ tf_cc_test(
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_api",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/c/experimental/saved_model/public:signature_def_function",
"//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata",
"//tensorflow/c/experimental/saved_model/public:signature_def_param",
"//tensorflow/c/experimental/saved_model/public:signature_def_param_list",
"//tensorflow/c/experimental/saved_model/public:tensor_spec",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -24,6 +24,13 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_shape.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h"
@ -143,6 +150,146 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TFE_DeleteContext(ctx);
}
// This tests running the "serving_default" SignatureDefFunction from the
// VarsAndArithmeticObjectGraph savedmodel. Here's what the signature_defs
// protobuf in the metagraph looks like:
// signature_def: {
// key : "serving_default"
// value: {
// inputs: {
// key : "a"
// value: {
// name : "serving_default_a:0"
// dtype: DT_FLOAT
// tensor_shape: {
// }
// }
// }
// inputs: {
// key : "b"
// value: {
// name : "serving_default_b:0"
// dtype: DT_FLOAT
// tensor_shape: {
// }
// }
// }
// outputs: {
// key : "output_0"
// value: {
// name : "StatefulPartitionedCall:0"
// dtype: DT_FLOAT
// tensor_shape: {
// }
// }
// }
// method_name: "tensorflow/serving/predict"
// }
// }
TEST_P(CSavedModelAPITest, RunsSignatureDefFunction) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
bool use_tfrt = GetParam();
if (use_tfrt) {
TFE_DeleteContextOptions(opts);
TF_DeleteStatus(status);
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_SignatureDefFunction* serving_default =
TF_GetSavedModelSignatureDefFunction(saved_model, "serving_default",
status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_SignatureDefFunctionMetadata* metadata =
TF_SignatureDefFunctionGetMetadata(serving_default);
const TF_SignatureDefParamList* args =
TF_SignatureDefFunctionMetadataArgs(metadata);
const TF_SignatureDefParamList* returns =
TF_SignatureDefFunctionMetadataReturns(metadata);
EXPECT_EQ(TF_SignatureDefParamListSize(args), 2);
const TF_SignatureDefParam* param_a = TF_SignatureDefParamListGet(args, 0);
const TF_TensorSpec* tensor_spec_a = TF_SignatureDefParamTensorSpec(param_a);
const TF_Shape* shape_a = TF_TensorSpecShape(tensor_spec_a);
// Input "a" is a scalar, float32 tensor
EXPECT_EQ("a", std::string(TF_SignatureDefParamName(param_a)));
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_a));
EXPECT_EQ(0, TF_ShapeDims(shape_a));
const TF_SignatureDefParam* param_b = TF_SignatureDefParamListGet(args, 1);
const TF_TensorSpec* tensor_spec_b = TF_SignatureDefParamTensorSpec(param_b);
const TF_Shape* shape_b = TF_TensorSpecShape(tensor_spec_b);
// Input "b" is a scalar, float32 tensor
EXPECT_EQ("b", std::string(TF_SignatureDefParamName(param_b)));
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_b));
EXPECT_EQ(0, TF_ShapeDims(shape_b));
EXPECT_EQ(TF_SignatureDefParamListSize(returns), 1);
const TF_SignatureDefParam* param_out =
TF_SignatureDefParamListGet(returns, 0);
const TF_TensorSpec* tensor_spec_out =
TF_SignatureDefParamTensorSpec(param_out);
const TF_Shape* shape_out = TF_TensorSpecShape(tensor_spec_out);
// Output "output_0" is a scalar, float32 tensor
EXPECT_EQ("output_0", std::string(TF_SignatureDefParamName(param_out)));
EXPECT_EQ(TF_FLOAT, TF_TensorSpecDataType(tensor_spec_out));
EXPECT_EQ(0, TF_ShapeDims(shape_out));
std::vector<TFE_TensorHandle*> compute_fn_inputs;
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
compute_fn_inputs.push_back(input_a);
compute_fn_inputs.push_back(input_b);
TFE_Op* serving_default_op = TF_SignatureDefFunctionMakeCallOp(
serving_default, compute_fn_inputs.data(), compute_fn_inputs.size(),
status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
std::vector<TFE_TensorHandle*> compute_fn_outputs(
TF_SignatureDefParamListSize(returns));
int num_retvals = TF_SignatureDefParamListSize(returns);
TFE_Execute(serving_default_op, compute_fn_outputs.data(), &num_retvals,
status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
float output_value = *static_cast<float*>(TF_TensorData(result));
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
EXPECT_FLOAT_EQ(output_value, 8.0);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
TFE_DeleteTensorHandle(input_a);
TFE_DeleteTensorHandle(input_b);
TFE_DeleteOp(serving_default_op);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
}
TEST_P(CSavedModelAPITest, LoadsAssetSavedModel) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -16,5 +16,18 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
// TODO(bmzhao): Add getter functions here as necessary.
extern "C" {
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataArgs(
const TF_SignatureDefFunctionMetadata* list) {
return tensorflow::wrap(&tensorflow::unwrap(list)->arguments());
}
extern const TF_SignatureDefParamList* TF_SignatureDefFunctionMetadataReturns(
const TF_SignatureDefFunctionMetadata* list) {
return tensorflow::wrap(&tensorflow::unwrap(list)->returns());
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
extern "C" {
extern const char* TF_SignatureDefParamName(const TF_SignatureDefParam* param) {
return tensorflow::unwrap(param)->name().c_str();
}
extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
const TF_SignatureDefParam* param) {
return tensorflow::wrap(&tensorflow::unwrap(param)->spec());
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h"
extern "C" {
extern size_t TF_SignatureDefParamListSize(
const TF_SignatureDefParamList* list) {
return tensorflow::unwrap(list)->size();
}
extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
const TF_SignatureDefParamList* list, int i) {
return tensorflow::wrap(&tensorflow::unwrap(list)->at(i));
}
} // end extern "C"

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(std::vector<SignatureDefParam>,
TF_SignatureDefParamList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h"
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
#include "tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h"
#include "tensorflow/c/tf_shape_internal.h"
extern "C" {
TF_DataType TF_TensorSpecDataType(const TF_TensorSpec* spec) {
return static_cast<TF_DataType>(tensorflow::unwrap(spec)->dtype());
}
const TF_Shape* TF_TensorSpecShape(const TF_TensorSpec* spec) {
return tensorflow::wrap(&tensorflow::unwrap(spec)->shape());
}
} // end extern "C"

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h"
typedef struct TF_TensorSpec TF_TensorSpec;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_

View File

@ -28,6 +28,9 @@ exports_files(
"saved_model_api.h",
"signature_def_function.h",
"signature_def_function_metadata.h",
"signature_def_param.h",
"signature_def_param_list.h",
"tensor_spec.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
@ -45,6 +48,9 @@ cc_library(
":saved_model_api",
":signature_def_function",
":signature_def_function_metadata",
":signature_def_param",
":signature_def_param_list",
":tensor_spec",
],
)
@ -77,3 +83,18 @@ alias(
name = "signature_def_function_metadata",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata",
)
alias(
name = "signature_def_param",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param",
)
alias(
name = "signature_def_param_list",
actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_param_list",
)
alias(
name = "tensor_spec",
actual = "//tensorflow/c/experimental/saved_model/internal:tensor_spec",
)

View File

@ -23,6 +23,9 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
@ -24,6 +27,18 @@ extern "C" {
// SavedModel.
typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata;
// Retrieves the arguments of the SignatureDefFunction. The caller is not
// responsible for freeing the returned pointer.
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
TF_SignatureDefFunctionMetadataArgs(
const TF_SignatureDefFunctionMetadata* list);
// Retrieves the returns of the SignatureDefFunction. The caller is not
// responsible for freeing the returned pointer.
TF_CAPI_EXPORT extern const TF_SignatureDefParamList*
TF_SignatureDefFunctionMetadataReturns(
const TF_SignatureDefFunctionMetadata* list);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that containing metadata of an input/output of a
// TF_SignatureDefFunction loaded from a SavedModel.
typedef struct TF_SignatureDefParam TF_SignatureDefParam;
// Returns the name of the given parameter. The caller is not responsible for
// freeing the returned char*.
TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName(
const TF_SignatureDefParam* param);
// Returns the TensorSpec associated with the given parameter. The caller is
// not reponsible for freeing the returned TF_TensorSpec*.
TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec(
const TF_SignatureDefParam* param);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that containing metadata of an input/output of a
// ConcreteFunction loaded from a SavedModel.
typedef struct TF_SignatureDefParamList TF_SignatureDefParamList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize(
const TF_SignatureDefParamList* list);
// Returns the `i`th TF_SignatureDefParam in the list.
TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet(
const TF_SignatureDefParamList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_shape.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type corresponding to TensorSpec
typedef struct TF_TensorSpec TF_TensorSpec;
// Returns the dtype associated with the TensorSpec.
TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType(
const TF_TensorSpec* spec);
// Returns the shape associated with the TensorSpec. The returned Shape is not
// owned by the caller. Caller must not call TF_DeleteShape on the returned
// shape.
TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape(
const TF_TensorSpec* spec);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_

View File

@ -22,6 +22,8 @@ cc_library(
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib",
"//tensorflow/core/platform:regexp",
"//tensorflow/core/platform:strcat",
"//tensorflow/stream_executor:executor_cache",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:platform",

View File

@ -27,7 +27,10 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
@ -39,6 +42,8 @@ limitations under the License.
using tensorflow::StatusFromTF_Status;
namespace stream_executor {
using tensorflow::StringPiece;
namespace {
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
@ -58,10 +63,35 @@ namespace {
} \
} while (0)
port::Status ValidateDeviceType(StringPiece type) {
// Validate device type. Device type must start with a capital letter and
// consist of capital letters and underscores. Reasoning behind this decision:
// * At the minimum we want to disallow '/' and ':' since
// these characters are used in device spec, for e.g.
// /job:foo/replica:12/device:GPU:1.
// * Underscores seem useful, for e.g. XLA_GPU uses underscores.
// * Allowing lowercase might get confusing. For example, say someone
// registers a new type called "Gpu". It might be confusing for users that
// "Gpu" is not the same device type as "GPU".
// Note that lowercase "cpu" and "gpu" are currently supported only for
// legacy reasons:
// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
if (!matches) {
return port::FailedPreconditionError(
tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
kTfDeviceTypeRegEx->pattern(), "."));
}
return port::Status::OK();
}
port::Status ValidateSPPlatform(const SP_Platform& platform) {
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
VALIDATE_MEMBER(SP_Platform, platform, name);
VALIDATE_MEMBER(SP_Platform, platform, type);
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
// `visible_device_count` could be 0 at initialization time.
return port::Status::OK();
}

View File

@ -52,7 +52,7 @@ limitations under the License.
// params.device = &device;
//
// /* Plugin code below */
// constexpr char DEVICE_NAME[] = "MyDevice";
// constexpr char DEVICE_NAME[] = "MY_DEVICE";
// constexpr char DEVICE_TYPE[] = "GPU";
//
// void create_device(const SP_Platform* platform,
@ -416,10 +416,15 @@ typedef struct SP_Platform {
void* ext; // free-form data set by plugin
// Platform name. Must be null-terminated.
// Platform name (also referred to as subtype), for example MY_DEVICE.
// The name must start with a capital letter and consist of
// capital letters and underscores.
// Must be null-terminated.
const char* name;
// Device type name, for example GPU. Must be null-terminated.
// The name must start with a capital letter and consist of
// capital letters and underscores.
const char* type;
// Number of visible devices

View File

@ -41,9 +41,9 @@ struct SP_Timer_st {
namespace stream_executor {
namespace {
constexpr int DEVICE_COUNT = 2;
constexpr char DEVICE_NAME[] = "MyDevice";
constexpr char DEVICE_TYPE[] = "GPU";
constexpr int kDeviceCount = 2;
constexpr char kDeviceName[] = "MY_DEVICE";
constexpr char kDeviceType[] = "GPU";
/*** Create SP_StreamExecutor (with empty functions) ***/
void allocate(const SP_Device* const device, uint64_t size,
@ -190,9 +190,9 @@ void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
void PopulateDefaultPlatform(SP_Platform* platform,
SP_PlatformFns* platform_fns) {
*platform = {SP_PLATFORM_STRUCT_SIZE};
platform->name = DEVICE_NAME;
platform->type = DEVICE_TYPE;
platform->visible_device_count = DEVICE_COUNT;
platform->name = kDeviceName;
platform->type = kDeviceType;
platform->visible_device_count = kDeviceCount;
platform_fns->create_device = create_device;
platform_fns->destroy_device = destroy_device;
platform_fns->create_device_fns = create_device_fns;
@ -218,11 +218,11 @@ TEST(StreamExecutor, SuccessfulRegistration) {
port::Status status = InitStreamExecutorPlugin(plugin_init);
TF_ASSERT_OK(status);
port::StatusOr<Platform*> maybe_platform =
MultiPlatformManager::PlatformWithName("MyDevice");
MultiPlatformManager::PlatformWithName("MY_DEVICE");
TF_ASSERT_OK(maybe_platform.status());
Platform* platform = maybe_platform.ConsumeValueOrDie();
ASSERT_EQ(platform->Name(), DEVICE_NAME);
ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT);
ASSERT_EQ(platform->Name(), kDeviceName);
ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount);
port::StatusOr<StreamExecutor*> maybe_executor =
platform->ExecutorForDevice(0);
@ -244,6 +244,39 @@ TEST(StreamExecutor, NameNotSet) {
ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set.");
}
TEST(StreamExecutor, InvalidNameWithSemicolon) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->name = "INVALID:NAME";
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(
status.error_message(),
testing::ContainsRegex("Device name/type 'INVALID:NAME' must match"));
}
TEST(StreamExecutor, InvalidNameWithSlash) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->name = "INVALID/";
params->destroy_platform = destroy_platform;
params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION);
EXPECT_THAT(status.error_message(),
testing::ContainsRegex("Device name/type 'INVALID/' must match"));
}
TEST(StreamExecutor, CreateDeviceNotSet) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {

View File

@ -57,43 +57,7 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);
if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst.oper, "updating input tensor");
}
TF_UpdateEdge(graph, new_src, dst, status);
}
void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
@ -136,6 +100,7 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
auto* out_shape_and_type = handle_data.add_shape_and_type();
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
out_shape_and_type->set_dtype(p.dtype);
out_shape_and_type->set_specialized_type(p.specialized_type);
}
}
string result;
@ -163,7 +128,8 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
if (TF_GetCode(status) != TF_OK) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
shape_and_type_proto.specialized_type());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}

39
tensorflow/c/tf_shape.cc Normal file
View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/tf_shape.h"
#include <stdint.h>
#include "tensorflow/c/tf_shape_internal.h"
#include "tensorflow/core/framework/tensor_shape.h"
extern "C" {
TF_Shape* TF_NewShape() {
return tensorflow::wrap(new tensorflow::PartialTensorShape());
}
int TF_ShapeDims(const TF_Shape* shape) {
return tensorflow::unwrap(shape)->dims();
}
int64_t TF_ShapeDimSize(const TF_Shape* shape, int d) {
return tensorflow::unwrap(shape)->dim_size(d);
}
void TF_DeleteShape(TF_Shape* shape) { delete tensorflow::unwrap(shape); }
} // end extern "C"

50
tensorflow/c/tf_shape.h Normal file
View File

@ -0,0 +1,50 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include "tensorflow/c/c_api_macros.h"
#ifndef TENSORFLOW_C_TF_SHAPE_H_
#define TENSORFLOW_C_TF_SHAPE_H_
#ifdef __cplusplus
extern "C" {
#endif
// An opaque type corresponding to a shape in tensorflow. In the future,
// we may expose the ABI of TF_Shape for performance reasons.
typedef struct TF_Shape TF_Shape;
// Return a new, unknown rank shape object. The caller is responsible for
// calling TF_DeleteShape to deallocate and destroy the returned shape.
TF_CAPI_EXPORT extern TF_Shape* TF_NewShape();
// Returns the rank of `shape`. If `shape` has unknown rank, returns -1.
TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape);
// Returns the `d`th dimension of `shape`. If `shape` has unknown rank,
// invoking this function is undefined behavior. Returns -1 if dimension is
// unknown.
TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d);
// Deletes `shape`.
TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_TF_SHAPE_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/core/framework/tensor_shape.h"
typedef struct TF_Shape TF_Shape;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape);
}
#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_

View File

@ -251,7 +251,6 @@ cc_library_with_android_deps(
deps = [
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:protos_all_cc",
],
)
@ -266,7 +265,6 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_experimental",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -15,13 +15,12 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/framework/gradients.h"
namespace tensorflow {
namespace ops {
namespace {
@ -90,15 +89,25 @@ Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
grad_outputs->push_back(Identity(scope, grad_inputs[0]));
grad_outputs->push_back(NoGradient());
grad_outputs->push_back(NoGradient());
Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
Input input = Shape(scope, op.input(0));
Input input_min = op.input(1);
Input input_max = op.input(2);
int64 axis;
TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
scope, grad_inputs[0], input, input_min, input_max,
QuantizeAndDequantizeV4Grad::Axis(axis));
grad_outputs->push_back(qdq_v4_grad.input_backprop);
grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
return scope.status();
}
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
QuantizeAndDequantizeV4GradHelper);
Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,

View File

@ -21,10 +21,7 @@ package(
licenses = ["notice"], # Apache 2.0
)
exports_files([
"LICENSE",
"loader.h",
])
exports_files(["loader.h"])
cc_library(
name = "constants",
@ -45,13 +42,15 @@ cc_library(
name = "reader",
srcs = ["reader.cc"],
hdrs = ["reader.h"],
deps = [":constants"] + if_not_mobile([
deps = [
":constants",
"//tensorflow/core:protos_all_cc",
] + if_not_mobile([
# TODO(b/111634734): :lib and :protos_all contain dependencies that
# cannot be built on mobile platforms. Instead, include the appropriate
# tf_lib depending on the build platform.
"@com_google_absl//absl/memory:memory",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
]),
)

View File

@ -12,8 +12,6 @@ package(
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
cc_library(
name = "freeze_saved_model",
srcs = ["freeze_saved_model.cc"],

View File

@ -75,7 +75,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
"//tensorflow/core/platform:regexp",
] + if_llvm_system_z_available([
"@llvm-project//llvm:SystemZCodeGen", # fixdeps: keep
]) + if_llvm_aarch64_available([

View File

@ -336,9 +336,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:regexp",
"//third_party/eigen3",
"@com_google_absl//absl/strings",
],
@ -559,9 +559,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:regexp",
"//third_party/eigen3",
"@com_google_absl//absl/strings",
],

View File

@ -127,7 +127,7 @@ def tf_library(
"$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools = [tfcompile_tool],
exec_tools = [tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local = 1,
@ -242,7 +242,7 @@ def tf_library(
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
),
tools = [tfcompile_tool],
exec_tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
# Run tfcompile on the build host since it's typically faster on the
@ -281,7 +281,7 @@ def tf_library(
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags
),
tools = [tfcompile_tool],
exec_tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
local = 1,

View File

@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "if_tpu", "tf_copts")
load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts")
load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
# buildifier: disable=same-origin-load
@ -77,7 +77,7 @@ cc_library(
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
] + if_tpu(
] + if_libtpu(
if_false = ["//tensorflow/compiler/xla/service:cpu_plugin"],
if_true = [],
),
@ -114,7 +114,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
] + if_tpu(
] + if_libtpu(
if_false = [
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
],
@ -141,7 +141,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/gpu:gpu_init",
] + if_tpu(
] + if_libtpu(
if_false = [
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
],
@ -204,7 +204,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core:resource_variable_ops_op_lib",
"//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform:stream_executor_no_cuda",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:function_ops",
@ -375,7 +375,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging",
] + if_tpu(
] + if_libtpu(
if_false = [
"//tensorflow/compiler/mlir:array_container_utils",
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
@ -435,6 +435,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@ -1022,10 +1023,10 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session_internal",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime:direct_session_internal",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:partitioned_function_ops",

View File

@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
return Status::OK();
}
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
const Node& node, absl::string_view attr_name,
absl::string_view call_name) {
std::vector<NameAttrList> attr_lists;
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
std::vector<NodeDef> out;
for (int i = 0; i < attr_lists.size(); i++) {
out.emplace_back();
NodeDef& inserted = out.back();
inserted.set_name(absl::StrCat(call_name, "_", i));
inserted.set_op(attr_lists[i].name());
*inserted.mutable_attr() = attr_lists[i].attr();
}
return out;
}
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
return is_compilable;
}
bool RecursiveCompilabilityChecker::IsCompilableCase(
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
xla::StatusOr<std::vector<NodeDef>> calls =
MakeCallNodesFromAttribute(case_node, "branches", "branch");
if (!calls.ok()) {
VLOG(2) << "Rejecting node " << case_node.name() << ": "
<< "missing attribute 'branches'";
return false;
}
bool is_compilable = true;
for (const NodeDef& call : *calls) {
is_compilable &=
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes);
}
return is_compilable;
}
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
return false;
}
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes)) {
LogNotCompilable(node, "unsupported case");
return false;
}
if (!op_filter_.allow_stateful_rng_ops &&
IsStatefulRandomOp(node.type_string())) {
absl::string_view uncompilable_reason = "stateful random op";

View File

@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
// Whether ops known to have numerical accuracy issues should be considered
// compilable..
bool allow_inaccurate_ops = false;
// Require the function to be always compilable, regardless whether some
// control flow branches might be dead for a given input.
bool require_always_compilable = false;
};
RecursiveCompilabilityChecker(OperationFilter op_filter,
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const;
// Tests whether 'case_node' is compilable. Every operator in all branches
// must be compilable.
bool IsCompilableCase(const Node& case_node,
FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const;
// Returns compilability of node def retrieved from `node`'s attribute with
// name `attr_name`.
bool ExtractNodeDefAndCheckCompilability(

View File

@ -34,7 +34,16 @@ limitations under the License.
namespace tensorflow {
namespace {
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
AttrValue attr;
for (const char* name : names) {
attr.mutable_list()->add_func()->set_name(name);
}
return attr;
}
constexpr char kFunctionalIfNodeName[] = "If";
constexpr char kFunctionalCaseNodeName[] = "Case";
constexpr char kFunctionalWhileNodeName[] = "While";
constexpr char kCompilableFunctionName[] = "CompilableFn";
constexpr char kCompilableFunctionNodeName[] = "n_c";
@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
op_filter_.allow_inaccurate_ops = false;
op_filter_.allow_slow_ops = false;
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
device_type_);
checker_ = CreateCompilabilityChecker();
}
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
device_type_);
}
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
"unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_c_uncompilable:float"},
/*Attributes*/ {},
// Node info
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionTwoName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_d_uncompilable:float"},
/*Attribute*/ {},
// Node info
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
std::vector<NodeBuilder::NodeOut> inputes(
{NodeBuilder::NodeOut(placeholder.node())});
Node* case_node;
TF_ASSERT_OK(
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
.Input(branch_index.node())
.Input(inputes)
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
kUncompilableFunctionTwoName}))
.Attr("Tout", {DT_INT32})
.Finalize(root.graph(), &case_node));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
auto case_node_it = std::find_if(
graph->nodes().begin(), graph->nodes().end(),
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
EXPECT_NE(case_node_it, graph->nodes().end());
auto* flib_runtime = GetFunctionLibraryRuntime();
op_filter_.require_always_compilable = false;
checker_ = CreateCompilabilityChecker();
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
op_filter_.require_always_compilable = true;
checker_ = CreateCompilabilityChecker();
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
}
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/status.h"
@ -47,8 +48,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev,
absl::Span<const Tensor* const> inputs) {
absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const TensorHandle* const> inputs_handles) {
NameAttrList function;
function.set_name(std::string{func_name});
@ -65,6 +66,25 @@ xla::StatusOr<std::string> GetCompilerIr(
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
std::deque<Tensor> inputs_storage;
std::vector<const Tensor*> inputs;
inputs.reserve(inputs_handles.size());
for (int i = 0; i < inputs_handles.size(); i++) {
const TensorHandle* th = inputs_handles[i];
const Tensor* t;
// Handle owns the tensor.
TF_RETURN_IF_ERROR(th->Tensor(&t));
if (absl::c_binary_search(constant_arg_indices, i)) {
// Need to make sure it's on the host.
inputs_storage.emplace_back(t->dtype(), t->shape());
TF_RETURN_IF_ERROR(
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
inputs.push_back(&inputs_storage.back());
} else {
inputs.push_back(t);
}
}
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
rmgr, dev, inputs, resource_arg_indices, &variable_infos));

View File

@ -24,6 +24,8 @@ namespace tensorflow {
class ProcessFunctionLibraryRuntime;
class Device;
class Tensor;
class TensorHandle;
class EagerContext;
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
// `runtime` on a device `dev` with given `inputs`.
xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev,
absl::Span<const Tensor* const> inputs);
absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const TensorHandle* const> inputs);
} // namespace tensorflow

View File

@ -34,7 +34,7 @@ XLA_OPS_DEPS = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/platform:stream_executor_no_cuda",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter",
]

View File

@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue;
}
if (!RecursiveCompilabilityChecker{
CreateOperationFilter(*registration),
DeviceType{registration->compilation_device_name}}
.IsCompilableNode(*node, lib_runtime)) {
RecursiveCompilabilityChecker::OperationFilter filter =
CreateOperationFilter(*registration);
filter.require_always_compilable = true;
RecursiveCompilabilityChecker checker(
filter, DeviceType{registration->compilation_device_name});
if (!checker.IsCompilableNode(*node, lib_runtime)) {
continue;
}
@ -2062,6 +2066,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"XlaSpmdFullToShardShape",
"XlaSpmdShardToFullShape",
"XlaSvd",
"XlaVariadicReduce",
"XlaWhile",
"Zeta",
"_Arg",

View File

@ -47,7 +47,7 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/dump_graph.h"
#if !defined(LIBTFTPU)
#if !defined(LIBTPU_ON_GCE)
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/compiler/mlir/utils/array_container_utils.h"
#endif
@ -289,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
});
const ConfigProto* config = ctx->function_library()->config_proto();
bool use_mlir = config && config->experimental().enable_mlir_bridge();
#ifdef LIBTFTPU
#ifdef LIBTPU_ON_GCE
if (use_mlir && has_tensor_list_arg) {
LOG(WARNING) << "MLIR is not supported in this environment.";
}
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
}
GraphDebugInfo debug_info;
std::vector<std::string> control_rets;
if (result_dtypes.empty()) {
control_rets.push_back(node_def.name());
}
return CompileGraphToXlaHlo(
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
#endif

View File

@ -9,3 +9,31 @@ dialects and utilities for
3. TF Lite
See [MLIR's website](https://mlir.llvm.org) for complete documentation.
## Getting started
Building dialects and utilities here follow the standard approach using
`bazel` as the rest of TensorFlow.
### Using local LLVM repo
To develop across MLIR core and TensorFlow, it is useful to override the repo
to use a local version instead of fetching from head. This can be achieved as
below but note, the BUILD files are not automatically generated from or CMake
used, so if your change requires a BUILD file change (or you are using a
different version of LLVM than set in tensorflow/workspace.bzl's LLVM_COMMIT)
then manual BUILD file changes may be required.
```sh
LLVM_SRC=...
# Create basic workspace file
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
# and copy over the bazel BUILD files.
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
cp third_party/mlir/BUILD $LLVM_SRC/mlir
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD
bazel build --override_repository=llvm-project=$LLVM_SRC \
-c opt tensorflow/compiler/mlir:tf-opt
```

View File

@ -48,6 +48,7 @@ filegroup(
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
@ -539,6 +540,8 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",

View File

@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
}];
}
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
HLO_FpOrComplexTensor> {
let summary = "Atan operator";
let description = [{
Returns `Atan(operand)` element-wise.
$$
\atan(x) = \atan2(x, 1)
$$
}];
}
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
HLO_FpOrComplexTensor> {
let summary = "Sinh operation";

View File

@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
>];
}
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
@ -193,12 +196,10 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one",
def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp;
def HLO_ImagOp: HLO_Op<
"imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp {
let builders = [OpBuilder<
"OpBuilder &, OperationState &tblgen_state, Value val">];
let arguments = (ins HLO_ComplexTensor);
def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag",
[NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>],
HLO_ComplexTensor>, BASE_HLO_ImagOp {
let results = (outs HLO_FpTensor);
let hasFolder = 1;
}
@ -237,12 +238,10 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt",
[NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>,
BASE_HLO_PopulationCountOp;
def HLO_RealOp: HLO_Op<
"real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp {
let builders = [OpBuilder<
"OpBuilder &, OperationState &tblgen_state, Value val">];
let arguments = (ins HLO_ComplexTensor);
def HLO_RealOp: HLO_UnaryElementwiseOp<"real",
[NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>],
HLO_ComplexTensor>, BASE_HLO_RealOp {
let results = (outs HLO_FpTensor);
let hasFolder = 1;
}
@ -321,12 +320,10 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add",
def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op;
def HLO_ComplexOp: HLO_Op<"complex",
[NoSideEffect, SameOperandsAndResultShape]>,
def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex",
[NoSideEffect, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
BASE_HLO_ComplexOp {
let builders = [OpBuilder<
"OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">];
let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs);
let results = (outs HLO_ComplexTensor);
let hasFolder = 1;
@ -356,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
let hasFolder = 1;
}
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
@ -913,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
let results = (outs HLO_Tensor);
}
// TODO(hinsu): Make this struct dialect independent so that it can be shared
// between HLO and LHLO dialect.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
let arguments = !con(
(ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs),
ConvolutionAttributes<HLO_Dialect>.attributes);
let results = (outs HLO_Tensor);
}
@ -1198,14 +1170,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
let results = (outs HLO_Tensor);
}
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp {
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
let arguments = (ins
Variadic<HLO_Tensor>:$operands,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
);
let results = (outs HLO_TensorOrTuple);
let results = (outs Variadic<HLO_Tensor>);
let regions = (region SizedRegion<1>:$comparator);
@ -1429,4 +1401,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
let hasCustomHLOConverter = 1;
}
// This is an op for purposes internal to XLA/GPU.
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
let arguments = (ins HLO_Tensor:$operand);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
BASE_HLO_ReducePrecisionOp {
let arguments = (ins
HLO_FpTensor:$operand,
I32Attr:$exponent_bits,
I32Attr:$mantissa_bits
);
let results = (outs HLO_FpTensor:$output);
}
#endif // HLO_OPS

View File

@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
}];
}
class BASE_HLO_CbrtOp {
string summary = "Cubic root operator";
string description = [{
Returns element-wise cubic root of the operand.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_CeilOp {
string summary = "Ceil operator";
@ -996,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
}];
}
//===----------------------------------------------------------------------===//
// Common convolution attributes
//===----------------------------------------------------------------------===//
class ConvDimensionNumbersBase<Dialect dialect>
: StructAttr<"ConvDimensionNumbers", dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
class ConvolutionAttributes<Dialect dialect> {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}
class BASE_HLO_ConvOp {
string summary = "Convolution operator";
@ -1336,4 +1383,17 @@ class BASE_HLO_WhileOp {
}];
}
class BASE_HLO_BitcastOp {
string summary = "Bitcast operator";
string description = [{
This op changes the shape of the input in the way that the physical
arranggment of elements are unchanged.
However, the op needs layout information to make sense of "physical
arrangement of elements". Layout support in MHLO is currently under
exploration.
}];
}
#endif // HLO_OPS_BASE

View File

@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
def LHLO_Dialect : Dialect {
let name = "lmhlo";
let cppNamespace = "::mlir::lmhlo";
}
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
//===----------------------------------------------------------------------===//
// LMHLO nullary op definitions.
//===----------------------------------------------------------------------===//
@ -289,6 +264,16 @@ def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
StrAttr:$call_target_name,
DefaultValuedAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedAttr<StrAttr, "">:$backend_config
);
}
//===----------------------------------------------------------------------===//
// LMHLO tuple op definitions.
//===----------------------------------------------------------------------===//
@ -335,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = [{
"modifies the offset, sizes and strides of a statically shaped memref.
modifies the offset, sizes and strides of a statically shaped memref
}];
let description = [{
Allows to modify the offset, sizes and strides of a statically shaped memref.
Casts the statically shaped memref operand to a memref with optionally
modified offsets, sizes and strides.
Example:
```mlir
@ -354,12 +340,11 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
"Value operand", [{
result.addOperands(operand);
result.types.push_back(resultType);
}]>];
let builders = [OpBuilder<"MemRefType resultType, Value operand",
[{
$_state.addOperands(operand);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
@ -400,13 +385,13 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
"Value operand, ValueRange sizes, ValueRange strides", [{
result.addOperands(operand);
result.addOperands(sizes);
result.addOperands(strides);
result.types.push_back(resultType);
let builders = [
OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, "
"ValueRange strides", [{
$_state.addOperands(operand);
$_state.addOperands(sizes);
$_state.addOperands(strides);
$_state.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
@ -582,40 +567,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
);
}
// TODO(bondhugula): Make this struct dialect independent so that it can be
// shared between the HLO and LHLO dialects.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
ConvolutionAttributes<LHLO_Dialect>.attributes);
}
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
@ -856,9 +814,8 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
"ArrayRef<NamedAttribute> attributes">
];
OpBuilder<"ArrayRef<NamedAttribute> attributes">
];
}
def TerminatorOp :
@ -867,9 +824,8 @@ def TerminatorOp :
let description = [{
Terminator operation for the LHLO dialect.
}];
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, ValueRange operands",
[{ build(b, result, llvm::None, operands, llvm::None); }]
let builders = [OpBuilder<"ValueRange operands",
[{ build($_builder, $_state, llvm::None, operands, llvm::None); }]
>];
}

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_OPS_BASE
#define LHLO_OPS_BASE
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
#endif // LHLO_OPS_BASE

View File

@ -50,6 +50,7 @@ MAP_HLO_TO_LHLO(ConvOp);
MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(CustomCallOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(ExpOp);
@ -57,11 +58,13 @@ MAP_HLO_TO_LHLO(FloorOp);
MAP_HLO_TO_LHLO(GatherOp);
MAP_HLO_TO_LHLO(ImagOp);
MAP_HLO_TO_LHLO(IotaOp);
MAP_HLO_TO_LHLO(IsFiniteOp);
MAP_HLO_TO_LHLO(LogOp);
MAP_HLO_TO_LHLO(MaxOp);
MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);

View File

@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
loc, result_types, args, b);
}
template <typename PredicateType>
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
return llvm::None;
@ -345,6 +354,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
if (args[0].getType().isa<FloatType>()) {
auto pos_inf = APFloat::getInf(
args[0].getType().cast<FloatType>().getFloatSemantics());
auto const_pos_inf =
b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
const_pos_inf);
}
return nullptr;
}
/// Implements the conversion of HLO op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args>
@ -431,6 +456,21 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// lmhlo.not(x) -> x ^ -1
auto all_ones =
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,
@ -454,11 +494,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
FloatType float_type = element_type.cast<FloatType>();
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
if (auto float_type = element_type.dyn_cast<FloatType>()) {
bool ignored;
APFloat one_apfloat(1.0f);
one_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
loc, integer_type.getWidth() - 1, integer_type.getWidth());
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value one =
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
}
return nullptr;
}

View File

@ -15,9 +15,9 @@ limitations under the License.
include "mlir/Pass/PassBase.td"
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
let constructor = "createTestChloLegalizeToHloPass()";
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Legalize CHLO to HLO.";
let constructor = "createChloLegalizeToHloPass()";
}
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {

View File

@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
/// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from the CHLO dialect to the HLO dialect.
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the
@ -63,7 +66,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
/// Lowers trigonometric operations from the standard dialect to approximations
// that do not use intrinsics.
/// that do not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>>
createLegalizeTrigonometricToApproximationPass();

View File

@ -22,7 +22,6 @@ limitations under the License.
namespace mlir {
namespace mhlo {
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/BufferPlacement.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {

View File

@ -185,8 +185,7 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
return failure();
const auto& dnums = gather.dimension_numbers();
if (dnums.collapsed_slice_dims().getNumElements() != 0 ||
dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
return failure();
// TODO(tberghammer): Remove when the verifier catches this case what is
@ -206,11 +205,35 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
}
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
rewriter.replaceOpWithNewOp<SliceOp>(
gather, gather.getType(), gather.getOperand(0),
llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
for (int64_t i = 0; i < slice_end.size(); ++i) {
slice_shape[i] = slice_end[i] - slice_start[i];
}
Type element_type = gather.getType().cast<TensorType>().getElementType();
auto slice_type = RankedTensorType::get(slice_shape, element_type);
Value result = rewriter.create<SliceOp>(
gather.getLoc(), slice_type, gather.getOperand(0),
GetI64ElementsAttr(slice_start, &rewriter),
GetI64ElementsAttr(slice_end, &rewriter),
GetI64ElementsAttr(slice_stride, &rewriter));
if (dnums.collapsed_slice_dims().getNumElements() > 0) {
auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
dnums.collapsed_slice_dims().getIntValues(),
[](const llvm::APInt& i) { return i.getSExtValue(); }));
llvm::SmallVector<int64_t, 8> reshape_shape;
for (int64_t i = 0; i < slice_shape.size(); ++i) {
if (llvm::count(collapsed_slice_dims, i) == 0) {
reshape_shape.push_back(slice_shape[i]);
}
}
auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
result =
rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
}
result.setType(gather.getType());
rewriter.replaceOp(gather, result);
return success();
}
};
@ -889,9 +912,10 @@ static LogicalResult Verify(ClampOp op) {
// ComplexOp
//===----------------------------------------------------------------------===//
void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
Value rhs) {
auto type = lhs.getType();
LogicalResult ComplexOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
auto type = operands[0].getType();
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
Type result_ty;
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
@ -901,8 +925,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
} else {
result_ty = element_ty;
}
build(builder, state, result_ty, lhs, rhs);
inferredReturnTypes.push_back(result_ty);
return success();
}
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
@ -932,8 +956,11 @@ Type CreateRealType(Type type) {
}
} // namespace
void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val.getType()), val);
LogicalResult ImagOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
return success();
}
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
@ -945,8 +972,11 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
return {};
}
void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val.getType()), val);
LogicalResult RealOp::inferReturnTypes(
MLIRContext*, Optional<Location>, ValueRange operands, DictionaryAttr,
RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(CreateRealType(operands[0].getType()));
return success();
}
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
@ -1971,6 +2001,23 @@ struct divide<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
};
template <typename T>
struct remainder : std::modulus<T> {};
template <>
struct remainder<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
};
template <>
struct remainder<APFloat> {
APFloat operator()(const APFloat& a, const APFloat& b) const {
APFloat result(a);
result.remainder(b);
return result;
}
};
template <typename T>
struct max {
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
@ -2012,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(MulOp, std::multiplies);
BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(RemOp, remainder);
BINARY_FOLDER(MaxOp, max);
BINARY_FOLDER(MinOp, min);
@ -2261,10 +2309,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state,
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
SmallVector<Type, 2> element_types;
element_types.reserve(operands.size());
for (Value operand : operands) element_types.push_back(operand.getType());
state.addTypes(builder.getTupleType(element_types));
for (Value operand : operands) state.addTypes(operand.getType());
state.addRegion();
}

View File

@ -283,7 +283,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
auto if_op = rewriter.create<scf::IfOp>(
loc, result_type, IsScalarTensor(rewriter, op, lhs), true);
OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder();
Value reshaped_lhs = if_lhs_scalar_builder.create<mhlo::ReshapeOp>(
Value reshaped_lhs = if_lhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs);
Value if_lhs_scalar_result = if_lhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{reshaped_lhs, rhs},
@ -300,7 +300,7 @@ struct ConvertUnrankedDynamicBroadcastBinaryOp
else_lhs_scalar_builder.create<scf::YieldOp>(loc,
if_rhs_scalar_op.getResult(0));
OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder();
Value reshaped_rhs = if_rhs_scalar_builder.create<mhlo::ReshapeOp>(
Value reshaped_rhs = if_rhs_scalar_builder.create<TensorCastOp>(
loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs);
Value if_rhs_scalar_result = if_rhs_scalar_builder.create<ChloOpTy>(
loc, ArrayRef<Type>{result_type}, ArrayRef<Value>{lhs, reshaped_rhs},
@ -516,7 +516,7 @@ struct HloCompareAdaptor {
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
populateWithGenerated(context, patterns);
populateWithGenerated(context, *patterns);
// Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do

Some files were not shown because too many files have changed in this diff Show More