Merge branch 'master' into wsign-compare-semi-final-lite-python-stream-executor
This commit is contained in:
commit
9424fb57d2
7
.bazelrc
7
.bazelrc
@ -28,6 +28,7 @@
|
||||
#
|
||||
# Other build options:
|
||||
# short_logs: Only log errors during build, skip warnings.
|
||||
# verbose_logs: Show all compiler warnings during build.
|
||||
# monolithic: Build all TF C++ code into a single shared object.
|
||||
# dynamic_kernels: Try to link all kernels dynamically (experimental).
|
||||
# libc++: Link against libc++ instead of stdlibc++
|
||||
@ -331,6 +332,8 @@ build:windows --distinct_host_configuration=false
|
||||
|
||||
# Suppress all warning messages.
|
||||
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||
build:verbose_logs --output_filter=
|
||||
build --config=short_logs
|
||||
|
||||
# Instruction set optimizations
|
||||
# TODO(gunan): Create a feature in toolchains for avx/avx2 to
|
||||
@ -547,6 +550,7 @@ try-import %workspace%/.bazelrc.user
|
||||
# Here are bazelrc configs for release builds
|
||||
build:release_common --config=opt
|
||||
build:release_common --config=v2
|
||||
build:release_common --distinct_host_configuration=false
|
||||
build:release_common --action_env TF_CONFIGURE_IOS="0"
|
||||
|
||||
build:release_cpu_linux --config=release_common
|
||||
@ -564,9 +568,10 @@ build:release_gpu_common --config=tensorrt
|
||||
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
|
||||
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
|
||||
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
|
||||
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
|
||||
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
|
||||
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
|
||||
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/tensorrt/lib"
|
||||
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
|
||||
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
|
||||
|
||||
|
||||
|
131
RELEASE.md
131
RELEASE.md
@ -31,7 +31,6 @@
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* TF Core:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
|
||||
type annotation for variables representing a Tensor or a value that can be
|
||||
converted to Tensor by `tf.convert_to_tensor`.
|
||||
@ -39,9 +38,18 @@
|
||||
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
|
||||
truncating inputs such as from int64 to int32.
|
||||
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
|
||||
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
|
||||
and `__invert__` now support non-`bool` arguments and apply the
|
||||
corresponding bitwise ops. `bool` arguments continue to be supported and
|
||||
dispatch to logical ops. This brings them more in line with Python and NumPy
|
||||
benavior.
|
||||
* `tf.data`:
|
||||
* Added new `tf.data.experimental.service.register_dataset` and
|
||||
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
|
||||
to register a dataset with the tf.data service, and another process to
|
||||
consume data from the dataset.
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
* We have implemented an optimization which reorders data-discarding
|
||||
transformations such as `take` and `shard` to happen earlier in the
|
||||
dataset when it is safe to do so. The optimization can be disabled via
|
||||
@ -51,11 +59,12 @@
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.keras`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.function`/AutoGraph:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
performance.
|
||||
* `tf.lite`:
|
||||
* Better support for ops with high-dimensional broadcasting inputs by adding
|
||||
`BroadcastTo` ops when necessary.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Math and Linear Algebra:
|
||||
@ -219,7 +228,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
|
||||
`Strategy.extended.update` and `Strategy.extended.update_non_slot`.
|
||||
* Experimental support for shape invariants has been enabled in
|
||||
`tf.function`. See the API docs for
|
||||
`tf.autograph.experimental.set_loop_options` for additonal info.
|
||||
`tf.autograph.experimental.set_loop_options` for additional info.
|
||||
* AutoGraph error messages now exclude frames corresponding to APIs
|
||||
internal to AutoGraph.
|
||||
* Improve shape inference for `tf.function` input arguments to unlock more
|
||||
@ -302,7 +311,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
|
||||
also deterministic back-prop of bias-addition in Keras layers) to
|
||||
include when XLA JIT compilation is enabled.
|
||||
* Fix problem, when running on a CUDA GPU and when either environment
|
||||
variable `TF_DETERMINSTIC_OPS` or environment variable
|
||||
variable `TF_DETERMINISTIC_OPS` or environment variable
|
||||
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
|
||||
configurations led to an exception with the message "No algorithm
|
||||
worked!"
|
||||
@ -345,32 +354,86 @@ This release contains contributions from many people at Google, as well as:
|
||||
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.
|
||||
|
||||
## Major Features and Improvements
|
||||
* The `tensorflow` pip package now includes GPU support by default (same as `tensorflow-gpu`) for both Linux and Windows. This runs on machines with and without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
|
||||
* **Windows users:** Officially-released `tensorflow` Pip packages are now built with Visual Studio 2019 version 16.4 in order to take advantage of the new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new packages, you must install "Microsoft Visual C++ Redistributable for Visual Studio 2015, 2017 and 2019", available from Microsoft's website [here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads).
|
||||
* This does not change the minimum required version for building TensorFlow from source on Windows, but builds enabling `EIGEN_STRONG_INLINE` can take over 48 hours to compile without this flag. Refer to `configure.py` for more information about `EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`.
|
||||
* If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll` (new), are missing on your machine, `import tensorflow` will print a warning message.
|
||||
* The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6.
|
||||
* `tf.keras`
|
||||
* Experimental support for mixed precision is available on GPUs and Cloud TPUs. See [usage guide](https://www.tensorflow.org/guide/keras/mixed_precision).
|
||||
* Introduced the `TextVectorization` layer, which takes as input raw strings and takes care of text standardization, tokenization, n-gram generation, and vocabulary indexing. See this [end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3).
|
||||
* Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be outside of the DistributionStrategy scope, as long as the model was constructed inside of a scope.
|
||||
* Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and `.predict` is available for Cloud TPUs, Cloud TPU, for all types of Keras models (sequential, functional and subclassing models).
|
||||
* Automatic outside compilation is now enabled for Cloud TPUs. This allows `tf.summary` to be used more conveniently with Cloud TPUs.
|
||||
* Dynamic batch sizes with DistributionStrategy and Keras are supported on Cloud TPUs.
|
||||
* Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in addition to `tf.data.Dataset`.
|
||||
* Keras reference implementations for many popular models are available in the TensorFlow [Model Garden](https://github.com/tensorflow/models/tree/master/official).
|
||||
* `tf.data`
|
||||
* Changes rebatching for `tf.data datasets` + DistributionStrategy for better performance. Note that the dataset also behaves slightly differently, in that the rebatched dataset cardinality will always be a multiple of the number of replicas.
|
||||
* `tf.data.Dataset` now supports automatic data distribution and sharding in distributed environments, including on TPU pods.
|
||||
* Distribution policies for `tf.data.Dataset` can now be tuned with 1. `tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2. `tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)`
|
||||
* `tf.debugging`
|
||||
* Add `tf.debugging.enable_check_numerics()` and `tf.debugging.disable_check_numerics()` to help debugging the root causes of issues involving infinities and `NaN`s.
|
||||
* `tf.distribute`
|
||||
* Custom training loop support on TPUs and TPU pods is avaiable through `strategy.experimental_distribute_dataset`, `strategy.experimental_distribute_datasets_from_function`, `strategy.experimental_run_v2`, `strategy.reduce`.
|
||||
* Support for a global distribution strategy through `tf.distribute.experimental_set_strategy(),` in addition to `strategy.scope()`.
|
||||
* `TensorRT`
|
||||
* [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new) is now supported and enabled by default. This adds support for more TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D, MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the TensorFlow-TensorRT python conversion API is exported as `tf.experimental.tensorrt.Converter`.
|
||||
* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to "true" or "1", this environment variable makes `tf.nn.bias_add` operate deterministically (i.e. reproducibly), but currently only when XLA JIT compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or "1" also makes cuDNN convolution and max-pooling operate deterministically. This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in both the forward and backward directions when running on a CUDA-enabled GPU.
|
||||
|
||||
* The `tensorflow` pip package now includes GPU support by default (same as
|
||||
`tensorflow-gpu`) for both Linux and Windows. This runs on machines with and
|
||||
without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only
|
||||
packages can be downloaded at `tensorflow-cpu` for users who are concerned
|
||||
about package size.
|
||||
* **Windows users:** Officially-released `tensorflow` Pip packages are now
|
||||
built with Visual Studio 2019 version 16.4 in order to take advantage of the
|
||||
new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new
|
||||
packages, you must install "Microsoft Visual C++ Redistributable for Visual
|
||||
Studio 2015, 2017 and 2019", available from Microsoft's website
|
||||
[here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads).
|
||||
* This does not change the minimum required version for building
|
||||
TensorFlow from source on Windows, but builds enabling
|
||||
`EIGEN_STRONG_INLINE` can take over 48 hours to compile without this
|
||||
flag. Refer to `configure.py` for more information about
|
||||
`EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`.
|
||||
* If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll`
|
||||
(new), are missing on your machine, `import tensorflow` will print a
|
||||
warning message.
|
||||
* The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6.
|
||||
* `tf.keras`
|
||||
* Experimental support for mixed precision is available on GPUs and Cloud
|
||||
TPUs. See
|
||||
[usage guide](https://www.tensorflow.org/guide/keras/mixed_precision).
|
||||
* Introduced the `TextVectorization` layer, which takes as input raw
|
||||
strings and takes care of text standardization, tokenization, n-gram
|
||||
generation, and vocabulary indexing. See this
|
||||
[end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3).
|
||||
* Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be
|
||||
outside of the DistributionStrategy scope, as long as the model was
|
||||
constructed inside of a scope.
|
||||
* Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and
|
||||
`.predict` is available for Cloud TPUs, Cloud TPU, for all types of
|
||||
Keras models (sequential, functional and subclassing models).
|
||||
* Automatic outside compilation is now enabled for Cloud TPUs. This allows
|
||||
`tf.summary` to be used more conveniently with Cloud TPUs.
|
||||
* Dynamic batch sizes with DistributionStrategy and Keras are supported on
|
||||
Cloud TPUs.
|
||||
* Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in
|
||||
addition to `tf.data.Dataset`.
|
||||
* Keras reference implementations for many popular models are available in
|
||||
the TensorFlow
|
||||
[Model Garden](https://github.com/tensorflow/models/tree/master/official).
|
||||
* `tf.data`
|
||||
* Changes rebatching for `tf.data datasets` + DistributionStrategy for
|
||||
better performance. Note that the dataset also behaves slightly
|
||||
differently, in that the rebatched dataset cardinality will always be a
|
||||
multiple of the number of replicas.
|
||||
* `tf.data.Dataset` now supports automatic data distribution and sharding
|
||||
in distributed environments, including on TPU pods.
|
||||
* Distribution policies for `tf.data.Dataset` can now be tuned with 1.
|
||||
`tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2.
|
||||
`tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)`
|
||||
* `tf.debugging`
|
||||
* Add `tf.debugging.enable_check_numerics()` and
|
||||
`tf.debugging.disable_check_numerics()` to help debugging the root
|
||||
causes of issues involving infinities and `NaN`s.
|
||||
* `tf.distribute`
|
||||
* Custom training loop support on TPUs and TPU pods is available through
|
||||
`strategy.experimental_distribute_dataset`,
|
||||
`strategy.experimental_distribute_datasets_from_function`,
|
||||
`strategy.experimental_run_v2`, `strategy.reduce`.
|
||||
* Support for a global distribution strategy through
|
||||
`tf.distribute.experimental_set_strategy(),` in addition to
|
||||
`strategy.scope()`.
|
||||
* `TensorRT`
|
||||
* [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new)
|
||||
is now supported and enabled by default. This adds support for more
|
||||
TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D,
|
||||
MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the
|
||||
TensorFlow-TensorRT python conversion API is exported as
|
||||
`tf.experimental.tensorrt.Converter`.
|
||||
* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to
|
||||
"true" or "1", this environment variable makes `tf.nn.bias_add` operate
|
||||
deterministically (i.e. reproducibly), but currently only when XLA JIT
|
||||
compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or
|
||||
"1" also makes cuDNN convolution and max-pooling operate deterministically.
|
||||
This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in
|
||||
both the forward and backward directions when running on a CUDA-enabled GPU.
|
||||
|
||||
## Breaking Changes
|
||||
* Deletes `Operation.traceback_with_start_lines` for which we know of no usages.
|
||||
|
@ -262,6 +262,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,11 +18,12 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
// execution mode.
|
||||
class AbstractTensorHandle {
|
||||
class AbstractTensorHandle : public core::RefCounted {
|
||||
protected:
|
||||
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
|
||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||
@ -34,14 +35,6 @@ class AbstractTensorHandle {
|
||||
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
private:
|
||||
const AbstractTensorHandleKind kind_;
|
||||
};
|
||||
@ -50,7 +43,7 @@ namespace internal {
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
p->Unref();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -147,7 +147,7 @@ TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||
|
@ -49,7 +49,6 @@ class GraphTensor : public TracingTensorHandle {
|
||||
public:
|
||||
explicit GraphTensor(TF_Output output)
|
||||
: TracingTensorHandle(kGraph), output_(output) {}
|
||||
void Release() override { delete this; }
|
||||
|
||||
tensorflow::DataType DataType() const override {
|
||||
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
||||
|
@ -51,25 +51,14 @@ int64 ToId(AbstractTensorHandle* t) {
|
||||
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
|
||||
: handle_(handle), ctx_(ctx) {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
handle_->Ref();
|
||||
}
|
||||
TapeTensor::TapeTensor(const TapeTensor& other) {
|
||||
handle_ = other.handle_;
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
handle_->Ref();
|
||||
ctx_ = other.ctx_;
|
||||
}
|
||||
TapeTensor::~TapeTensor() {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Unref();
|
||||
}
|
||||
TapeTensor::~TapeTensor() { handle_->Unref(); }
|
||||
|
||||
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
|
||||
|
||||
@ -192,7 +181,7 @@ TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
|
||||
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
|
||||
|
||||
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
||||
gradient->Release();
|
||||
gradient->Unref();
|
||||
}
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
|
@ -93,7 +93,7 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Release();
|
||||
add_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
@ -144,14 +144,14 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(output_list.outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Release();
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
output_list.outputs[0]->Release();
|
||||
output_list.outputs[1]->Release();
|
||||
output_list.outputs[0]->Unref();
|
||||
output_list.outputs[1]->Unref();
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Release();
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
@ -260,7 +260,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Release();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
@ -270,7 +270,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
@ -50,6 +50,14 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
// Return a copy of the handle.
|
||||
virtual ImmediateExecutionTensorHandle* Copy() = 0;
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
|
@ -448,6 +448,41 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
|
||||
stat_cache_max_age, stat_cache_max_entries);
|
||||
}
|
||||
|
||||
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client, bool compose,
|
||||
uint64_t block_size, size_t max_bytes, uint64_t max_staleness,
|
||||
uint64_t stat_cache_max_age, size_t stat_cache_max_entries)
|
||||
: gcs_client(gcs_client),
|
||||
compose(compose),
|
||||
block_cache_lock(),
|
||||
block_size(block_size) {
|
||||
file_block_cache = std::make_unique<RamFileBlockCache>(
|
||||
block_size, max_bytes, max_staleness,
|
||||
[this](const std::string& filename, size_t offset, size_t buffer_size,
|
||||
char* buffer, TF_Status* status) {
|
||||
return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this,
|
||||
status);
|
||||
});
|
||||
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
|
||||
stat_cache_max_age, stat_cache_max_entries);
|
||||
}
|
||||
|
||||
void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size,
|
||||
size_t max_bytes, uint64_t max_staleness,
|
||||
uint64_t stat_cache_max_age, size_t stat_cache_max_entries,
|
||||
TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
if (!client) {
|
||||
TF_SetStatusFromGCSStatus(client.status(), status);
|
||||
return;
|
||||
}
|
||||
|
||||
filesystem->plugin_filesystem =
|
||||
new GCSFile(std::move(client.value()), compose, block_size, max_bytes,
|
||||
max_staleness, stat_cache_max_age, stat_cache_max_entries);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
|
@ -62,8 +62,19 @@ typedef struct GCSFile {
|
||||
// of block_size.
|
||||
std::unique_ptr<ExpiringLRUCache<GcsFileStat>> stat_cache;
|
||||
GCSFile(google::cloud::storage::Client&& gcs_client);
|
||||
// This constructor is used for testing purpose only.
|
||||
GCSFile(google::cloud::storage::Client&& gcs_client, bool compose,
|
||||
uint64_t block_size, size_t max_bytes, uint64_t max_staleness,
|
||||
uint64_t stat_cache_max_age, size_t stat_cache_max_entries);
|
||||
} GCSFile;
|
||||
|
||||
// This function is used to initialize a filesystem without the need of setting
|
||||
// manually environement variables.
|
||||
void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size,
|
||||
size_t max_bytes, uint64_t max_staleness,
|
||||
uint64_t stat_cache_max_age, size_t stat_cache_max_entries,
|
||||
TF_Status* status);
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
|
@ -66,6 +66,9 @@ static std::string* GetTmpDir() {
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// TODO(vnvo2409): Refactor `gcs_filesystem_test` to remove unnecessary tests
|
||||
// after porting all tests from
|
||||
// `//tensorflow/core/platform/cloud:gcs_file_system_test`.
|
||||
class GCSFilesystemTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
@ -74,13 +77,14 @@ class GCSFilesystemTest : public ::testing::Test {
|
||||
::testing::UnitTest::GetInstance()->current_test_info()->name());
|
||||
status_ = TF_NewStatus();
|
||||
filesystem_ = new TF_Filesystem;
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
filesystem_->plugin_filesystem = nullptr;
|
||||
// Because different tests requires different setup for filesystem. We
|
||||
// initialize filesystem in each testcase.
|
||||
}
|
||||
void TearDown() override {
|
||||
TF_DeleteStatus(status_);
|
||||
tf_gcs_filesystem::Cleanup(filesystem_);
|
||||
if (filesystem_->plugin_filesystem != nullptr)
|
||||
tf_gcs_filesystem::Cleanup(filesystem_);
|
||||
delete filesystem_;
|
||||
}
|
||||
|
||||
@ -117,6 +121,21 @@ class GCSFilesystemTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
|
||||
::testing::AssertionResult InsertObject(const std::string& path,
|
||||
const std::string& content,
|
||||
gcs::Client* gcs_client,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK)
|
||||
return ::testing::AssertionFailure() << TF_Message(status);
|
||||
auto metadata = gcs_client->InsertObject(bucket, object, content);
|
||||
if (metadata)
|
||||
return ::testing::AssertionSuccess();
|
||||
else
|
||||
return ::testing::AssertionFailure() << metadata.status().message();
|
||||
}
|
||||
|
||||
::testing::AssertionResult CompareSubString(int64_t offset, size_t length,
|
||||
absl::string_view result,
|
||||
size_t read) {
|
||||
@ -172,6 +191,9 @@ TEST_F(GCSFilesystemTest, ParseGCSPath) {
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, RandomAccessFile) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
std::string filepath = GetURIForPath("a_file");
|
||||
TF_RandomAccessFile* file = new TF_RandomAccessFile;
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file,
|
||||
@ -208,6 +230,9 @@ TEST_F(GCSFilesystemTest, RandomAccessFile) {
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, WritableFile) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
std::string filepath = GetURIForPath("a_file");
|
||||
TF_WritableFile* file = new TF_WritableFile;
|
||||
tf_gcs_filesystem::NewWritableFile(filesystem_, filepath.c_str(), file,
|
||||
@ -273,6 +298,9 @@ TEST_F(GCSFilesystemTest, WritableFile) {
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
std::string path = GetURIForPath("a_file");
|
||||
auto gcs_file =
|
||||
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
|
||||
@ -298,6 +326,131 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
|
||||
delete region;
|
||||
}
|
||||
|
||||
// These tests below are ported from
|
||||
// `//tensorflow/core/platform/cloud:gcs_file_system_test`
|
||||
TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) {
|
||||
tf_gcs_filesystem::InitTest(filesystem_, false, 0, 0, 0, 0, 0, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
std::string path = GetURIForPath("a_file");
|
||||
auto gcs_file =
|
||||
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
|
||||
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
|
||||
|
||||
TF_RandomAccessFile* file = new TF_RandomAccessFile;
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(6);
|
||||
int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_);
|
||||
ASSERT_EQ(read, 6) << "Read: " << read << "\n";
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_EQ(result, "012345") << "Result: " << result << "\n";
|
||||
|
||||
read = tf_random_access_file::Read(file, 6, 6, &result[0], status_);
|
||||
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
|
||||
result.resize(read);
|
||||
ASSERT_EQ(result, "6789") << "Result: " << result << "\n";
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered) {
|
||||
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
std::string path = GetURIForPath("a_file");
|
||||
auto gcs_file =
|
||||
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
|
||||
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
|
||||
|
||||
TF_RandomAccessFile* file = new TF_RandomAccessFile;
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(6);
|
||||
int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_);
|
||||
ASSERT_EQ(read, 6) << "Read: " << read << "\n";
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_EQ(result, "012345") << "Result: " << result << "\n";
|
||||
|
||||
read = tf_random_access_file::Read(file, 6, 6, &result[0], status_);
|
||||
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
|
||||
result.resize(read);
|
||||
ASSERT_EQ(result, "6789") << "Result: " << result << "\n";
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) {
|
||||
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
std::string path = GetURIForPath("a_file");
|
||||
auto gcs_file =
|
||||
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
|
||||
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
|
||||
|
||||
TF_RandomAccessFile* file = new TF_RandomAccessFile;
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(10);
|
||||
int64_t read = tf_random_access_file::Read(file, 0, result.length(),
|
||||
&result[0], status_);
|
||||
ASSERT_EQ(read, 10) << "Read: " << read << "\n";
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_EQ(result, "0123456789") << "Result: " << result << "\n";
|
||||
|
||||
read = tf_random_access_file::Read(file, result.length(), result.length(),
|
||||
&result[0], status_);
|
||||
ASSERT_EQ(read, 0) << "Read: " << read << "\n";
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
|
||||
result.resize(read);
|
||||
ASSERT_EQ(result, "") << "Result: " << result << "\n";
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) {
|
||||
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
std::string path = GetURIForPath("a_file");
|
||||
auto gcs_file =
|
||||
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
|
||||
ASSERT_TRUE(InsertObject(path, "012345678", &gcs_file->gcs_client, status_));
|
||||
|
||||
TF_RandomAccessFile* file = new TF_RandomAccessFile;
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
|
||||
std::string result;
|
||||
result.resize(5);
|
||||
int64_t read = tf_random_access_file::Read(file, 0, result.length(),
|
||||
&result[0], status_);
|
||||
ASSERT_EQ(read, 5) << "Read: " << read << "\n";
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_EQ(result, "01234") << "Result: " << result << "\n";
|
||||
|
||||
read = tf_random_access_file::Read(file, 4, result.length(), &result[0],
|
||||
status_);
|
||||
ASSERT_EQ(read, 5) << "Read: " << read << "\n";
|
||||
ASSERT_TF_OK(status_);
|
||||
result.resize(read);
|
||||
ASSERT_EQ(result, "45678") << "Result: " << result << "\n";
|
||||
|
||||
read = tf_random_access_file::Read(file, 5, result.length(), &result[0],
|
||||
status_);
|
||||
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
|
||||
result.resize(read);
|
||||
ASSERT_EQ(result, "5678") << "Result: " << result << "\n";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include <aws/s3/model/CompletedPart.h>
|
||||
#include <aws/s3/model/CopyObjectRequest.h>
|
||||
#include <aws/s3/model/CreateMultipartUploadRequest.h>
|
||||
#include <aws/s3/model/DeleteObjectRequest.h>
|
||||
#include <aws/s3/model/GetObjectRequest.h>
|
||||
#include <aws/s3/model/HeadBucketRequest.h>
|
||||
#include <aws/s3/model/HeadObjectRequest.h>
|
||||
@ -44,6 +45,7 @@ limitations under the License.
|
||||
constexpr char kS3FileSystemAllocationTag[] = "S3FileSystemAllocation";
|
||||
constexpr char kS3ClientAllocationTag[] = "S3ClientAllocation";
|
||||
constexpr int64_t kS3TimeoutMsec = 300000; // 5 min
|
||||
constexpr int kS3GetChildrenMaxKeys = 100;
|
||||
|
||||
constexpr char kExecutorTag[] = "TransferManagerExecutorAllocation";
|
||||
constexpr int kExecutorPoolSize = 25;
|
||||
@ -961,7 +963,157 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
|
||||
s3_file, status);
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
|
||||
Aws::S3::Model::DeleteObjectRequest delete_object_request;
|
||||
delete_object_request.WithBucket(bucket).WithKey(object);
|
||||
auto delete_object_outcome =
|
||||
s3_file->s3_client->DeleteObject(delete_object_request);
|
||||
if (!delete_object_outcome.IsSuccess())
|
||||
TF_SetStatusFromAWSError(delete_object_outcome.GetError(), status);
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
|
||||
if (object.empty()) {
|
||||
Aws::S3::Model::HeadBucketRequest head_bucket_request;
|
||||
head_bucket_request.WithBucket(bucket);
|
||||
auto head_bucket_outcome =
|
||||
s3_file->s3_client->HeadBucket(head_bucket_request);
|
||||
if (!head_bucket_outcome.IsSuccess())
|
||||
TF_SetStatusFromAWSError(head_bucket_outcome.GetError(), status);
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return;
|
||||
}
|
||||
|
||||
Aws::String dir_path = path;
|
||||
if (dir_path.back() != '/') dir_path.push_back('/');
|
||||
|
||||
PathExists(filesystem, dir_path.c_str(), status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> file(
|
||||
new TF_WritableFile, [](TF_WritableFile* file) {
|
||||
if (file != nullptr) {
|
||||
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
});
|
||||
file->plugin_file = nullptr;
|
||||
NewWritableFile(filesystem, dir_path.c_str(), file.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
tf_writable_file::Close(file.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
Aws::String bucket, object;
|
||||
ParseS3Path(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
|
||||
if (object.back() != '/') object.push_back('/');
|
||||
Aws::S3::Model::ListObjectsRequest list_objects_request;
|
||||
list_objects_request.WithBucket(bucket).WithPrefix(object).WithMaxKeys(2);
|
||||
list_objects_request.SetResponseStreamFactory(
|
||||
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
|
||||
auto list_objects_outcome =
|
||||
s3_file->s3_client->ListObjects(list_objects_request);
|
||||
if (list_objects_outcome.IsSuccess()) {
|
||||
auto contents = list_objects_outcome.GetResult().GetContents();
|
||||
if (contents.size() > 1 ||
|
||||
(contents.size() == 1 && contents[0].GetKey() != object)) {
|
||||
TF_SetStatus(status, TF_UNKNOWN,
|
||||
"Cannot delete a non-empty directory. "
|
||||
"This operation will be retried in case this "
|
||||
"is due to S3's eventual consistency.");
|
||||
}
|
||||
if (contents.size() == 1 && contents[0].GetKey() == object) {
|
||||
Aws::String dir_path = path;
|
||||
if (dir_path.back() != '/') dir_path.push_back('/');
|
||||
DeleteFile(filesystem, dir_path.c_str(), status);
|
||||
}
|
||||
} else {
|
||||
TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status);
|
||||
}
|
||||
}
|
||||
|
||||
int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
char*** entries, TF_Status* status) {
|
||||
Aws::String bucket, prefix;
|
||||
ParseS3Path(path, true, &bucket, &prefix, status);
|
||||
if (TF_GetCode(status) != TF_OK) return -1;
|
||||
if (!prefix.empty() && prefix.back() != '/') prefix.push_back('/');
|
||||
|
||||
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
|
||||
GetS3Client(s3_file);
|
||||
|
||||
Aws::S3::Model::ListObjectsRequest list_objects_request;
|
||||
list_objects_request.WithBucket(bucket)
|
||||
.WithPrefix(prefix)
|
||||
.WithMaxKeys(kS3GetChildrenMaxKeys)
|
||||
.WithDelimiter("/");
|
||||
list_objects_request.SetResponseStreamFactory(
|
||||
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
|
||||
|
||||
Aws::S3::Model::ListObjectsResult list_objects_result;
|
||||
std::vector<Aws::String> result;
|
||||
do {
|
||||
auto list_objects_outcome =
|
||||
s3_file->s3_client->ListObjects(list_objects_request);
|
||||
if (!list_objects_outcome.IsSuccess()) {
|
||||
TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status);
|
||||
return -1;
|
||||
}
|
||||
|
||||
list_objects_result = list_objects_outcome.GetResult();
|
||||
for (const auto& object : list_objects_result.GetCommonPrefixes()) {
|
||||
Aws::String s = object.GetPrefix();
|
||||
s.erase(s.length() - 1);
|
||||
Aws::String entry = s.substr(prefix.length());
|
||||
if (entry.length() > 0) {
|
||||
result.push_back(entry);
|
||||
}
|
||||
}
|
||||
for (const auto& object : list_objects_result.GetContents()) {
|
||||
Aws::String s = object.GetKey();
|
||||
Aws::String entry = s.substr(prefix.length());
|
||||
if (entry.length() > 0) {
|
||||
result.push_back(entry);
|
||||
}
|
||||
}
|
||||
list_objects_result.SetMarker(list_objects_result.GetNextMarker());
|
||||
} while (list_objects_result.GetIsTruncated());
|
||||
|
||||
int num_entries = result.size();
|
||||
*entries = static_cast<char**>(
|
||||
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
|
||||
for (int i = 0; i < num_entries; i++)
|
||||
(*entries)[i] = strdup(result[i].c_str());
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||
return strdup(uri);
|
||||
}
|
||||
|
||||
} // namespace tf_s3_filesystem
|
||||
|
||||
@ -969,6 +1121,48 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||
|
||||
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
|
||||
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
|
||||
ops->read_only_memory_region_ops->cleanup =
|
||||
tf_read_only_memory_region::Cleanup;
|
||||
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
|
||||
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_s3_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_s3_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_s3_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file = tf_s3_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_s3_filesystem::NewAppendableFile;
|
||||
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_s3_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
ops->filesystem_ops->create_dir = tf_s3_filesystem::CreateDir;
|
||||
ops->filesystem_ops->delete_file = tf_s3_filesystem::DeleteFile;
|
||||
ops->filesystem_ops->delete_dir = tf_s3_filesystem::DeleteDir;
|
||||
ops->filesystem_ops->copy_file = tf_s3_filesystem::CopyFile;
|
||||
ops->filesystem_ops->path_exists = tf_s3_filesystem::PathExists;
|
||||
ops->filesystem_ops->get_file_size = tf_s3_filesystem::GetFileSize;
|
||||
ops->filesystem_ops->stat = tf_s3_filesystem::Stat;
|
||||
ops->filesystem_ops->get_children = tf_s3_filesystem::GetChildren;
|
||||
ops->filesystem_ops->translate_name = tf_s3_filesystem::TranslateName;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
|
@ -104,6 +104,12 @@ TF_ShapeHandle* TF_NewShapeHandle() {
|
||||
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
|
||||
}
|
||||
|
||||
TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
|
||||
auto* handle = new ShapeHandle;
|
||||
*handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
|
||||
return reinterpret_cast<TF_ShapeHandle*>(handle);
|
||||
}
|
||||
|
||||
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
|
||||
TF_ShapeInferenceContext* ctx, size_t size) {
|
||||
auto* handle = new ShapeHandle;
|
||||
|
@ -280,6 +280,11 @@ extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
|
||||
int i, TF_ShapeHandle* handle,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns a newly-allocated scalar shape handle. The returned handle should
|
||||
// be freed with TF_DeleteShapeHandle.
|
||||
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar(
|
||||
TF_ShapeInferenceContext* ctx);
|
||||
|
||||
// Returns a newly-allocate shape handle representing a vector of the given
|
||||
// size. The returned handle should be freed with TF_DeleteShapeHandle.
|
||||
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
|
||||
|
@ -316,5 +316,16 @@ TEST(OpsTest, ShapeInferenceSubshape) {
|
||||
TF_DeleteShapeHandle(handle);
|
||||
}
|
||||
|
||||
TEST(OpsTest, ShapeInferenceScalarShape) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, def, MakeOpDef(0, 0), {S({})}, {}, {},
|
||||
{});
|
||||
TF_ShapeHandle* TF_scalar_shape = TF_ShapeInferenceContextScalar(C_CTX(&c));
|
||||
shape_inference::ShapeHandle* scalar_shape =
|
||||
reinterpret_cast<shape_inference::ShapeHandle*>(TF_scalar_shape);
|
||||
ASSERT_EQ("[]", c.DebugString(*scalar_shape));
|
||||
TF_DeleteShapeHandle(TF_scalar_shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -52,7 +52,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties):
|
||||
native.py_test(
|
||||
name = name,
|
||||
srcs = ["@llvm-project//llvm:lit"],
|
||||
tags = tags + ["no_windows"],
|
||||
tags = tags + ["no_pip", "no_windows"],
|
||||
args = [
|
||||
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
|
||||
] + features,
|
||||
|
@ -140,7 +140,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||
Here the effective workload shape roughly represents the maximum
|
||||
parallelism can be used during the codegen stage. It's used to check
|
||||
the shape-compatibility of the operation. During fusion, we only
|
||||
try to fuse shape-compatible ops for performace.
|
||||
try to fuse shape-compatible ops for performance.
|
||||
For example, the effective workload shape of an elementwise op is its
|
||||
output shape, while the effective workload shape of a reduction op may
|
||||
be its operand shape.
|
||||
|
@ -640,25 +640,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
}
|
||||
};
|
||||
|
||||
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class IotaConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
||||
OpTy iotaOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto resultMemrefType =
|
||||
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
|
||||
if (!resultMemrefType) return failure();
|
||||
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
|
||||
if (!resultShapedType) return failure();
|
||||
|
||||
auto resultElementType = resultMemrefType.getElementType();
|
||||
auto resultElementType = resultShapedType.getElementType();
|
||||
if (!resultElementType.isSignlessIntOrFloat()) return failure();
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
unsigned nloops = resultMemrefType.getRank();
|
||||
unsigned nloops = resultShapedType.getRank();
|
||||
|
||||
rewriter.create<linalg::IndexedGenericOp>(
|
||||
iotaOp.getLoc(), ArrayRef<Type>{}, args,
|
||||
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
|
||||
0, // args_in
|
||||
1, // args_out
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
@ -669,14 +669,16 @@ class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
||||
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
|
||||
nestedBuilder.getIntegerType(
|
||||
resultElementType.getIntOrFloatBitWidth()));
|
||||
if (resultElementType.isa<FloatType>()) {
|
||||
if (resultElementType.template isa<FloatType>()) {
|
||||
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
|
||||
resultElementType);
|
||||
}
|
||||
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(iotaOp, llvm::None);
|
||||
if (isLHLO)
|
||||
rewriter.replaceOp(iotaOp, llvm::None);
|
||||
else
|
||||
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -768,7 +770,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||
ConstConverter,
|
||||
ConvToLinalgConverter,
|
||||
IotaConverter,
|
||||
IotaConverter<lmhlo::IotaOp>,
|
||||
LhloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||
@ -870,36 +872,37 @@ namespace mhlo {
|
||||
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
HloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
patterns
|
||||
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
@ -89,12 +89,10 @@ def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
||||
// Absolute value is evaluated as:
|
||||
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
||||
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
||||
(HLO_ComplexOp
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp (HLO_RealOp:$real $val), $real),
|
||||
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
|
||||
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
|
||||
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;
|
||||
|
||||
// Exponential can be lowered to an exponential on the real component and a
|
||||
// sum of sinusoids of the imaginary component, which equates to a normal
|
||||
|
@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @iota
|
||||
func @iota() -> tensor<7x10xf32> {
|
||||
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
|
||||
return %result : tensor<7x10xf32>
|
||||
}
|
||||
// CHECK: linalg.indexed_generic
|
||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
|
||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
|
||||
|
@ -182,11 +182,10 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
|
||||
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]]
|
||||
return %2 : tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp
|
||||
|
@ -237,28 +237,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant_utils",
|
||||
srcs = [
|
||||
"utils/constant_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/constant_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lstm_utils",
|
||||
srcs = [
|
||||
@ -368,20 +346,23 @@ cc_library(
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":constant_utils",
|
||||
":lstm_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
":tftext_utils",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -943,7 +943,10 @@ def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [
|
||||
NoSideEffect,
|
||||
TFL_OperandHasAtleastRank<0, 2>,
|
||||
TFL_OperandHasAtleastRank<1, 2>,
|
||||
SameOperandsAndResultElementType]> {
|
||||
PredOpTrait<"x and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
PredOpTrait<"y and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
|
||||
let summary = "Batch Matrix Multiply Operator";
|
||||
|
||||
@ -4345,55 +4348,4 @@ def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankAtMost<0, 8>,
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
PredOpTrait<"output dimension count must be at most 8",
|
||||
Or<[TFL_OperandIsUnrankedPred<1>,
|
||||
TFL_OperandDimIsAtMost<1, 0, 8>]>>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Broadcast an array for a compatible shape.";
|
||||
|
||||
let description = [{
|
||||
Broadcasting is the process of making arrays to have compatible shapes
|
||||
for arithmetic operations. Two shapes are compatible if for each
|
||||
dimension pair they are either equal or one of them is one. When trying
|
||||
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
|
||||
and works its way forward.
|
||||
|
||||
For example,
|
||||
|
||||
>>> x = tf.constant([1, 2, 3])
|
||||
>>> y = tf.broadcast_to(x, [3, 3])
|
||||
>>> print(y)
|
||||
tf.Tensor(
|
||||
[[1 2 3]
|
||||
[1 2 3]
|
||||
[1 2 3]], shape=(3, 3), dtype=int32)
|
||||
|
||||
In the above example, the input Tensor with the shape of `[1, 3]`
|
||||
is broadcasted to output Tensor with shape of `[3, 3]`.
|
||||
|
||||
When doing broadcasted operations such as multiplying a tensor
|
||||
by a scalar, broadcasting (usually) confers some time or space
|
||||
benefit, as the broadcasted tensor is never materialized.
|
||||
|
||||
However, `broadcast_to` does not carry with it any such benefits.
|
||||
The newly-created tensor takes the full memory of the broadcasted
|
||||
shape. (In a graph context, `broadcast_to` might be fused to
|
||||
subsequent operation and then be optimized away, however.)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
|
||||
TFL_I32OrI64Tensor:$shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -5,6 +5,8 @@ func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3
|
||||
return %0: tensor<3x3xbf16>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_bf16
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xbf16>
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xbf16>
|
||||
}
|
||||
|
@ -1487,8 +1487,10 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xf32>
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
@ -1496,8 +1498,10 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32
|
||||
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[BCT]] : tensor<3x3xi32>
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {
|
||||
|
@ -1248,6 +1248,13 @@ func @testSpaceToBatchND(%arg0 : tensor<1x4x4x3xf32>, %arg1 : tensor<2xi32>, %ar
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatmulQuant(%arg0 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06:-2>>, %arg1 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>> {
|
||||
// CHECK: "tfl.batch_matmul"(%arg0, %arg1)
|
||||
%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06:-2>>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
|
||||
return %0 : tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
|
||||
}
|
||||
// -----
|
||||
|
||||
func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
|
||||
@ -2310,21 +2317,3 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<i32> {
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>)
|
||||
return %0#0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI32ShapeTensor
|
||||
func @testBroadcastToWithI32ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi32>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testBroadcastToWithI64ShapeTensor
|
||||
func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi64>):
|
||||
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
|
||||
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
|
||||
}
|
||||
|
@ -400,6 +400,32 @@ func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<
|
||||
// FOLD: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConstWithActivation
|
||||
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConstWithActivation
|
||||
func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40x40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<40xf32>
|
||||
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
|
||||
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
|
||||
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
|
||||
return %3 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]]
|
||||
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
|
||||
// FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
|
||||
// FOLD: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter
|
||||
func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<2.0> : tensor<40xf32>
|
||||
|
@ -1,5 +1,7 @@
|
||||
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
|
||||
|
||||
func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<256x3x32x32xf32>) :
|
||||
// OK
|
||||
@ -579,69 +581,18 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) ->
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_f32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
// CHECK-LABEL: xla_conv
|
||||
func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
|
||||
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1")
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("XlaConv/feature_group_count")
|
||||
%2 = "tf.Const"() {value = dense<1> : tensor<2x2xi32>} : () -> tensor<2x2xi32> loc("XlaConv/padding")
|
||||
%3 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> loc("XlaConv/window_strides")
|
||||
%4 = "tf.XlaConv"(%arg0, %0, %3, %2, %3, %3, %1) {device = "", dimension_numbers = "\18\02 \032\02\00\01@\03P\03Z\02\01\02b\02\01\02", precision_config = ""} : (tensor<4x8x8x16xf32>, tensor<3x3x16x16xf32>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<4x8x8x16xf32>
|
||||
return %4 : tensor<4x8x8x16xf32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
|
||||
// CHECK: %[[CST0:.*]] = constant dense<1.000000e+00> : tensor<16x3x3x16xf32>
|
||||
// CHECK: %[[RES:.*]] = "tfl.conv_2d"(%arg0, %[[CST0]], %[[CST]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4x8x8x16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<4x8x8x16xf32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
|
||||
return %0: tensor<3x3xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32>
|
||||
return %0: tensor<3x3xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output
|
||||
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
// CHECK: return [[MUL]] : tensor<*xi32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
return %0: tensor<7x8x1x2x3x4x5x6xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
|
||||
// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output
|
||||
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
|
||||
// CHECK: return [[BCT]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return %0: tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output
|
||||
// CHECK: "tf.BroadcastTo"(%arg0, %arg1)
|
||||
}
|
||||
|
@ -180,6 +180,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// control flow ops (IfOp, CaseOp).
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
|
||||
// This pass removes the asset file dependencies in hash table use cases.
|
||||
pass_manager->addPass(mlir::TF::CreateInitTextFileToImportPass());
|
||||
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
|
@ -109,9 +109,6 @@ def LegalizeArgMax : Pat<(TF_ArgMaxOp $input, $dim),
|
||||
def LegalizeArgMin : Pat<(TF_ArgMinOp $input, $dim),
|
||||
(TFL_ArgMinOp $input, $dim)>;
|
||||
|
||||
def LegalizeBroadcastTo : Pat<(TF_BroadcastToOp $input, $dim),
|
||||
(TFL_BroadcastToOp $input, $dim)>;
|
||||
|
||||
def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>;
|
||||
|
||||
def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>;
|
||||
|
@ -45,7 +45,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
@ -138,6 +137,7 @@ DECL_CONVERT_OP(StridedSlice);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
DECL_CONVERT_OP(Reciprocal);
|
||||
DECL_CONVERT_OP(RandomUniform);
|
||||
DECL_CONVERT_OP(BroadcastTo);
|
||||
|
||||
#undef DECL_CONVERT_OP
|
||||
|
||||
@ -483,6 +483,89 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
|
||||
Location loc,
|
||||
ShapedType shaped_type,
|
||||
int value) {
|
||||
Type element_type = shaped_type.getElementType();
|
||||
ShapedType scalar_type = RankedTensorType::get({}, element_type);
|
||||
Attribute attr;
|
||||
switch (element_type.getKind()) {
|
||||
case mlir::StandardTypes::F16: {
|
||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::BF16: {
|
||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::F32: {
|
||||
attr =
|
||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr;
|
||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||
|
||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
shape->add_dim()->set_size(int64_t{1});
|
||||
std::string content;
|
||||
auto complex_value =
|
||||
std::complex<float>(static_cast<float>(value), 0.0f);
|
||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||
sizeof(complex_value));
|
||||
repr.set_tensor_content(content);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
||||
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
||||
break;
|
||||
}
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = element_type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 8:
|
||||
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
||||
static_cast<int8_t>(value));
|
||||
break;
|
||||
case 16:
|
||||
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
||||
static_cast<int16_t>(value));
|
||||
break;
|
||||
case 32:
|
||||
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
||||
static_cast<int32_t>(value));
|
||||
break;
|
||||
case 64:
|
||||
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
||||
static_cast<int64_t>(value));
|
||||
break;
|
||||
default:
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
|
||||
}
|
||||
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
|
||||
@ -503,6 +586,31 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
|
||||
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
|
||||
auto output_type = tf_broadcast_to_op.output().getType();
|
||||
|
||||
auto status_or_const_op =
|
||||
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
|
||||
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
|
||||
status_or_const_op.ValueOrDie());
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
|
||||
rewriter.replaceOpWithNewOp<TFL::MulOp>(
|
||||
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
|
||||
fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Legalize unidirectional sequence lstm.
|
||||
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
|
||||
@ -643,7 +751,7 @@ void LegalizeTF::runOnFunction() {
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||
ConvertTFRandomUniformOp>(context);
|
||||
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
|
@ -341,8 +341,8 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
// make sure $rhs is the tail shape of $lhs.
|
||||
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
|
||||
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a), TFL_AF_None),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
|
||||
(ConstantOp:$rhs $a), $act_fn),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape),
|
||||
// The broadcasting of "BinaryOp" only happens in the lower
|
||||
// dimensions, and the higher dimensions are same, so we know the
|
||||
// result and input of the "BinaryOp" in the source pattern have
|
||||
|
@ -41,7 +41,9 @@ limitations under the License.
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
@ -49,16 +51,18 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
|
||||
@ -681,46 +685,6 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFBroadcastTo : public RewritePattern {
|
||||
explicit ConvertTFBroadcastTo(MLIRContext *context)
|
||||
: RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
|
||||
auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
|
||||
auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
|
||||
auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
|
||||
Type element_type = input_type.getElementType();
|
||||
|
||||
// Allow lowering when low dimension inputs are given and its type is F32 or
|
||||
// I32.
|
||||
if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
|
||||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
|
||||
shape_type.getDimSize(0) <= 4)))
|
||||
return failure();
|
||||
if (!((element_type.getKind() == mlir::StandardTypes::F32) ||
|
||||
(element_type.getKind() == mlir::StandardTypes::Integer &&
|
||||
element_type.cast<mlir::IntegerType>().getWidth() == 32)))
|
||||
return failure();
|
||||
|
||||
auto status_or_const_op =
|
||||
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
|
||||
if (!status_or_const_op.ok()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto tf_fill_op = rewriter.create<TF::FillOp>(
|
||||
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
|
||||
status_or_const_op.ValueOrDie());
|
||||
|
||||
auto mul_op = rewriter.create<TF::MulOp>(
|
||||
op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
|
||||
rewriter.replaceOp(op, mul_op.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
||||
|
||||
// Returns success if all the operations in the `op`'s regions including `op`
|
||||
@ -737,6 +701,23 @@ LogicalResult ValidateOp(Operation *op) {
|
||||
return failure(has_illegal_ops);
|
||||
}
|
||||
|
||||
// Converts a set of TF2XLA ops into pure TF ops for future legalizations as
|
||||
// TF2XLA ops aren't supported by later stages.
|
||||
LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<TF::TensorFlowDialect>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
target.addIllegalOp<TF::XlaConvOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns);
|
||||
TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
|
||||
|
||||
return applyPartialConversion(func, target, patterns);
|
||||
}
|
||||
|
||||
void PrepareTFPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
@ -752,6 +733,11 @@ void PrepareTFPass::runOnFunction() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(ConvertTf2XlaOps(func, ctx))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
@ -780,7 +766,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
||||
}
|
||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
|
||||
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
|
||||
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
@ -1,112 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
xla::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
||||
PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
|
||||
int value) {
|
||||
Type element_type = shaped_type.getElementType();
|
||||
ShapedType scalar_type = RankedTensorType::get({}, element_type);
|
||||
Attribute attr;
|
||||
switch (element_type.getKind()) {
|
||||
case mlir::StandardTypes::F16: {
|
||||
auto floatType = mlir::FloatType::getF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::BF16: {
|
||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::F32: {
|
||||
attr =
|
||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::Complex: {
|
||||
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
|
||||
if (etype.isF32()) {
|
||||
auto dialect = etype.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr;
|
||||
repr.set_dtype(tensorflow::DT_COMPLEX64);
|
||||
|
||||
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
shape->add_dim()->set_size(int64_t{1});
|
||||
std::string content;
|
||||
auto complex_value =
|
||||
std::complex<float>(static_cast<float>(value), 0.0f);
|
||||
content.assign(reinterpret_cast<const char*>(&complex_value),
|
||||
sizeof(complex_value));
|
||||
repr.set_tensor_content(content);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
||||
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
|
||||
break;
|
||||
}
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto& itype = element_type.cast<mlir::IntegerType>();
|
||||
switch (itype.getWidth()) {
|
||||
case 8:
|
||||
attr = DenseElementsAttr::get<int8_t>(scalar_type,
|
||||
static_cast<int8_t>(value));
|
||||
break;
|
||||
case 16:
|
||||
attr = DenseElementsAttr::get<int16_t>(scalar_type,
|
||||
static_cast<int16_t>(value));
|
||||
break;
|
||||
case 32:
|
||||
attr = DenseElementsAttr::get<int32_t>(scalar_type,
|
||||
static_cast<int32_t>(value));
|
||||
break;
|
||||
case 64:
|
||||
attr = DenseElementsAttr::get<int64_t>(scalar_type,
|
||||
static_cast<int64_t>(value));
|
||||
break;
|
||||
default:
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
|
||||
"Unsupported type");
|
||||
}
|
||||
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -1,35 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Returns a Constant op with a single value.
|
||||
xla::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
|
||||
PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
|
@ -73,7 +73,8 @@ tool_names = [
|
||||
'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir'
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir',
|
||||
'xla-thunks-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -727,8 +727,10 @@ cc_library(
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/gpu_fusion.cc",
|
||||
"transforms/graph_pruning.cc",
|
||||
"transforms/init_text_file_to_import.cc",
|
||||
"transforms/launch_to_device_attribute.cc",
|
||||
"transforms/layout_optimization.cc",
|
||||
"transforms/mark_ops_for_outside_compilation.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
@ -825,6 +827,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tensorflow_test_passes",
|
||||
srcs = [
|
||||
"transforms/init_text_file_to_import_test_pass.cc",
|
||||
"transforms/lift_variables_test_pass.cc",
|
||||
"transforms/lower_tf_pass.cc",
|
||||
],
|
||||
@ -840,8 +843,10 @@ cc_library(
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:threadpool_options",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -102,8 +102,6 @@ class MlirTensor : public TracingTensorHandle {
|
||||
return type;
|
||||
}
|
||||
|
||||
void Release() override { delete this; }
|
||||
|
||||
Value getValue() { return value_; }
|
||||
|
||||
// For LLVM style RTTI.
|
||||
|
@ -87,7 +87,7 @@ tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf]
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>,
|
||||
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType, TF_LayoutAgnostic]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -136,7 +136,7 @@ Inputs must be of same size and shape.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>,
|
||||
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType, TF_LayoutAgnostic]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = "Returns x + y element-wise.";
|
||||
|
||||
@ -725,6 +725,30 @@ window in `value`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> {
|
||||
let summary = "Performs 3D average pooling on the input.";
|
||||
|
||||
let description = [{
|
||||
Each entry in `output` is the mean of the corresponding size `ksize` window in
|
||||
`value`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$input,
|
||||
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$ksize,
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$strides,
|
||||
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
|
||||
DefaultValuedAttr<TF_AnyStrAttrOf<["NDHWC", "NCDHW"]>, "NDHWC">:$data_format
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> {
|
||||
let summary = "Computes gradients of average pooling function.";
|
||||
|
||||
@ -745,30 +769,6 @@ def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> {
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> {
|
||||
let summary = "Performs 3D average pooling on the input.";
|
||||
|
||||
let description = [{
|
||||
Each entry in `output` is the mean of the corresponding size `ksize`
|
||||
window in `value`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$value,
|
||||
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$ksize,
|
||||
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$strides,
|
||||
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
|
||||
DefaultValuedAttr<TF_AnyStrAttrOf<["NDHWC", "NCDHW"]>, "NDHWC">:$data_format
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> {
|
||||
let summary = "Computes gradients of the average pooling function.";
|
||||
|
||||
@ -4231,6 +4231,29 @@ Where to extract the key and value from a line is specified by `key_index` and
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||
let summary = "Updates specified rows 'i' with values 'v'.";
|
||||
|
||||
let description = [{
|
||||
Computes `x[i, :] = v; return x`.
|
||||
|
||||
Originally this function is mutative however for compilation we make this
|
||||
operation create / operate on a copy of `x`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x,
|
||||
I32Tensor:$i,
|
||||
TF_Tensor:$v
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the reciprocal of x element-wise.";
|
||||
|
||||
@ -6311,6 +6334,8 @@ This is the opposite of `unpack`.
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
@ -9711,6 +9736,22 @@ For internal use only.
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TPUOrdinalSelectorOp : TF_Op<"TPUOrdinalSelector", []> {
|
||||
let summary = "A TPU core selector Op.";
|
||||
|
||||
let description = [{
|
||||
This Op produces a set of TPU cores (for warm-up) or a single TPU core
|
||||
(for regular inference) to execute the TPU program on. The output is
|
||||
consumed by TPUPartitionedCall.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$device_ordinals
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> {
|
||||
let summary = "Connects N inputs to an N-way replicated TPU computation.";
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
|
@ -225,12 +225,25 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Get the then branch function.
|
||||
FuncOp then_func() {
|
||||
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(then_branch());
|
||||
}
|
||||
|
||||
// Get the else branch function.
|
||||
FuncOp else_func() {
|
||||
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(else_branch());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_YieldOp : TF_Op<"Yield",
|
||||
@ -612,7 +625,6 @@ body: A function that takes a list of tensors and returns another
|
||||
|
||||
FlatSymbolRefAttr:$cond,
|
||||
FlatSymbolRefAttr:$body,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
@ -625,10 +637,24 @@ body: A function that takes a list of tensors and returns another
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Get the condition function.
|
||||
FuncOp cond_func() {
|
||||
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(cond());
|
||||
}
|
||||
|
||||
// Get the body function.
|
||||
FuncOp body_func() {
|
||||
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(body());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TL_WhileRegionOp : TF_Op<"WhileRegion",
|
||||
@ -1070,31 +1096,6 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
|
||||
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
|
||||
}
|
||||
|
||||
// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once
|
||||
// autogenerated op def matches.
|
||||
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
|
||||
let summary = "Updates specified rows 'i' with values 'v'.";
|
||||
|
||||
let description = [{
|
||||
Computes `x[i, :] = v; return x`.
|
||||
|
||||
Originally this function is mutative however for compilation we make this
|
||||
operation create / operate on a copy of `x`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x,
|
||||
I32Tensor:$i,
|
||||
TF_Tensor:$v
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the Bessel i0e function of `x` element-wise.";
|
||||
|
||||
@ -1255,4 +1256,74 @@ def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> {
|
||||
let summary = [{
|
||||
Batches all the inputs tensors to the computation done by the function.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
So, for example, in the following code
|
||||
|
||||
```python
|
||||
|
||||
# This input will be captured.
|
||||
y = tf.placeholder_with_default(1.0, shape=[])
|
||||
|
||||
@tf.Defun(tf.float32)
|
||||
def computation(a):
|
||||
return tf.matmul(a, a) + y
|
||||
|
||||
b = gen_batch_ops.batch_function(
|
||||
f=computation
|
||||
in_tensors=[a],
|
||||
captured_tensors=computation.captured_inputs,
|
||||
Tout=[o.type for o in computation.definition.signature.output_arg],
|
||||
num_batch_threads=1,
|
||||
max_batch_size=10,
|
||||
batch_timeout_micros=100000, # 100ms
|
||||
allowed_batch_sizes=[3, 10],
|
||||
batching_queue="")
|
||||
|
||||
If more than one session.run call is simultaneously trying to compute `b`
|
||||
the values of `a` will be gathered, non-deterministically concatenated
|
||||
along the first axis, and only one thread will run the computation.
|
||||
|
||||
Assumes that all arguments of the function are Tensors which will be batched
|
||||
along their first dimension.
|
||||
|
||||
Arguments that are captured, are not batched. The session.run call which does
|
||||
the concatenation, will use the values of the captured tensors available to it.
|
||||
Therefore, typical uses of captured tensors should involve values which remain
|
||||
unchanged across session.run calls. Inference is a good example of this.
|
||||
|
||||
SparseTensor is not supported. The return value of the decorated function
|
||||
must be a Tensor or a list/tuple of Tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$in_tensors,
|
||||
Variadic<TF_Tensor>:$captured_tensors,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
I64Attr:$num_batch_threads,
|
||||
I64Attr:$max_batch_size,
|
||||
I64Attr:$batch_timeout_micros,
|
||||
DefaultValuedAttr<I64Attr, "10">:$max_enqueued_batches,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$allowed_batch_sizes,
|
||||
StrAttr:$container,
|
||||
StrAttr:$shared_name,
|
||||
StrAttr:$batching_queue,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$enable_large_batch_splitting,
|
||||
I32ElementsAttr:$operand_segment_sizes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$out_tensors
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedOperandTypeListAttr Tcaptured = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
@ -1615,6 +1615,10 @@ static LogicalResult Verify(IfOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfOp canonicalization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
public:
|
||||
explicit FoldConstantIfOp(MLIRContext *context)
|
||||
@ -1662,7 +1666,7 @@ LogicalResult FoldConstantIfOp::matchAndRewrite(
|
||||
|
||||
void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<FoldConstantIfOp>(context);
|
||||
results.insert<FoldConstantIfOp, DropAttributes<IfOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
|
@ -578,3 +578,23 @@ LogicalResult VerifyRegionResults(Operation *op, Region ®ion,
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Function control flow canonicalization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Eliminate attributes that are not needed, but can get attached to Ops
|
||||
// during import.
|
||||
template <typename Op>
|
||||
struct DropAttributes : public OpRewritePattern<Op> {
|
||||
using OpRewritePattern<Op>::OpRewritePattern;
|
||||
|
||||
// Drop the "output_shapes" attribute.
|
||||
LogicalResult matchAndRewrite(Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
bool found = op.removeAttr("output_shapes") ==
|
||||
MutableDictionaryAttr::RemoveResult::Removed;
|
||||
return success(found);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -217,6 +217,97 @@ static LogicalResult Verify(PackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Fold pack operation if it computes the input tensor shape:
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
|
||||
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
//
|
||||
// Where `...` are some statically known dimensions. In this case %pack can be
|
||||
// replaced with a %shape. This is a common pattern in models with a dynamic
|
||||
// batch size.
|
||||
|
||||
// Pack operation should pack at least two values.
|
||||
if (values().size() < 2) return {};
|
||||
|
||||
// Dimensions packed along axis = 0 (pack scalars into vector).
|
||||
if (axis().getSExtValue() != 0) return {};
|
||||
|
||||
// First packed value is defined by a strided slice operation.
|
||||
auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
|
||||
if (!slice_op) return {};
|
||||
|
||||
// Input to the slice op is defined by shape operation.
|
||||
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
|
||||
if (!shape_op) return {};
|
||||
|
||||
// Input tensor, which shape is reconstructed by the pack operation.
|
||||
Value tensor = shape_op.input();
|
||||
|
||||
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
|
||||
// scalar value from input vector).
|
||||
if (slice_op.begin_mask().getSExtValue() != 0 ||
|
||||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
|
||||
slice_op.end_mask().getSExtValue() != 0 ||
|
||||
slice_op.new_axis_mask().getSExtValue() != 0 ||
|
||||
slice_op.shrink_axis_mask().getSExtValue() != 1)
|
||||
return {};
|
||||
|
||||
// Returns a value if the `value` is defined by a ConstOp with a single
|
||||
// integer element in it and has an expected rank.
|
||||
auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
|
||||
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
|
||||
if (!const_op) return None;
|
||||
|
||||
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
|
||||
if (!value_attr || value_attr.getNumElements() != 1) return None;
|
||||
|
||||
auto value_ty = value_attr.getType();
|
||||
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
|
||||
|
||||
auto splat = value_attr.getSplatValue<IntegerAttr>();
|
||||
return splat.getValue().getSExtValue();
|
||||
};
|
||||
|
||||
// All other packed values are scalar constants.
|
||||
SmallVector<int64_t, 4> packed_dims;
|
||||
packed_dims.reserve(values().size() - 1);
|
||||
for (Value operand : llvm::drop_begin(values(), 1)) {
|
||||
if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
|
||||
packed_dims.push_back(*dim);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Slice exactly the first shape dimension:
|
||||
// begin = [0] end = [1], strides = [1]
|
||||
auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
|
||||
auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
|
||||
auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
|
||||
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
|
||||
*begin != 0 || *end != 1 || *strides != 1)
|
||||
return {};
|
||||
|
||||
// First tensor dimension is dynamic.
|
||||
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
|
||||
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
|
||||
!arg_ty.isDynamicDim(0))
|
||||
return {};
|
||||
|
||||
// Argument tensor rank is equal to the number of packed dimensions.
|
||||
if (arg_ty.getRank() != values().size()) return {};
|
||||
|
||||
// All other dimensions are statically known and equal to packed dims.
|
||||
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
|
||||
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
|
||||
return {};
|
||||
|
||||
// Replace %pack with %shape.
|
||||
return slice_op.input();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -608,12 +699,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<RedundantReshape>(context);
|
||||
results.insert<RedundantReshape, ReshapeToSelfShape>(context);
|
||||
}
|
||||
|
||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
Value tensor = this->tensor();
|
||||
Value shape = this->shape();
|
||||
|
||||
// Fold reshape if operand and result types are the same and all dimensions
|
||||
// are statically known (no-op reshape).
|
||||
@ -624,90 +714,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Fold reshape if the shape is computed from the input tensor:
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
|
||||
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
|
||||
//
|
||||
// Where `...` are some statically known dimensions. In this case reshape is
|
||||
// a no-op and can be replaced by %arg (assuming `...` are equal).
|
||||
auto pack_op = dyn_cast_or_null<PackOp>(shape.getDefiningOp());
|
||||
if (!pack_op || pack_op.values().size() < 2) return {};
|
||||
|
||||
// Dimensions packed along axis = 0 (pack scalars into vector).
|
||||
if (pack_op.axis().getSExtValue() != 0) return {};
|
||||
|
||||
// First packed value is defined by a strided slice operation.
|
||||
auto slice_op =
|
||||
dyn_cast_or_null<StridedSliceOp>(pack_op.values()[0].getDefiningOp());
|
||||
if (!slice_op) return {};
|
||||
|
||||
// Input to the slice op is defined by shape operation.
|
||||
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
|
||||
if (!shape_op || shape_op.input() != tensor) return {};
|
||||
|
||||
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
|
||||
// scalar value from input vector).
|
||||
if (slice_op.begin_mask().getSExtValue() != 0 ||
|
||||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
|
||||
slice_op.end_mask().getSExtValue() != 0 ||
|
||||
slice_op.new_axis_mask().getSExtValue() != 0 ||
|
||||
slice_op.shrink_axis_mask().getSExtValue() != 1)
|
||||
return {};
|
||||
|
||||
// Returns a value if the `value` is defined by a ConstOp with a single
|
||||
// integer element in it and has an expected rank.
|
||||
auto get_value = [](Value value, int expected_rank) -> Optional<int64_t> {
|
||||
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
|
||||
if (!const_op) return None;
|
||||
|
||||
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
|
||||
if (!value_attr || value_attr.getNumElements() != 1) return None;
|
||||
|
||||
auto value_ty = value_attr.getType();
|
||||
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
|
||||
|
||||
auto splat = value_attr.getSplatValue<IntegerAttr>();
|
||||
return splat.getValue().getSExtValue();
|
||||
};
|
||||
|
||||
// All other packed values are scalar constants.
|
||||
SmallVector<int64_t, 4> packed_dims;
|
||||
packed_dims.reserve(pack_op.values().size() - 1);
|
||||
for (Value operand : llvm::drop_begin(pack_op.values(), 1)) {
|
||||
if (auto dim = get_value(operand, /*expected_rank=*/0)) {
|
||||
packed_dims.push_back(*dim);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Slice exactly the first shape dimension:
|
||||
// begin = [0] end = [1], strides = [1]
|
||||
auto begin = get_value(slice_op.begin(), /*expected_rank=*/1);
|
||||
auto end = get_value(slice_op.end(), /*expected_rank=*/1);
|
||||
auto strides = get_value(slice_op.strides(), /*expected_rank=*/1);
|
||||
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
|
||||
*begin != 0 || *end != 1 || *strides != 1)
|
||||
return {};
|
||||
|
||||
// First tensor dimension is dynamic.
|
||||
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
|
||||
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
|
||||
!arg_ty.isDynamicDim(0))
|
||||
return {};
|
||||
|
||||
// Argument tensor rank is equal to the number of packed dimensions.
|
||||
if (arg_ty.getRank() != pack_op.values().size()) return {};
|
||||
|
||||
// All other dimensions are statically known and equal to packed dims.
|
||||
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
|
||||
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
|
||||
return {};
|
||||
|
||||
return tensor;
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2065,6 +2072,14 @@ static LogicalResult Verify(WhileOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileOp canonicalization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<DropAttributes<WhileOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileRegionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // from @llvm-project
|
||||
|
@ -377,6 +377,15 @@ func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> {
|
||||
// CHECK: return %1 : tensor<2x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testReshapeToSelfShape
|
||||
func @testReshapeToSelfShape(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x4xf32>) -> tensor<2xi32>
|
||||
%1 = "tf.Reshape"(%arg0, %0) : (tensor<?x4xf32>, tensor<2xi32>) -> tensor<?x4xf32>
|
||||
|
||||
// CHECK: return %arg0 : tensor<?x4xf32>
|
||||
return %1: tensor<?x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testReshapeNoOp
|
||||
func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> {
|
||||
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32>
|
||||
@ -385,8 +394,8 @@ func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @testReshapeNoOpShapeComputation
|
||||
func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>) {
|
||||
// CHECK-LABEL: func @testPackShapeComputation
|
||||
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
|
||||
// Test dimensions sizes.
|
||||
%d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
@ -396,65 +405,56 @@ func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
|
||||
// Fold reshape if the shape is computed from the input tensor:
|
||||
// Fold pack operation if it computes the input tensor shape:
|
||||
//
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
|
||||
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
|
||||
// %shape = tf.Shape(%arg) // [? x ...]
|
||||
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
|
||||
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
|
||||
//
|
||||
// Where `...` are some statically known dimensions. In this case reshape is
|
||||
// a no-op and can be replaced by %arg (assuming `...` are equal).
|
||||
// Where `...` are some statically known dimensions. In this case %pack can be
|
||||
// replace with a %shape. This is a common pattern in models with a dynamic
|
||||
// batch size.
|
||||
|
||||
// Test Rank 2
|
||||
// CHECK: %[[SHAPE0:.*]] = "tf.Shape"
|
||||
%3 = "tf.Shape"(%arg0) : (tensor<?x1xf32>) -> tensor<2xi32>
|
||||
%4 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%5 = "tf.Pack"(%4, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
|
||||
%6 = "tf.Reshape"(%arg0, %5) : (tensor<?x1xf32>, tensor<2xi32>) -> tensor<?x1xf32>
|
||||
|
||||
// Test Rank 3.
|
||||
|
||||
// CHECK: %[[SHAPE1:.*]] = "tf.Shape"
|
||||
%7 = "tf.Shape"(%arg1) : (tensor<?x1x2xf32>) -> tensor<3xi32>
|
||||
%8 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%9 = "tf.Pack"(%8, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
%10 = "tf.Reshape"(%arg1, %9) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
|
||||
|
||||
// Shape was taken from the op that is not reshaped in the end:
|
||||
// Reshape(%arg1) vs Shape(%arg0)
|
||||
%11 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%12 = "tf.Pack"(%11, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"
|
||||
%13 = "tf.Reshape"(%arg1, %12) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
|
||||
|
||||
// Packed dimensions have different order from the reshape operand:
|
||||
// [?, 1, 2] vs [?, 2, 1]
|
||||
%14 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"
|
||||
%16 = "tf.Reshape"(%arg1, %15) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x2x1xf32>
|
||||
// CHECK: %[[PACK0:.*]] = "tf.Pack"
|
||||
|
||||
// StridedSlice takes second dimension from the shape:
|
||||
// begin = [1], end = [2], stride = [1]
|
||||
%17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE2:.*]] = "tf.Reshape"
|
||||
%19 = "tf.Reshape"(%arg1, %18) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
|
||||
// CHECK: %[[PACK1:.*]] = "tf.Pack"
|
||||
|
||||
// Packed dimensions have higher rank than the reshape operand:
|
||||
// [?, 1] vs [?, 1, 1]
|
||||
%20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
|
||||
%21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
|
||||
// CHECK: %[[RESHAPE3:.*]] = "tf.Reshape"
|
||||
%22 = "tf.Reshape"(%arg0, %21) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
|
||||
// CHECK: %[[PACK2:.*]] = "tf.Pack"
|
||||
|
||||
// Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
|
||||
%23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
|
||||
%24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
%25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%26 = "tf.Reshape"(%arg2, %25) : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
// CHECK: %[[PACK3:.*]] = "tf.Pack"
|
||||
|
||||
// CHECK: return %arg0, %arg1, %[[RESHAPE0]], %[[RESHAPE1]], %[[RESHAPE2]], %[[RESHAPE3]]
|
||||
return %6, %10, %13, %16, %19, %22, %26 : tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>
|
||||
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
|
||||
return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectScalarPred
|
||||
@ -985,3 +985,36 @@ func @testWhileRegionUnusedValue(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %ar
|
||||
// CHECK: return %[[WHILE_OUT]]#0 : tensor<*xf32>
|
||||
return %0#0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Check that output_shapes attribute is removed for tf.If
|
||||
func @testIfThen(tensor<*xf32>) -> tensor<*xf32>
|
||||
func @testIfElse(tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK-LABEL: func @testIfDropOutputShapes
|
||||
func @testIfDropOutputShapes(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
|
||||
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
|
||||
// CHECK: "tf.If"
|
||||
// CHECK-NOT: output_shapes
|
||||
%1 = "tf.If"(%arg0, %arg1) {
|
||||
then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false, output_shapes = [#tf.shape<>]
|
||||
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// Check that output_shapes attribute is removed for tf.Whileß
|
||||
func @testWhileCond(tensor<*xf32>) -> (tensor<i1>)
|
||||
func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>)
|
||||
// CHECK-LABEL: func @testWhileDropOutputShapes
|
||||
func @testWhileDropOutputShapes(tensor<*xf32>) -> (tensor<*xf32>) {
|
||||
^bb0(%arg0: tensor<*xf32>):
|
||||
// CHECK: "tf.While"
|
||||
// CHECK-NOT: output_shapes
|
||||
%1 = "tf.While"(%arg0) {
|
||||
cond = @testWhileCond,
|
||||
body = @testWhileBody,
|
||||
is_stateless = false,
|
||||
output_shapes = [#tf.shape<>]
|
||||
} : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
@ -0,0 +1,14 @@
|
||||
// RUN: tf-opt -tf-init-text-file-to-import-test %s | FileCheck %s
|
||||
|
||||
// Tests that the tf.InitializeTableFromTextFileV2 op are inlined.
|
||||
|
||||
func @init_all_tables() {
|
||||
%cst = constant dense<"%FILE_PLACEHOLDER"> : tensor<!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
return
|
||||
// CHECK: [[CST:%.*]] = constant dense<["apple", "banana", "grape"]> : tensor<3x!tf.string>
|
||||
// CHECK: [[CST_0:%.*]] = constant dense<[0, 1, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
|
||||
// CHECK: "tf.LookupTableImportV2"([[VAL]], [[CST]], [[CST_0]])
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-init-text-file-to-import %s | FileCheck %s
|
||||
|
||||
// Tests that the given vocabulary file does not exist.
|
||||
|
||||
func @init_all_tables() {
|
||||
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
// expected-error @+1 {{'tf.InitializeTableFromTextFileV2' op failed to open vocabulary file (vocab_file_does_not_exist.txt): cannot open input file 'vocab_file_does_not_exist.txt': No such file or directory}}
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since
|
||||
// unsupported key_index, -1.
|
||||
|
||||
func @init_all_tables() {
|
||||
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -1 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
return
|
||||
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
|
||||
// CHECK: tf.InitializeTableFromTextFileV2"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since
|
||||
// unsupported value_index, 0.
|
||||
|
||||
func @init_all_tables() {
|
||||
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = 0 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
return
|
||||
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
|
||||
// CHECK: tf.InitializeTableFromTextFileV2"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since
|
||||
// unsupported vocab_size, 1.
|
||||
|
||||
func @init_all_tables() {
|
||||
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = 1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
return
|
||||
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
|
||||
// CHECK: tf.InitializeTableFromTextFileV2"
|
||||
}
|
@ -1,13 +1,13 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<4xf32>, %arg3: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
%outputs_2, %control_3 = tf_executor.island wraps "tf.Less"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%outputs_4, %control_5 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatefulIf")
|
||||
%outputs_6, %control_7 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatelessIf")
|
||||
tf_executor.fetch %outputs_4, %outputs_6 : tensor<f32>, tensor<f32>
|
||||
%outputs_4, %control_5 = tf_executor.island wraps "tf.If"(%outputs_2, %arg2, %arg3) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<i1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("StatefulIf")
|
||||
%outputs_6, %control_7 = tf_executor.island wraps "tf.If"(%outputs_2, %arg2, %arg3) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("StatelessIf")
|
||||
tf_executor.fetch %outputs_4, %outputs_6 : tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<f32>, tensor<f32>
|
||||
return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
@ -34,8 +34,32 @@ func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "If"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
// CHECK: list {
|
||||
// CHECK: shape {
|
||||
// CHECK: dim {
|
||||
// CHECK: size: 4
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: name: "StatelessIf"
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "StatelessIf"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
// CHECK: list {
|
||||
// CHECK: shape {
|
||||
// CHECK: dim {
|
||||
// CHECK: size: 4
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
@ -1,12 +1,12 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatelessWhile")
|
||||
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<f32>, tensor<f32>
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<f32>, tensor<f32>
|
||||
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
|
||||
}
|
||||
|
||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
@ -36,8 +36,34 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "While"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
// CHECK: list {
|
||||
// CHECK: shape {
|
||||
// CHECK: dim {
|
||||
// CHECK: size: 5
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
|
||||
// CHECK: name: "StatelessWhile"
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "StatelessWhile"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
// CHECK: list {
|
||||
// CHECK: shape {
|
||||
// CHECK: dim {
|
||||
// CHECK: size: 5
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
|
@ -56,7 +56,7 @@ func @propagate_if_op(
|
||||
"tf.If"(%arg1, %id0, %var_handle) {
|
||||
then_branch = @if_then,
|
||||
else_branch = @if_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
|
||||
tf_executor.yield
|
||||
@ -128,8 +128,7 @@ func @propagate_while_op(
|
||||
// CHECK-NEXT: "tf.While"
|
||||
"tf.While"(%arg1, %id0, %var_handle) {
|
||||
body = @while_body,
|
||||
cond = @while_cond,
|
||||
output_shapes = [], is_stateless = false}
|
||||
cond = @while_cond, is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) ->
|
||||
(tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
@ -209,8 +208,7 @@ func @error_on_conflict_multiple_callers(
|
||||
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
|
||||
"tf.If"(%arg1, %id0, %var_handle) {
|
||||
then_branch = @if_then_and_else,
|
||||
else_branch = @if_then_and_else,
|
||||
output_shapes = [], is_stateless = false}
|
||||
else_branch = @if_then_and_else, is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
|
||||
"tf.If"(%arg1, %var_handle, %id0) {
|
||||
|
@ -147,8 +147,7 @@ func @cluster_with_loop() -> () {
|
||||
"tf_device.cluster"() ( {
|
||||
// CHECK: %[[WHILE:.*]]:2 = "tf.While"(%[[COUNT]], %[[READ]])
|
||||
%2:3 = "tf.While"(%0, %1, %unused)
|
||||
{body = @while_body, cond = @while_cond, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>, #tf.shape<>]}
|
||||
{body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
// CHECK: tf_device.return %[[WHILE]]#1 : tensor<f32>
|
||||
@ -197,8 +196,7 @@ func @cluster_with_loop() -> () {
|
||||
"tf_device.cluster"() ( {
|
||||
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]])
|
||||
%1 = "tf.While"(%0) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>]}
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
// CHECK: tf_device.return %[[WHILE]] : tensor<f32>
|
||||
@ -239,8 +237,7 @@ func @cluster_with_loop() -> () {
|
||||
"tf_device.cluster"() ( {
|
||||
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]])
|
||||
%1 = "tf.While"(%0) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>]}
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
// CHECK: tf_device.return
|
||||
@ -278,8 +275,7 @@ func @cluster_with_nested_loop() -> () {
|
||||
"tf_device.cluster"() ( {
|
||||
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]])
|
||||
%2:2 = "tf.While"(%0, %1) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>, #tf.shape<>]}
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
// CHECK: tf_device.return %[[WHILE]] : tensor<f32>
|
||||
@ -295,8 +291,7 @@ func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf
|
||||
-> (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[BARG0]])
|
||||
%0:2 = "tf.While"(%arg0, %arg1) {
|
||||
body = @while_body1, cond = @while_cond1, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>, #tf.shape<>]}
|
||||
body = @while_body1, cond = @while_cond1, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
// CHECK-NEXT: return %[[WHILE]]
|
||||
@ -334,8 +329,7 @@ func @cluster_with_loop() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.cluster"() ( {
|
||||
%1 = "tf.While"(%0) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>]}
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
@ -359,8 +353,7 @@ func @cluster_with_loop() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.cluster"() ( {
|
||||
%1 = "tf.While"(%0) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>]}
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
@ -384,8 +377,7 @@ func @cluster_with_loop() -> () {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
"tf_device.cluster"() ( {
|
||||
%1 = "tf.While"(%0) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>]}
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
|
@ -100,10 +100,11 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
return %1 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_from_if_to_branch_functions
|
||||
func @shape_from_if_to_branch_functions(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
// CHECK-LABEL: func @shape_from_if_to_branch_functions_to_results
|
||||
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||
func @shape_from_if_to_branch_functions_to_results(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @if_then_branch
|
||||
|
@ -20,8 +20,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]}
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
@ -217,8 +216,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]}
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
@ -305,8 +303,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"],
|
||||
body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false,
|
||||
output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]}
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
|
@ -7,7 +7,7 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
|
||||
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>)
|
||||
%3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @while_body_2710
|
||||
|
@ -209,6 +209,11 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
|
||||
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
|
||||
(TF_ReshapeOp $arg, $shape)>;
|
||||
|
||||
def IsSame : Constraint<CPred<"$0 == $1">>;
|
||||
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)),
|
||||
(replaceWithValue $arg0),
|
||||
[(IsSame $arg0, $arg1)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Select op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -19,15 +19,62 @@ limitations under the License.
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_executor {
|
||||
|
||||
// Visits an op's operand if it is an output of an Operation in the same
|
||||
// tf_executor.graph.
|
||||
void VisitOpOperand(GraphOp graph, Value operand,
|
||||
llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
|
||||
llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
|
||||
Operation* def = operand.getDefiningOp();
|
||||
if (def && def->getParentOp() == graph && reachable_ops->insert(def).second) {
|
||||
// Op has not been visited, add to queue to visit later.
|
||||
ops_to_visit->push_back(def);
|
||||
}
|
||||
}
|
||||
|
||||
// Visits all operands of an op where each operand is an output of an Operation
|
||||
// in the same tf_executor.graph.
|
||||
void VisitOpOperands(GraphOp graph, Operation* op,
|
||||
llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
|
||||
llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
|
||||
for (Value operand : op->getOperands())
|
||||
VisitOpOperand(graph, operand, reachable_ops, ops_to_visit);
|
||||
}
|
||||
|
||||
// Visits an op and it's associated operands. IslandOps are handled differently
|
||||
// where it's regions op operands are also visited as values may be implicitly
|
||||
// captured within. NextIterationSourceOp will also visit it's associated
|
||||
// NextIterationSinkOp.
|
||||
void VisitOp(GraphOp graph, Operation* op,
|
||||
llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
|
||||
llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
|
||||
if (auto island = llvm::dyn_cast<IslandOp>(op)) {
|
||||
mlir::visitUsedValuesDefinedAbove(
|
||||
island.body(), island.body(), [&](OpOperand* operand) {
|
||||
VisitOpOperand(graph, operand->get(), reachable_ops, ops_to_visit);
|
||||
});
|
||||
}
|
||||
|
||||
VisitOpOperands(graph, op, reachable_ops, ops_to_visit);
|
||||
|
||||
// If op is a `tf_executor.NextIteration.Source`, visit its associated
|
||||
// `tf_executor.NextIteration.Sink` op.
|
||||
if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
|
||||
Operation* sink_op = source_op.GetSink().getOperation();
|
||||
if (reachable_ops->insert(sink_op).second) ops_to_visit->push_back(sink_op);
|
||||
}
|
||||
}
|
||||
|
||||
// Prunes unreachable operations of a tf_executor.graph operation.
|
||||
void PruneGraph(GraphOp graph) {
|
||||
// A graph has a single block which forms a DAG: operations that aren't
|
||||
@ -36,49 +83,23 @@ void PruneGraph(GraphOp graph) {
|
||||
llvm::SmallPtrSet<Operation*, 8> reachable_ops;
|
||||
llvm::SmallVector<Operation*, 8> ops_to_visit;
|
||||
|
||||
// Visit an op's operands if it is output of an Operation in same graph.
|
||||
auto visit_op = [&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* def = operand.getDefiningOp();
|
||||
if (def && def->getParentOp() == graph &&
|
||||
reachable_ops.insert(def).second) {
|
||||
// Op has not been visited, add to queue to visit later.
|
||||
ops_to_visit.push_back(def);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Visit `fetch` operands.
|
||||
visit_op(graph.GetFetch());
|
||||
// Visit fetches first to create a starting point for ops that are reachable.
|
||||
reachable_ops.insert(graph.GetFetch());
|
||||
VisitOpOperands(graph, graph.GetFetch(), &reachable_ops, &ops_to_visit);
|
||||
|
||||
// Visit transitive ops until no there are no reachable ops left that have not
|
||||
// been visited.
|
||||
while (!ops_to_visit.empty()) {
|
||||
Operation* op = ops_to_visit.pop_back_val();
|
||||
if (llvm::isa<IslandOp>(op)) {
|
||||
// Visit island and island inner ops operands.
|
||||
op->walk([&](Operation* inner_op) { visit_op(inner_op); });
|
||||
continue;
|
||||
} else {
|
||||
// Op is not an island, only visit its operands.
|
||||
visit_op(op);
|
||||
}
|
||||
|
||||
// If op is a `tf_executor.NextIteration.Source`, visit its associated
|
||||
// `tf_executor.NextIteration.Sink` op.
|
||||
if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
|
||||
Operation* sink_op = source_op.GetSink().getOperation();
|
||||
if (reachable_ops.insert(sink_op).second) {
|
||||
ops_to_visit.push_back(sink_op);
|
||||
}
|
||||
}
|
||||
VisitOp(graph, op, &reachable_ops, &ops_to_visit);
|
||||
}
|
||||
|
||||
// Erase unreachable ops in reverse order.
|
||||
for (Operation& op : llvm::make_early_inc_range(
|
||||
llvm::drop_begin(llvm::reverse(graph.GetBody()), 1))) {
|
||||
if (reachable_ops.find(&op) == reachable_ops.end()) {
|
||||
op.erase();
|
||||
}
|
||||
}
|
||||
// Erase unreachable ops in reverse order so references don't need to be
|
||||
// dropped before removing an op. Going in reverse order will guarantee that
|
||||
// when an op to be erased is reached, there are no users left.
|
||||
for (Operation& op :
|
||||
llvm::make_early_inc_range(llvm::reverse(graph.GetBody())))
|
||||
if (!reachable_ops.contains(&op)) op.erase();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -0,0 +1,134 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
static constexpr int kTextFileIndex_WholeLine = -2;
|
||||
static constexpr int kTextFileIndex_LineNumber = -1;
|
||||
|
||||
// InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the
|
||||
// corresponding LookupTableImportV2Op if possible.
|
||||
class InitTextFileToImportPass
|
||||
: public mlir::PassWrapper<InitTextFileToImportPass, FunctionPass> {
|
||||
public:
|
||||
explicit InitTextFileToImportPass() {}
|
||||
|
||||
private:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
class ConvertInitializeTableFromTextFileV2
|
||||
: public OpRewritePattern<InitializeTableFromTextFileV2Op> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
// Now, this pattern matching only supports the following case, which is
|
||||
// commonly used among inference use cases:
|
||||
//
|
||||
// tf.lookup.TextFileInitializer(
|
||||
// "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
|
||||
// tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ")
|
||||
//
|
||||
// In the above case, the delimiter will be not used since the key is just a
|
||||
// whole line and value is a line number.
|
||||
if (op.key_index() != kTextFileIndex_WholeLine ||
|
||||
op.value_index() != kTextFileIndex_LineNumber ||
|
||||
op.vocab_size() != -1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Try to find filename from constant op.
|
||||
DenseStringElementsAttr filename_attr;
|
||||
if (!matchPattern(op.filename().getDefiningOp(),
|
||||
m_Constant(&filename_attr))) {
|
||||
return failure();
|
||||
}
|
||||
StringRef filename = filename_attr.getRawStringData()[0];
|
||||
|
||||
// Read the content of the file.
|
||||
std::string error_message;
|
||||
auto file = openInputFile(filename, &error_message);
|
||||
if (!file) {
|
||||
return op.emitOpError("failed to open vocabulary file")
|
||||
<< " (" << filename.str() << "): " << error_message;
|
||||
}
|
||||
|
||||
// Splits into lines.
|
||||
SmallVector<StringRef, 8> lines;
|
||||
file->getBuffer().split(lines, "\n", -1, false);
|
||||
|
||||
// Map each line to line number, starting from zero.
|
||||
SmallVector<int64_t, 8> line_nums;
|
||||
line_nums.resize(lines.size());
|
||||
std::iota(line_nums.begin(), line_nums.end(), 0);
|
||||
|
||||
// Create constant ops for keys an values.
|
||||
Value key_constant_tensor = rewriter.create<ConstantOp>(
|
||||
op.getLoc(),
|
||||
DenseStringElementsAttr::get(
|
||||
RankedTensorType::get(static_cast<int64_t>(lines.size()),
|
||||
StringType::get(rewriter.getContext())),
|
||||
lines));
|
||||
|
||||
Value value_constant_tensor = rewriter.create<ConstantOp>(
|
||||
op.getLoc(), rewriter.getI64TensorAttr(line_nums));
|
||||
|
||||
// Replace the given op with LookupTableImportV2Op.
|
||||
rewriter.create<LookupTableImportV2Op>(op.getLoc(), op.table_handle(),
|
||||
key_constant_tensor,
|
||||
value_constant_tensor);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void InitTextFileToImportPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
MLIRContext* context = &getContext();
|
||||
FuncOp func = getFunction();
|
||||
|
||||
patterns.insert<ConvertInitializeTableFromTextFileV2>(context);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass() {
|
||||
return std::make_unique<InitTextFileToImportPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<InitTextFileToImportPass> pass(
|
||||
"tf-init-text-file-to-import",
|
||||
"convert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op to "
|
||||
"remove the dependency on asset files");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
// InitTextFileToImportTestPass generates a temporary file and run the
|
||||
// InitTextFileToImportPass for testing purpose.
|
||||
class InitTextFileToImportTestPass
|
||||
: public mlir::PassWrapper<InitTextFileToImportTestPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit InitTextFileToImportTestPass() {}
|
||||
|
||||
private:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void InitTextFileToImportTestPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
// Create a temporary vocab file.
|
||||
int fd;
|
||||
SmallString<256> filename;
|
||||
std::error_code error_code =
|
||||
llvm::sys::fs::createTemporaryFile("text", "vocab", fd, filename);
|
||||
if (error_code) return signalPassFailure();
|
||||
|
||||
llvm::ToolOutputFile temp_file(filename, fd);
|
||||
const char* dictionary_in_lines =
|
||||
"apple\n"
|
||||
"banana\n"
|
||||
"grape";
|
||||
temp_file.os() << dictionary_in_lines;
|
||||
temp_file.os().flush();
|
||||
|
||||
// Replace filename constant ops to use the temporary file.
|
||||
MLIRContext* context = &getContext();
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
llvm::SmallVector<ConstantOp, 4> constant_ops(func.getOps<ConstantOp>());
|
||||
for (auto op : constant_ops) {
|
||||
ShapedType shaped_type =
|
||||
RankedTensorType::get({1}, StringType::get(context));
|
||||
|
||||
DenseStringElementsAttr attr;
|
||||
if (!matchPattern(op.getOperation(), m_Constant(&attr))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ArrayRef<StringRef> values = attr.getRawStringData();
|
||||
if (values.size() != 1 || values[0] != "%FILE_PLACEHOLDER") {
|
||||
continue;
|
||||
}
|
||||
|
||||
op.valueAttr(DenseStringElementsAttr::get(shaped_type, {filename}));
|
||||
}
|
||||
}
|
||||
|
||||
// Run the lowering pass.
|
||||
PassManager pm(context);
|
||||
pm.addPass(CreateInitTextFileToImportPass());
|
||||
if (failed(pm.run(module))) return signalPassFailure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static PassRegistration<InitTextFileToImportTestPass> pass(
|
||||
"tf-init-text-file-to-import-test",
|
||||
"generate a temporary file and invoke InitTextFileToImportPass");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -744,9 +745,7 @@ void LegalizeHloToTf::runOnFunction() {
|
||||
|
||||
// Add legalization patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&context, &patterns);
|
||||
patterns.insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
|
||||
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum>(&context);
|
||||
PopulateLegalizeHloToTfPatterns(&patterns, &context);
|
||||
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<TensorFlowDialect>();
|
||||
@ -762,6 +761,13 @@ static PassRegistration<LegalizeHloToTf> pass(
|
||||
|
||||
} // end namespace
|
||||
|
||||
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *context) {
|
||||
populateWithGenerated(context, patterns);
|
||||
patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
|
||||
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
|
||||
return std::make_unique<LegalizeHloToTf>();
|
||||
}
|
||||
|
@ -344,12 +344,56 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints,
|
||||
// since we currently don't have an implementation that can use this
|
||||
// information. Adds appropriate casts where necessary to align element types
|
||||
// of operands and result for `TF::MatMulOp`.
|
||||
class LowerSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Result type must be f32 for applying the pattern (currently this is
|
||||
// required by the op anyway but this might change).
|
||||
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
|
||||
return failure();
|
||||
}
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
|
||||
for (Value &operand : operands) {
|
||||
TensorType tensor_type = operand.getType().cast<TensorType>();
|
||||
Type element_type = tensor_type.getElementType();
|
||||
if (element_type.isF32()) continue;
|
||||
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
|
||||
// must be bf16 here.
|
||||
assert(element_type.isBF16());
|
||||
Type tensor_type_f32;
|
||||
if (tensor_type.hasRank()) {
|
||||
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
|
||||
FloatType::getF32(context));
|
||||
} else {
|
||||
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
|
||||
}
|
||||
// Add cast to f32 to conform with element type of result.
|
||||
operand =
|
||||
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
|
||||
}
|
||||
Value result = rewriter.create<TF::MatMulOp>(
|
||||
op.getLoc(), op.product().getType(), operands[0], operands[1],
|
||||
op.transpose_a(), op.transpose_b());
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void PopulateLoweringTFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerAddNOp, LowerDynamicStitchOp, LowerInvertPermutationOp,
|
||||
LowerPackOp>(context);
|
||||
LowerPackOp, LowerSparseMatMulOp>(context);
|
||||
populateWithGenerated(context, patterns);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFDevice {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pass marks unsupported ops in a device cluster with
|
||||
// `_xla_outside_compilation` attribute so the operations will run on the host
|
||||
// instead of the device. Unsupported ops are ops that can not be code
|
||||
// generated to run on the device for the cluster.
|
||||
struct MarkOpsForOutsideCompilation
|
||||
: public PassWrapper<MarkOpsForOutsideCompilation,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void MarkOpsForOutsideCompilation::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
|
||||
module.walk([&](tf_device::ClusterOp cluster) {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkOpsForOutsideCompilationPass() {
|
||||
return std::make_unique<MarkOpsForOutsideCompilation>();
|
||||
}
|
||||
|
||||
static PassRegistration<MarkOpsForOutsideCompilation> pass(
|
||||
"tf-mark-ops-for-outside-compilation",
|
||||
"Marks unsupported ops a device cluster for outside compilation.");
|
||||
|
||||
} // namespace TFDevice
|
||||
} // namespace mlir
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
@ -148,6 +150,10 @@ CreateTensorArrayOpsDecompositionPass();
|
||||
// Create a pass that legalize HLO to TF dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
||||
|
||||
// Addds the HLO to TF rewrite patterns to the specified pattern list.
|
||||
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList* patterns,
|
||||
MLIRContext* context);
|
||||
|
||||
// Matches sequence of ops to TensorFlow fused kernels. This pass should not be
|
||||
// generally used beyond exporting to runtimes that supports these ops. In the
|
||||
// future these fusions may be codegen'd automatically.
|
||||
@ -155,6 +161,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
||||
|
||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
|
||||
// Creates function pass to replace InitializeTableFromTextFileV2Ops with
|
||||
// LookupTableImportV2Op ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
@ -237,6 +247,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateAnnotateParameterReplicationPass();
|
||||
|
||||
// Creates a pass that marks unsupported ops in device cluster for outside
|
||||
// compilation.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkOpsForOutsideCompilationPass();
|
||||
|
||||
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
|
||||
// attribute to each TensorFlow dialect op in the body based on the `device`
|
||||
// attribute on the `tf_device.launch`.
|
||||
|
@ -373,8 +373,7 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
|
||||
OpBuilder builder(while_region);
|
||||
auto while_op = builder.create<WhileOp>(
|
||||
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
|
||||
builder.getArrayAttr({}), while_region.parallel_iterations(),
|
||||
while_region.is_stateless());
|
||||
while_region.parallel_iterations(), while_region.is_stateless());
|
||||
|
||||
// Redirect old results to new results.
|
||||
for (auto it : llvm::zip(
|
||||
|
@ -627,8 +627,6 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
});
|
||||
// Recreate the while op.
|
||||
OpBuilder builder(while_op);
|
||||
auto new_output_shapes = FilterRange<Attribute, ArrayRef<Attribute>>(
|
||||
while_op.output_shapes().getValue(), resource_arg_uses);
|
||||
// Now use the filtered original operands, which will be replaced by
|
||||
// AddLoadsStoresOutsideControlFlowOp().
|
||||
auto new_while = builder.create<TF::WhileOp>(
|
||||
@ -636,8 +634,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
FilterRange<Value, OperandRange>(while_op.getOperands(),
|
||||
resource_arg_uses),
|
||||
while_op.getAttrs());
|
||||
// Prepare for AddLoadsStoresOutsideControlFlowOp() and update
|
||||
// new_output_shapes.
|
||||
// Prepare for AddLoadsStoresOutsideControlFlowOp().
|
||||
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
|
||||
arg_data_type_and_updated_output_index;
|
||||
for (const auto& entry : remaining_resource_data_types) {
|
||||
@ -647,14 +644,9 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
: entry.getFirst();
|
||||
arg_data_type_and_updated_output_index[entry.getFirst()] = {
|
||||
entry.getSecond(), update_index};
|
||||
if (!new_output_shapes.empty()) {
|
||||
new_output_shapes[entry.getFirst()] =
|
||||
tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond());
|
||||
}
|
||||
}
|
||||
AddLoadsStoresOutsideControlFlowOp(new_while,
|
||||
arg_data_type_and_updated_output_index);
|
||||
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < old_to_new_indices.size(); ++i) {
|
||||
if (old_to_new_indices[i] >= 0) {
|
||||
|
@ -262,22 +262,6 @@ bool InferShapeForCall(Operation* op) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Infer the shape IfRegion outputs based on the shapes of the then and else
|
||||
// yields.
|
||||
bool InferShapeForIfRegion(IfRegionOp op) {
|
||||
bool changed = false;
|
||||
|
||||
Operation* then_yield = op.then_branch().front().getTerminator();
|
||||
Operation* else_yield = op.else_branch().front().getTerminator();
|
||||
for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
|
||||
else_yield->getOperandTypes())) {
|
||||
// If then and else types do not match, skip refinement for that result.
|
||||
if (std::get<1>(result) != std::get<2>(result)) continue;
|
||||
changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
|
||||
changed;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
|
||||
Value result = op.getResult();
|
||||
if (!CanBeRefined(result.getType())) return false;
|
||||
@ -306,6 +290,37 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Infer the shape IfOp outputs based on the shapes of the then and else
|
||||
// function result types.
|
||||
bool InferShapeForIf(IfOp op) {
|
||||
bool changed = false;
|
||||
auto then_results = op.then_func().getType().getResults();
|
||||
auto else_results = op.else_func().getType().getResults();
|
||||
for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
|
||||
// If then and else types do not match, skip refinement for that result.
|
||||
if (std::get<1>(it) != std::get<2>(it)) continue;
|
||||
changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Infer the shape IfRegion outputs based on the shapes of the then and else
|
||||
// yields.
|
||||
bool InferShapeForIfRegion(IfRegionOp op) {
|
||||
bool changed = false;
|
||||
|
||||
Operation* then_yield = op.then_branch().front().getTerminator();
|
||||
Operation* else_yield = op.else_branch().front().getTerminator();
|
||||
for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
|
||||
else_yield->getOperandTypes())) {
|
||||
// If then and else types do not match, skip refinement for that result.
|
||||
if (std::get<1>(result) != std::get<2>(result)) continue;
|
||||
changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
|
||||
changed;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
|
||||
Dialect* tf_dialect) {
|
||||
Operation* op = infer_ti.getOperation();
|
||||
@ -768,17 +783,23 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
op))
|
||||
return InferShapeForCall(op);
|
||||
|
||||
// Handle IfRegion operations by infering return shape from the then and else
|
||||
// branches.
|
||||
if (auto if_region = dyn_cast<IfRegionOp>(op))
|
||||
return InferShapeForIfRegion(if_region);
|
||||
|
||||
// tf.Cast are only inferred if they have at least one user in the TF dialect
|
||||
// or feeding into the function return. This is necessary to avoid inserting
|
||||
// casts which cannot be refined.
|
||||
if (auto cast_op = dyn_cast<CastOp>(op))
|
||||
return InferShapeForCast(cast_op, tf_dialect_);
|
||||
|
||||
// Handle IfOp here by inferring the shape from the else/then function
|
||||
// results. Since `output_shapes` is a derived attribute, avoid going down the
|
||||
// TF InferenceContext path as IfOp shape inference is implemented as just
|
||||
// a lookup of the output_shapes attribute.
|
||||
if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
|
||||
|
||||
// Handle IfRegion operations by infering return shape from the then and else
|
||||
// branches.
|
||||
if (auto if_region = dyn_cast<IfRegionOp>(op))
|
||||
return InferShapeForIfRegion(if_region);
|
||||
|
||||
StringRef op_name = op->getName().getStringRef();
|
||||
// Drop the `tf.` prefix to query TF registry.
|
||||
auto node_name =
|
||||
|
@ -197,24 +197,16 @@ LogicalResult HandleWhileOp(
|
||||
if (!signature_change) return success();
|
||||
// Create the new while op.
|
||||
auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
|
||||
auto new_output_shapes =
|
||||
llvm::to_vector<8>(while_op.output_shapes().getValue());
|
||||
OpBuilder builder(while_op);
|
||||
assert(while_op.getNumOperands() == while_op.getNumResults());
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
auto it = data_var_to_size_var.find(while_op.getOperand(i));
|
||||
if (it == data_var_to_size_var.end()) continue;
|
||||
new_while_operands.push_back(it->getSecond());
|
||||
if (!new_output_shapes.empty()) {
|
||||
// Size is a scalar shape.
|
||||
new_output_shapes.push_back(
|
||||
mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef<int64_t>()));
|
||||
}
|
||||
}
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
|
@ -595,8 +595,6 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
operands, while_op.getAttrs());
|
||||
// Clear the output shapes as it is not needed for XLA lowering.
|
||||
new_while.setAttr("output_shapes", builder.getArrayAttr({}));
|
||||
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
|
||||
if (ta_arg_buffer_type(i)) {
|
||||
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
|
||||
@ -663,8 +661,6 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
|
||||
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
|
||||
then_branch.getType().getResults(),
|
||||
operands, if_op.getAttrs());
|
||||
// Clear the output shapes as it is not needed for XLA lowering.
|
||||
new_if.setAttr("output_shapes", builder.getArrayAttr({}));
|
||||
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
|
||||
auto retval = f.front().getTerminator()->getOperand(ret_ind);
|
||||
auto arg = retval.dyn_cast<BlockArgument>();
|
||||
|
@ -190,22 +190,14 @@ LogicalResult HandleWhileOp(
|
||||
}
|
||||
// Create the new while op.
|
||||
auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
|
||||
auto new_output_shapes =
|
||||
llvm::to_vector<8>(while_op.output_shapes().getValue());
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
auto it = buffer_to_size->find(while_op.getOperand(i));
|
||||
if (it == buffer_to_size->end()) continue;
|
||||
new_while_operands.push_back(it->getSecond().size);
|
||||
if (!new_output_shapes.empty()) {
|
||||
// Size is a scalar shape.
|
||||
new_output_shapes.push_back(
|
||||
mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef<int64_t>()));
|
||||
}
|
||||
}
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
|
||||
for (const auto& entry : output_buffer_to_size) {
|
||||
(*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
|
||||
new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
|
||||
|
@ -365,16 +365,6 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body,
|
||||
while_op.getLoc(),
|
||||
append_types(llvm::to_vector<4>(while_op.getResultTypes())),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
if (new_while_op.output_shapes().size() != 0) {
|
||||
auto new_output_shapes = llvm::to_vector<4>(new_while_op.output_shapes());
|
||||
// VarHandleOp is a scalar shape resource.
|
||||
for (int64_t i = 0; i < state_vars.size(); ++i) {
|
||||
new_output_shapes.push_back(
|
||||
mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef<int64_t>()));
|
||||
}
|
||||
new_while_op.setAttr("output_shapes",
|
||||
builder.getArrayAttr(new_output_shapes));
|
||||
}
|
||||
while_op.replaceAllUsesWith(
|
||||
new_while_op.getResults().take_front(while_op.getNumResults()));
|
||||
while_op.erase();
|
||||
|
@ -1836,6 +1836,9 @@ Operation *AvgPoolDivideByCount(
|
||||
return result;
|
||||
}
|
||||
|
||||
Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
|
||||
Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
|
||||
|
||||
// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
|
||||
// dimensions with add as the reduction function. The reduction result is
|
||||
// then divided by the number of elements in the window.
|
||||
@ -1846,8 +1849,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input_value = GetAvgPoolInput(op);
|
||||
auto input_type =
|
||||
op.value().getType().template dyn_cast<RankedTensorType>();
|
||||
input_value.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!input_type) return failure();
|
||||
|
||||
// We will do accumulation first; use a larger bitwidth if suitable.
|
||||
@ -1862,8 +1866,6 @@ class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
|
||||
else
|
||||
result_type = UnrankedTensorType::get(sum_element_type);
|
||||
|
||||
Value input_value = op.value();
|
||||
|
||||
// Convert if we need enlarge the element type's bitwidth.
|
||||
if (input_element_type != sum_element_type)
|
||||
input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
|
||||
@ -5398,50 +5400,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
|
||||
// hints, since we currently don't have an implementation that can use this
|
||||
// information. Adds appropriate casts where necessary to align element types
|
||||
// of operands and result for `TF::MatMulOp`.
|
||||
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
|
||||
public:
|
||||
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Result type must be f32 for applying the pattern (currently this is
|
||||
// required by the op anyway but this might change).
|
||||
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
|
||||
return failure();
|
||||
}
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
|
||||
for (Value &operand : operands) {
|
||||
TensorType tensor_type = operand.getType().cast<TensorType>();
|
||||
Type element_type = tensor_type.getElementType();
|
||||
if (element_type.isF32()) continue;
|
||||
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
|
||||
// must be bf16 here.
|
||||
assert(element_type.isBF16());
|
||||
Type tensor_type_f32;
|
||||
if (tensor_type.hasRank()) {
|
||||
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
|
||||
FloatType::getF32(context));
|
||||
} else {
|
||||
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
|
||||
}
|
||||
// Add cast to f32 to conform with element type of result.
|
||||
operand =
|
||||
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
|
||||
}
|
||||
Value result = rewriter.create<TF::MatMulOp>(
|
||||
op.getLoc(), op.product().getType(), operands[0], operands[1],
|
||||
op.transpose_a(), op.transpose_b());
|
||||
|
||||
rewriter.replaceOp(op, {result});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Emits debug information which includes the number of ops of each type which
|
||||
// failed to legalize.
|
||||
void EmitLegalizationErrors(Operation *op,
|
||||
@ -5531,11 +5489,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
|
||||
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
|
||||
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
|
||||
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
|
||||
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
|
||||
ConvertRandomShuffleOp, ConvertXlaShardingOp,
|
||||
ConvertXlaDynamicUpdateSliceOp>(op->getContext());
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
@ -223,18 +224,20 @@ static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
|
||||
|
||||
class Tf2XlaRewriter {
|
||||
public:
|
||||
static LogicalResult RewriteOp(Operation* op, OpBuilder& builder,
|
||||
static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
|
||||
const std::string& device_type) {
|
||||
Tf2XlaRewriter rewriter(op, builder, device_type);
|
||||
return rewriter.LegalizeOp();
|
||||
Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type);
|
||||
return tf2xla_rewriter.LegalizeOp();
|
||||
}
|
||||
|
||||
private:
|
||||
Tf2XlaRewriter(Operation* op, OpBuilder builder,
|
||||
Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
|
||||
const std::string& device_type)
|
||||
: op_(op),
|
||||
device_type_(device_type),
|
||||
hlo_builder_(op->getName().getStringRef().str(), builder, op->getLoc()),
|
||||
rewriter_(rewriter),
|
||||
hlo_builder_(op->getName().getStringRef().str(), rewriter_,
|
||||
op->getLoc()),
|
||||
context_(nullptr) {}
|
||||
|
||||
~Tf2XlaRewriter() {
|
||||
@ -259,6 +262,7 @@ class Tf2XlaRewriter {
|
||||
Operation* op_;
|
||||
std::string device_type_;
|
||||
|
||||
PatternRewriter& rewriter_;
|
||||
::xla::MlirHloBuilder hlo_builder_;
|
||||
tensorflow::OpOrArgLocNameMapper name_mapper_;
|
||||
|
||||
@ -429,6 +433,8 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
|
||||
|
||||
// Replace uses of old results using the corresponding value after the
|
||||
// lowering.
|
||||
llvm::SmallVector<Value, 2> values;
|
||||
values.reserve(op_->getNumResults());
|
||||
for (int i = 0, e = op_->getNumResults(); i < e; i++) {
|
||||
tensorflow::Tensor* output = op_context.mutable_output(i);
|
||||
const tensorflow::XlaExpression* expr =
|
||||
@ -442,10 +448,9 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
|
||||
value =
|
||||
hlo_builder_.create<mlir::TensorCastOp>(value, old_result.getType());
|
||||
}
|
||||
old_result.replaceAllUsesWith(value);
|
||||
values.push_back(value);
|
||||
}
|
||||
|
||||
op_->erase();
|
||||
rewriter_.replaceOp(op_, values);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -529,6 +534,11 @@ static PassRegistration<LegalizeTF> pass(
|
||||
|
||||
} // end namespace
|
||||
|
||||
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
|
||||
OwningRewritePatternList& patterns) {
|
||||
patterns.insert<Tf2XlaRewritePattern>(device_type.str());
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
|
||||
llvm::StringRef device_type) {
|
||||
return std::make_unique<LegalizeTF>(device_type);
|
||||
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
@ -41,6 +44,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
|
||||
llvm::StringRef device_type);
|
||||
|
||||
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
|
||||
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
|
||||
OwningRewritePatternList& patterns);
|
||||
|
||||
/// Lowers from TF dialect's control flow to HLO dialect's control flow.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
|
||||
|
||||
|
@ -1333,6 +1333,7 @@ tf_xla_py_test(
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
"notap", # b/162025277
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
|
@ -8,6 +8,7 @@ load(
|
||||
"tf_cuda_tests_tags",
|
||||
"tf_exec_properties",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
def all_backends():
|
||||
b = ["cpu"] + plugins.keys()
|
||||
@ -121,7 +122,7 @@ def tf_xla_py_test(
|
||||
updated_name = updated_name[:-5]
|
||||
updated_name += "_mlir_bridge_test"
|
||||
|
||||
native.py_test(
|
||||
py_test(
|
||||
name = updated_name,
|
||||
srcs = srcs,
|
||||
srcs_version = "PY2AND3",
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for while loops in XLA."""
|
||||
"""Tests for case statements in XLA."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
|
||||
@ -332,7 +333,6 @@ void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
|
||||
Status CreateTRTNode(const ConversionParams& params,
|
||||
const std::vector<EngineInfo>& infos, int pos,
|
||||
int max_batch_size, Graph* graph,
|
||||
nvinfer1::IGpuAllocator* alloc,
|
||||
std::vector<Node*>* engine_nodes) {
|
||||
const auto& info = infos.at(pos);
|
||||
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
|
||||
@ -428,16 +428,30 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
// Build the engine and get its serialized representation.
|
||||
string segment_string;
|
||||
if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
|
||||
std::pair<int, Allocator*> device_allocator =
|
||||
GetDeviceAndAllocator(params, info);
|
||||
int cuda_device_id = 0;
|
||||
std::unique_ptr<TRTBaseAllocator> trt_allocator;
|
||||
if (device_allocator.first >= 0) {
|
||||
cuda_device_id = device_allocator.first;
|
||||
trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
|
||||
} else {
|
||||
// The value in trt_allocator is a nullptr and cudamalloc will be used.
|
||||
LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
|
||||
"device 0 and use cudamalloc as an allocator";
|
||||
}
|
||||
cudaSetDevice(cuda_device_id);
|
||||
|
||||
auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
|
||||
// Create static engine for fp32/fp16 mode.
|
||||
// Create static engines with precision_mode fp32/fp16.
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
|
||||
// TODO(sami): What happens if 1st dim is not batch?
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
|
||||
info.segment_graph_def,
|
||||
calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
|
||||
max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
|
||||
alloc, /*calibrator=*/nullptr, &engine, info.use_calibration,
|
||||
params.use_implicit_batch, /*convert_successfully=*/nullptr,
|
||||
trt_allocator.get(), /*calibrator=*/nullptr, &engine,
|
||||
info.use_calibration, params.use_implicit_batch,
|
||||
/*convert_successfully=*/nullptr,
|
||||
/*profile=*/nullptr));
|
||||
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
|
||||
segment_string = string(static_cast<const char*>(engine_data->data()),
|
||||
@ -793,13 +807,27 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create a TRT node for each segment using its EngineInfo.
|
||||
int old_cuda_device = 0;
|
||||
auto err = cudaGetDevice(&old_cuda_device);
|
||||
if (err != cudaSuccess) {
|
||||
LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
|
||||
// Save the cuda device if we may need to switch to another cuda device to
|
||||
// build static engines.
|
||||
absl::optional<int> old_cuda_device = absl::nullopt;
|
||||
if (!params.is_dyn_op) {
|
||||
int cuda_device_id;
|
||||
cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
|
||||
if (cuda_error != cudaSuccess) {
|
||||
LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
|
||||
<< cudaGetErrorString(cuda_error);
|
||||
} else {
|
||||
VLOG(1) << "Current cuda device is " << cuda_device_id;
|
||||
old_cuda_device = cuda_device_id;
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Current cuda device is " << old_cuda_device;
|
||||
|
||||
auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
|
||||
if (old_cuda_device.has_value()) {
|
||||
cudaSetDevice(old_cuda_device.value());
|
||||
}
|
||||
});
|
||||
|
||||
std::vector<Node*> engine_nodes;
|
||||
engine_nodes.resize(engine_segments.size());
|
||||
for (int i = 0; i < engine_segments.size(); ++i) {
|
||||
@ -813,24 +841,8 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
2.0;
|
||||
VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
|
||||
<< engine.engine_name;
|
||||
// The allocator is used to build the engine. The build and the built engine
|
||||
// will be destroyed after we get the serialized engine string, so it's fine
|
||||
// to use unique_ptr here.
|
||||
std::unique_ptr<TRTBaseAllocator> alloc;
|
||||
auto device_alloc = GetDeviceAndAllocator(params, engine);
|
||||
int cuda_device_id = 0;
|
||||
if (device_alloc.first >= 0) {
|
||||
cuda_device_id = device_alloc.first;
|
||||
alloc.reset(new TRTDeviceAllocator(device_alloc.second));
|
||||
} else {
|
||||
// Setting allocator as nullptr should get revert to the cudamalloc
|
||||
LOG_WARNING_WITH_PREFIX
|
||||
<< "Can't identify the cuda device. Running on device 0 ";
|
||||
}
|
||||
cudaSetDevice(cuda_device_id);
|
||||
auto status =
|
||||
CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
|
||||
alloc.get(), &engine_nodes);
|
||||
auto status = CreateTRTNode(params, engine_segments, i,
|
||||
params.max_batch_size, &graph, &engine_nodes);
|
||||
|
||||
string msg = StrCat("segment ", i, " consisting of ",
|
||||
converted_segments.at(i).size(), " nodes by ",
|
||||
@ -859,7 +871,6 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
}
|
||||
}
|
||||
}
|
||||
cudaSetDevice(old_cuda_device);
|
||||
graph.ToGraphDef(params.output_graph_def);
|
||||
VLOG(1) << "Returning from conversion";
|
||||
return Status::OK();
|
||||
|
@ -1309,7 +1309,8 @@ std::vector<float> GetDataAsFloat(InputOutputData& data) {
|
||||
class OpConverterTest : public ::testing::Test {
|
||||
public:
|
||||
OpConverterTest()
|
||||
: scope_(Scope::NewRootScope()), allocator_(new GpuManagedAllocator()) {
|
||||
: tensor_buffer_allocator_(new GpuManagedAllocator()),
|
||||
scope_(Scope::NewRootScope()) {
|
||||
QCHECK_EQ(0, cudaStreamCreate(&stream_));
|
||||
Reset();
|
||||
}
|
||||
@ -1341,7 +1342,7 @@ class OpConverterTest : public ::testing::Test {
|
||||
// Constructs a flat tensor with 'vals' in Unified Memory.
|
||||
template <typename T>
|
||||
Tensor AsTensor(gtl::ArraySlice<T> vals) { // non-absl ok
|
||||
Tensor ret(allocator_.get(), DataTypeToEnum<T>::value,
|
||||
Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum<T>::value,
|
||||
{static_cast<int64>(vals.size())});
|
||||
std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
|
||||
return ret;
|
||||
@ -1351,7 +1352,7 @@ class OpConverterTest : public ::testing::Test {
|
||||
template <typename T>
|
||||
Tensor AsTensor(gtl::ArraySlice<T> vals, // non-absl ok
|
||||
const TensorShape& shape) {
|
||||
Tensor ret(allocator_.get(), DataTypeToEnum<T>::value,
|
||||
Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum<T>::value,
|
||||
{static_cast<int64>(vals.size())});
|
||||
CHECK(ret.CopyFrom(AsTensor(vals), shape));
|
||||
return ret;
|
||||
@ -1363,7 +1364,8 @@ class OpConverterTest : public ::testing::Test {
|
||||
template <typename T>
|
||||
Tensor AsTensor(std::vector<T> vals, const std::vector<int> input_dims,
|
||||
DataType tf_type) {
|
||||
Tensor ret(allocator_.get(), tf_type, {static_cast<int64>(vals.size())});
|
||||
Tensor ret(tensor_buffer_allocator_.get(), tf_type,
|
||||
{static_cast<int64>(vals.size())});
|
||||
if (tf_type == DT_FLOAT) {
|
||||
auto conv_vals = CastTestVector<T, float>(vals);
|
||||
std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat<float>().data());
|
||||
@ -1646,13 +1648,15 @@ class OpConverterTest : public ::testing::Test {
|
||||
Logger logger_;
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||
cudaStream_t stream_;
|
||||
// Used to create placeholders with shape and data type information. The
|
||||
// created placeholders will be used as inputs to the node to be verified,
|
||||
// thus we need the shape and data type information to get a non-empty
|
||||
// GraphProperties.
|
||||
std::unique_ptr<Allocator> tensor_buffer_allocator_;
|
||||
// The scope that contains the graph being converted. Because
|
||||
// tensor_buffer_allocator_ provides the storage for tensor contents that are
|
||||
// represented as attributes for graph nodes within scope_,
|
||||
// tensor_buffer_allocator_ needs to be available when destructing scope_.
|
||||
// Therefore, scope_ comes after tensor_buffer_allocator_ in the class member
|
||||
// field list.
|
||||
Scope scope_;
|
||||
std::unordered_map<string, Output> node_inputs_;
|
||||
std::unique_ptr<Allocator> allocator_;
|
||||
};
|
||||
|
||||
// General test parameters to be used with ops that take a single input tensor.
|
||||
|
@ -160,17 +160,15 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
|
||||
XlaCompiler* compiler = ctx->compiler();
|
||||
|
||||
std::vector<XlaCompiler::CompilationResult> branch_results(num_branches);
|
||||
std::vector<XlaCompiler::CompilationResult*> branch_results_p(num_branches);
|
||||
for (int j = 0; j < num_branches; ++j) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
compiler->CompileFunction(options, branches[j], arguments,
|
||||
&branch_results[j]));
|
||||
branch_results_p[j] = &branch_results[j];
|
||||
}
|
||||
|
||||
bool has_tensor_array_gradients = false;
|
||||
for (XlaCompiler::CompilationResult* result : branch_results_p) {
|
||||
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
|
||||
for (XlaCompiler::CompilationResult& result : branch_results) {
|
||||
for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetResourceInput(update.input_index + 1, &resource));
|
||||
|
@ -47,6 +47,122 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
// Populates tensor array gradients for compiled branches, returns whether the
|
||||
// set of found tensor array gradients is non-empty.
|
||||
static xla::StatusOr<bool> PopulateTensorArrayGradients(
|
||||
XlaOpKernelContext* ctx, xla::XlaBuilder* b,
|
||||
absl::Span<XlaCompiler::Argument> arguments,
|
||||
XlaCompiler::CompilationResult* then_result,
|
||||
XlaCompiler::CompilationResult* else_result) {
|
||||
bool has_tensor_array_gradients = false;
|
||||
for (XlaCompiler::CompilationResult* result : {then_result, else_result}) {
|
||||
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
|
||||
XlaResource* resource;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->GetResourceInput(update.input_index + 1, &resource));
|
||||
XlaCompiler::Argument& arg = arguments[update.input_index];
|
||||
|
||||
// Add any TensorArray gradients touched by the then/else computation to
|
||||
// the enclosing graph.
|
||||
for (const string& grad_source : update.tensor_array_gradients_accessed) {
|
||||
VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
|
||||
<< grad_source;
|
||||
XlaResource* gradient;
|
||||
TF_RETURN_IF_ERROR(resource->GetOrCreateTensorArrayGradient(
|
||||
grad_source, b, &gradient));
|
||||
}
|
||||
// Add all of the TensorArray gradients to the argument. For simplicity,
|
||||
// we always pass all known gradients.
|
||||
for (const auto& gradient : resource->tensor_array_gradients()) {
|
||||
arg.tensor_array_gradients.insert(gradient.first);
|
||||
}
|
||||
if (!resource->tensor_array_gradients().empty())
|
||||
has_tensor_array_gradients = true;
|
||||
}
|
||||
}
|
||||
return has_tensor_array_gradients;
|
||||
}
|
||||
|
||||
// Checks that shapes matches on both sides of the conditional.
|
||||
static Status ValidateShapes(
|
||||
XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& then_result,
|
||||
const XlaCompiler::CompilationResult& else_result) {
|
||||
// Check that both branches have identical input shapes.
|
||||
if (then_result.xla_input_shapes.size() != 1) {
|
||||
return errors::FailedPrecondition("Expected one input shape");
|
||||
}
|
||||
|
||||
xla::Shape then_input_shape = then_result.xla_input_shapes[0];
|
||||
if (!then_input_shape.IsTuple()) {
|
||||
return errors::FailedPrecondition("Expected tuple shape");
|
||||
}
|
||||
|
||||
if (else_result.xla_input_shapes.size() != 1) {
|
||||
return errors::FailedPrecondition("Expected one input shape");
|
||||
}
|
||||
xla::Shape else_input_shape = else_result.xla_input_shapes[0];
|
||||
if (!else_input_shape.IsTuple()) {
|
||||
return errors::FailedPrecondition("Expected tuple shape");
|
||||
}
|
||||
if (!xla::ShapeUtil::Compatible(then_input_shape, else_input_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Input shapes of then and else branches do not match: ",
|
||||
xla::ShapeUtil::HumanString(then_input_shape), " vs. ",
|
||||
xla::ShapeUtil::HumanString(else_input_shape));
|
||||
}
|
||||
|
||||
// Check that both branches have identical output shapes.
|
||||
if (!xla::ShapeUtil::Compatible(then_result.xla_output_shape,
|
||||
else_result.xla_output_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Output shapes of then and else branches do not match: ",
|
||||
xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ",
|
||||
xla::ShapeUtil::HumanString(else_result.xla_output_shape));
|
||||
}
|
||||
|
||||
// Check that both branches have same TensorList output indices.
|
||||
for (int output_index = 0; output_index < then_result.outputs.size();
|
||||
output_index++) {
|
||||
bool is_tensor_list_in_then_branch =
|
||||
then_result.outputs[output_index].is_tensor_list;
|
||||
bool is_tensor_list_in_else_branch =
|
||||
else_result.outputs[output_index].is_tensor_list;
|
||||
if (is_tensor_list_in_then_branch != is_tensor_list_in_else_branch) {
|
||||
return errors::FailedPrecondition(
|
||||
"Output #", output_index, " is ",
|
||||
(is_tensor_list_in_then_branch ? "" : "not"),
|
||||
" a TensorList in then branch, but is ",
|
||||
(is_tensor_list_in_else_branch ? "" : "not"),
|
||||
" a TensorList in else branch");
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape);
|
||||
VLOG(2) << "Output shape: "
|
||||
<< xla::ShapeUtil::HumanString(then_result.xla_output_shape);
|
||||
|
||||
// We set return_updated_values_for_all_resources=true and we pass the same
|
||||
// arguments to both computations, so the resource update count must match.
|
||||
if (then_result.resource_updates.size() !=
|
||||
else_result.resource_updates.size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Different number of resources in then and else branch");
|
||||
}
|
||||
|
||||
for (int i = 0; i < then_result.resource_updates.size(); ++i) {
|
||||
const auto& lhs = then_result.resource_updates[i];
|
||||
const auto& rhs = else_result.resource_updates[i];
|
||||
bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape &&
|
||||
lhs.tensor_array_gradients_accessed ==
|
||||
rhs.tensor_array_gradients_accessed;
|
||||
if (!equal) {
|
||||
return errors::FailedPrecondition(
|
||||
"Mismatch in resource of then and else branch for resource ", i);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(b/35949885): There is duplication here with the handling of the
|
||||
// while_op. Refactor the common code out/rework.
|
||||
void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
@ -137,35 +253,12 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
|
||||
arguments, &else_result));
|
||||
|
||||
bool has_tensor_array_gradients = false;
|
||||
for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
|
||||
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
|
||||
XlaResource* resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetResourceInput(update.input_index + 1, &resource));
|
||||
XlaCompiler::Argument& arg = arguments[update.input_index];
|
||||
|
||||
// Add any TensorArray gradients touched by the then/else computation to
|
||||
// the enclosing graph.
|
||||
for (const string& grad_source : update.tensor_array_gradients_accessed) {
|
||||
VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
|
||||
<< grad_source;
|
||||
XlaResource* gradient;
|
||||
OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
|
||||
grad_source, b, &gradient));
|
||||
}
|
||||
// Add all of the TensorArray gradients to the argument. For simplicity,
|
||||
// we always pass all known gradients.
|
||||
for (const auto& gradient : resource->tensor_array_gradients()) {
|
||||
arg.tensor_array_gradients.insert(gradient.first);
|
||||
}
|
||||
if (!resource->tensor_array_gradients().empty())
|
||||
has_tensor_array_gradients = true;
|
||||
}
|
||||
}
|
||||
xla::StatusOr<bool> has_tensor_array_gradients = PopulateTensorArrayGradients(
|
||||
ctx, b, absl::MakeSpan(arguments), &then_result, &else_result);
|
||||
OP_REQUIRES_OK(ctx, has_tensor_array_gradients.status());
|
||||
|
||||
// Recompile the functions to update the argument shapes for tensor arrays.
|
||||
if (has_tensor_array_gradients) {
|
||||
if (*has_tensor_array_gradients) {
|
||||
then_result = {};
|
||||
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_,
|
||||
arguments, &then_result));
|
||||
@ -174,72 +267,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
arguments, &else_result));
|
||||
}
|
||||
|
||||
// Check that both branches have identical input shapes.
|
||||
OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1,
|
||||
errors::FailedPrecondition("Expected one input shape"));
|
||||
xla::Shape then_input_shape = then_result.xla_input_shapes[0];
|
||||
OP_REQUIRES(ctx, then_input_shape.IsTuple(),
|
||||
errors::FailedPrecondition("Expected tuple shape"));
|
||||
OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1,
|
||||
errors::FailedPrecondition("Expected one input shape"));
|
||||
xla::Shape else_input_shape = else_result.xla_input_shapes[0];
|
||||
OP_REQUIRES(ctx, else_input_shape.IsTuple(),
|
||||
errors::FailedPrecondition("Expected tuple shape"));
|
||||
OP_REQUIRES(ctx,
|
||||
xla::ShapeUtil::Compatible(then_input_shape, else_input_shape),
|
||||
errors::InvalidArgument(
|
||||
"Input shapes of then and else branches do not match: ",
|
||||
xla::ShapeUtil::HumanString(then_input_shape), " vs. ",
|
||||
xla::ShapeUtil::HumanString(else_input_shape)));
|
||||
|
||||
// Check that both branches have identical output shapes.
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
xla::ShapeUtil::Compatible(then_result.xla_output_shape,
|
||||
else_result.xla_output_shape),
|
||||
errors::InvalidArgument(
|
||||
"Output shapes of then and else branches do not match: ",
|
||||
xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ",
|
||||
xla::ShapeUtil::HumanString(else_result.xla_output_shape)));
|
||||
|
||||
// Check that both branches have same TensorList output indices.
|
||||
for (int output_index = 0; output_index < then_result.outputs.size();
|
||||
output_index++) {
|
||||
bool is_tensor_list_in_then_branch =
|
||||
then_result.outputs[output_index].is_tensor_list;
|
||||
bool is_tensor_list_in_else_branch =
|
||||
else_result.outputs[output_index].is_tensor_list;
|
||||
OP_REQUIRES(
|
||||
ctx, is_tensor_list_in_then_branch == is_tensor_list_in_else_branch,
|
||||
errors::FailedPrecondition("Output #", output_index, " is ",
|
||||
(is_tensor_list_in_then_branch ? "" : "not"),
|
||||
" a TensorList in then branch, but is ",
|
||||
(is_tensor_list_in_else_branch ? "" : "not"),
|
||||
" a TensorList in else branch"));
|
||||
}
|
||||
|
||||
VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape);
|
||||
VLOG(2) << "Output shape: "
|
||||
<< xla::ShapeUtil::HumanString(then_result.xla_output_shape);
|
||||
|
||||
// We set return_updated_values_for_all_resources=true and we pass the same
|
||||
// arguments to both computations, so the resource update count must match.
|
||||
OP_REQUIRES(ctx,
|
||||
then_result.resource_updates.size() ==
|
||||
else_result.resource_updates.size(),
|
||||
errors::FailedPrecondition(
|
||||
"Different number of resources in then and else branch"));
|
||||
for (int i = 0; i < then_result.resource_updates.size(); ++i) {
|
||||
const auto& lhs = then_result.resource_updates[i];
|
||||
const auto& rhs = else_result.resource_updates[i];
|
||||
bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape &&
|
||||
lhs.tensor_array_gradients_accessed ==
|
||||
rhs.tensor_array_gradients_accessed;
|
||||
OP_REQUIRES(
|
||||
ctx, equal,
|
||||
errors::FailedPrecondition(
|
||||
"Mismatch in resource of then and else branch for resource ", i));
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ValidateShapes(ctx, then_result, else_result));
|
||||
|
||||
int num_inputs = then_result.input_mapping.size();
|
||||
std::vector<xla::XlaOp> inputs(num_inputs);
|
||||
@ -263,22 +291,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
auto input_tuple = xla::Tuple(b, inputs);
|
||||
xla::XlaOp input_tuple = xla::Tuple(b, inputs);
|
||||
xla::XlaOp outputs =
|
||||
xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation,
|
||||
input_tuple, *else_result.computation);
|
||||
|
||||
// Sets non-variable outputs.
|
||||
for (int i = 0; i < output_types_.size(); ++i) {
|
||||
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
LOG(INFO) << "Setting output " << i;
|
||||
auto shape_or = b->GetShape(output_handle);
|
||||
if (shape_or.ok()) {
|
||||
LOG(INFO) << "Shape for output " << i << ": "
|
||||
<< xla::ShapeUtil::HumanString(shape_or.ValueOrDie());
|
||||
} else {
|
||||
LOG(INFO) << "Shape unknown for output " << i;
|
||||
}
|
||||
xla::StatusOr<xla::Shape> shape = b->GetShape(output_handle);
|
||||
VLOG(2) << "Setting output " << i << " with shape "
|
||||
<< (shape.ok() ? shape->ToString() : "<unknown>");
|
||||
}
|
||||
// We have checked that both branches have same TensorList output indices.
|
||||
if (then_result.outputs[i].is_tensor_list) {
|
||||
@ -287,6 +311,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
|
||||
ctx->SetOutput(i, output_handle);
|
||||
}
|
||||
}
|
||||
|
||||
if (has_token_input_output_) {
|
||||
// Set token output for this "If" op. Token output is the last output of
|
||||
// XLA computation, which comes after all "normal" TF outputs and resource
|
||||
|
@ -30,8 +30,15 @@ class ShardingOp : public XlaOpKernel {
|
||||
~ShardingOp() override = default;
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
auto shape_or = ctx->InputXlaShape(0);
|
||||
xla::XlaOp input;
|
||||
{
|
||||
// The builder might create a broadcast from a constant, so we clear
|
||||
// sharding for the input.
|
||||
xla::XlaScopedShardingAssignment no_sharding(ctx->builder(),
|
||||
absl::nullopt);
|
||||
input = ctx->Input(0);
|
||||
}
|
||||
auto shape_or = ctx->builder()->GetShape(input);
|
||||
OP_REQUIRES_OK(ctx, shape_or.status());
|
||||
|
||||
ctx->SetOutput(
|
||||
|
@ -28,6 +28,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -415,8 +416,11 @@ sharding = gen_xla_ops.xla_sharding
|
||||
|
||||
@ops.RegisterGradient("XlaSharding")
|
||||
def _sharding_grad(op, grad):
|
||||
del op # Unused
|
||||
return [grad]
|
||||
grad_sharding = gen_xla_ops.xla_sharding(grad)
|
||||
# pylint: disable=protected-access
|
||||
grad_sharding.op._set_attr(
|
||||
"_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding")))
|
||||
return [grad_sharding]
|
||||
|
||||
|
||||
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
|
||||
|
@ -1030,6 +1030,11 @@ Status XlaCompiler::BuildArguments(
|
||||
xla::XlaScopedShardingAssignment assign_sharding(
|
||||
builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
|
||||
: it->second);
|
||||
auto& arg = args[input_to_args->at(i)];
|
||||
|
||||
xla::OpMetadata arg_metadata;
|
||||
arg_metadata.set_op_name(arg.node_name);
|
||||
builder->SetOneShotOpMetadata(arg_metadata);
|
||||
arg_handles[i] = xla::GetTupleElement(tuple, i);
|
||||
}
|
||||
} else {
|
||||
|
@ -3021,7 +3021,12 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
||||
instr.add_operand_ids(operand.handle());
|
||||
}
|
||||
|
||||
*instr.mutable_metadata() = metadata_;
|
||||
if (one_shot_metadata_.has_value()) {
|
||||
*instr.mutable_metadata() = one_shot_metadata_.value();
|
||||
one_shot_metadata_.reset();
|
||||
} else {
|
||||
*instr.mutable_metadata() = metadata_;
|
||||
}
|
||||
if (sharding_) {
|
||||
*instr.mutable_sharding() = *sharding_;
|
||||
}
|
||||
|
@ -153,6 +153,11 @@ class XlaBuilder {
|
||||
// OpMetadata attached until a call to ClearOpMetadata.
|
||||
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
|
||||
|
||||
// Similar to SetOpMetadata, but only set the metadata for the next op.
|
||||
void SetOneShotOpMetadata(OpMetadata metadata) {
|
||||
metadata_ = std::move(metadata);
|
||||
}
|
||||
|
||||
// Clears the HloMetadata state.
|
||||
void ClearOpMetadata() { metadata_.Clear(); }
|
||||
|
||||
@ -842,6 +847,9 @@ class XlaBuilder {
|
||||
// throughout the TensorFlow op kernel implementations).
|
||||
OpMetadata metadata_;
|
||||
|
||||
// A temporary metadata that will only be applied to the next op created.
|
||||
absl::optional<OpMetadata> one_shot_metadata_;
|
||||
|
||||
// Sharding for this operator. This is structured as a "model"-like operation,
|
||||
// in order to simplify client code, similar to metadata_.
|
||||
absl::optional<OpSharding> sharding_;
|
||||
|
@ -17,6 +17,8 @@ upper_tabs:
|
||||
path: /xla
|
||||
- title: XLA architecture
|
||||
path: /xla/architecture
|
||||
- title: Known issues
|
||||
path: /xla/known_issues
|
||||
- title: Broadcasting semantics
|
||||
path: /xla/broadcasting
|
||||
- title: Develop a new backend for XLA
|
||||
|
@ -177,30 +177,6 @@ a bug to a single XLA program by using the
|
||||
[`replay_computation`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/run_hlo_module_main.cc)
|
||||
and iteratively running it on generated programs.
|
||||
|
||||
## Known Issues
|
||||
|
||||
Compilation with XLA can greatly improve the performance of your programs, but
|
||||
the TensorFlow interop has a number of known sharp corners.
|
||||
|
||||
### TensorArray TF/XLA Interconversion
|
||||
|
||||
The problem manifests itself as an error message
|
||||
`Support for TensorList crossing the XLA/TF boundary is not implemented`.
|
||||
|
||||
XLA supports `tf.TensorArray`. However, the _interconversion_ between TF and
|
||||
XLA representations is not implemented yet.
|
||||
This error often arises when the `TensorArray` is used inside the compiled
|
||||
block, but the derivative is taken outside.
|
||||
|
||||
Workaround: compile the outermost scope which is taking the derivative.
|
||||
|
||||
### Random Number Generation
|
||||
|
||||
XLA currently ignores TF seeds to random operations. This affects stateful TF
|
||||
random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will
|
||||
behave as if the compilation was seeded with a new unique seed at each run. This
|
||||
limitation does not apply to stateless random ops.
|
||||
|
||||
## XLA Frontends
|
||||
|
||||
Apart from TensorFlow, XLA programs can be generated by:
|
||||
|
32
tensorflow/compiler/xla/g3doc/known_issues.md
Normal file
32
tensorflow/compiler/xla/g3doc/known_issues.md
Normal file
@ -0,0 +1,32 @@
|
||||
# Known Issues
|
||||
|
||||
Compilation with XLA can greatly improve the performance of your programs, but
|
||||
the TensorFlow interop has a number of known sharp corners.
|
||||
|
||||
## TensorArray TF/XLA interconversion
|
||||
|
||||
The problem manifests itself as an error message
|
||||
`Support for TensorList crossing the XLA/TF boundary is not implemented`.
|
||||
|
||||
XLA supports `tf.TensorArray`. However, the _interconversion_ between TF and
|
||||
XLA representations is not implemented yet.
|
||||
This error often arises when the `TensorArray` is used inside the compiled
|
||||
block, but the derivative is taken outside.
|
||||
|
||||
Workaround: compile the outermost scope which is taking the derivative.
|
||||
|
||||
## Dynamic `tf.TensorArray` is not supported
|
||||
|
||||
Writes into `tf.TensorArray(..., dynamic_size=True)` are not compilable with
|
||||
XLA, as such writes require an unknown number of reallocations when the array
|
||||
exceeds the original bound.
|
||||
|
||||
Workaround: provide a statically known bound to your arrays.
|
||||
|
||||
## Random number generation
|
||||
|
||||
XLA currently ignores TF seeds to random operations. This affects stateful TF
|
||||
random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will
|
||||
behave as if the compilation was seeded with a new unique seed at each run. This
|
||||
limitation does not apply to stateless random ops.
|
||||
|
@ -114,24 +114,26 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
"CustomCall",
|
||||
[](XlaBuilder* builder, const py::bytes& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const py::bytes& opaque) -> XlaOp {
|
||||
return CustomCall(builder, call_target_name, operands, shape, opaque);
|
||||
const py::bytes& opaque, bool has_side_effect) -> XlaOp {
|
||||
return CustomCall(builder, call_target_name, operands, shape, opaque,
|
||||
has_side_effect);
|
||||
},
|
||||
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
|
||||
py::arg("shape"), py::arg("opaque") = py::bytes(""));
|
||||
py::arg("shape"), py::arg("opaque") = py::bytes(""),
|
||||
py::arg("has_side_effect") = false);
|
||||
ops.def(
|
||||
"CustomCallWithLayout",
|
||||
[](XlaBuilder* builder, const py::bytes& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||
absl::Span<const Shape> operand_shapes_with_layout,
|
||||
const py::bytes& opaque) -> XlaOp {
|
||||
return CustomCallWithLayout(builder, call_target_name, operands,
|
||||
shape_with_layout,
|
||||
operand_shapes_with_layout, opaque);
|
||||
const py::bytes& opaque, bool has_side_effect) -> XlaOp {
|
||||
return CustomCallWithLayout(
|
||||
builder, call_target_name, operands, shape_with_layout,
|
||||
operand_shapes_with_layout, opaque, has_side_effect);
|
||||
},
|
||||
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
|
||||
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
|
||||
py::arg("opaque") = py::bytes(""));
|
||||
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
|
||||
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("precision_config") = nullptr);
|
||||
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
|
||||
|
@ -2475,6 +2475,33 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
HloInstruction *a, *b, *c1, *c2;
|
||||
// Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
|
||||
// constant1*constant2)
|
||||
if (Match(multiply,
|
||||
m::Multiply(
|
||||
m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
|
||||
m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
|
||||
TF_ASSIGN_OR_RETURN(auto* product_of_constants,
|
||||
MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
|
||||
if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
|
||||
!ShapeUtil::IsScalar(multiply->shape())) {
|
||||
product_of_constants =
|
||||
computation_->AddInstruction(HloInstruction::CreateBroadcast(
|
||||
multiply->shape(), product_of_constants, {}));
|
||||
}
|
||||
|
||||
return ReplaceWithNewInstruction(
|
||||
multiply,
|
||||
HloInstruction::CreateBinary(
|
||||
multiply->shape(), HloOpcode::kMultiply,
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
multiply->shape(), HloOpcode::kMultiply, a, b)),
|
||||
product_of_constants));
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
|
||||
HloInstruction *a, *c1, *c2;
|
||||
if (Match(multiply,
|
||||
@ -3088,6 +3115,17 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
|
||||
HloOpcode::kMultiply, lhs, lhs));
|
||||
}
|
||||
|
||||
// Pow(A, 3) is used in GELU.
|
||||
VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
|
||||
if (IsAll(rhs, 3)) {
|
||||
HloInstruction* tmp =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
power->shape(), HloOpcode::kMultiply, lhs, lhs));
|
||||
return ReplaceWithNewInstruction(
|
||||
power, HloInstruction::CreateBinary(power->shape(),
|
||||
HloOpcode::kMultiply, lhs, tmp));
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
|
||||
if (IsAll(rhs, -1)) {
|
||||
return ReplaceWithNewInstruction(
|
||||
|
@ -117,6 +117,29 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
|
||||
m::ConstantScalar(0.125))));
|
||||
}
|
||||
|
||||
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
|
||||
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p0 = f32[] parameter(0)
|
||||
p1 = f32[] parameter(1)
|
||||
c = f32[] constant(2)
|
||||
d = f32[] constant(4)
|
||||
x = f32[] multiply(p0, c)
|
||||
y = f32[] multiply(p1, d)
|
||||
ROOT z = f32[] multiply(x, y)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::MultiplyAnyOrder(
|
||||
m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)),
|
||||
m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4)))));
|
||||
}
|
||||
|
||||
// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
|
||||
TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
|
||||
const char* kModuleStr = R"(
|
||||
@ -1568,6 +1591,32 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
|
||||
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
|
||||
}
|
||||
|
||||
// Test that pow(A, 3) is simplified to A*A*A.
|
||||
TEST_F(AlgebraicSimplifierTest, Pow3) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, r0f32, "param0"));
|
||||
HloInstruction* three = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
|
||||
|
||||
auto computation = m->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_THAT(computation->root_instruction(),
|
||||
GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
|
||||
|
||||
AlgebraicSimplifier simplifier(default_options_);
|
||||
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
|
||||
|
||||
EXPECT_THAT(
|
||||
computation->root_instruction(),
|
||||
GmockMatch(m::Multiply(m::Parameter(0),
|
||||
m::Multiply(m::Parameter(0), m::Parameter(0)))));
|
||||
}
|
||||
|
||||
// Test that pow(A, -1) is simplified to 1/A.
|
||||
TEST_F(AlgebraicSimplifierTest, PowNegative1) {
|
||||
auto m = CreateNewVerifiedModule();
|
||||
|
@ -270,11 +270,48 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
||||
return EmitMlirFuncAndCall(
|
||||
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
|
||||
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
|
||||
CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
|
||||
CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
|
||||
mlir::MLIRContext* context = builder->getContext();
|
||||
mlir::edsc::ScopedContext scope(*builder, function.getLoc());
|
||||
mlir::Value a = function.getArgument(0), b = function.getArgument(1),
|
||||
c = function.getArgument(2);
|
||||
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
|
||||
mlir::ValueRange{b, c, a});
|
||||
|
||||
llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
|
||||
dot_info_.lhs_shape.rank());
|
||||
llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
|
||||
dot_info_.rhs_shape.rank());
|
||||
|
||||
llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
|
||||
mlir::AffineExpr reduce_expr;
|
||||
for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
|
||||
parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
|
||||
}
|
||||
reduce_expr =
|
||||
mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
|
||||
|
||||
// The reduction expr is shared for both inputs.
|
||||
b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
|
||||
c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
|
||||
|
||||
// Fill in the remaining parallel exprs.
|
||||
int par_expr_num = 0;
|
||||
for (auto* v : {&b_exprs, &c_exprs}) {
|
||||
for (auto& e : *v) {
|
||||
if (!e) {
|
||||
e = parallel_exprs[par_expr_num++];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::IteratorType, 4> types(
|
||||
parallel_exprs.size(), mlir::IteratorType::Parallel);
|
||||
types.push_back(mlir::IteratorType::Reduction);
|
||||
|
||||
mlir::edsc::StructuredIndexed s_a(a), s_b(b), s_c(c);
|
||||
mlir::edsc::makeGenericLinalgOp(types, {s_b(b_exprs), s_c(c_exprs)},
|
||||
{s_a(parallel_exprs)},
|
||||
mlir::edsc::ops::macRegionBuilder);
|
||||
mlir::edsc::intrinsics::std_ret();
|
||||
|
||||
mlir::linalg::LinalgTilingOptions tilingOptions;
|
||||
@ -283,13 +320,13 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
||||
target_machine_features_.minimum_alignment_for_allocation(
|
||||
ShapeUtil::ByteSizeOf(dot_info_.result_shape));
|
||||
mlir_strategy::MatmulCodegenStrategy strategy;
|
||||
strategy.tile<mlir::linalg::MatmulOp>(tilingOptions)
|
||||
.promote<mlir::linalg::MatmulOp>(
|
||||
strategy.tile<mlir::linalg::GenericOp>(tilingOptions)
|
||||
.promote<mlir::linalg::GenericOp>(
|
||||
mlir::linalg::LinalgPromotionOptions()
|
||||
.setAlignment(alignment)
|
||||
.setUseFullTileBuffersByDefault(true)
|
||||
.setUseAlloca(true))
|
||||
.vectorize<mlir::linalg::MatmulOp>()
|
||||
.vectorize<mlir::linalg::GenericOp>()
|
||||
.setVectorTransformsOptions(
|
||||
mlir::vector::VectorTransformsOptions()
|
||||
.setVectorTransformsOptions(
|
||||
|
@ -148,15 +148,12 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleDomain(HloInstruction* hlo) override;
|
||||
|
||||
private:
|
||||
using DimensionConstraint = DynamicDimensionInference::DimensionConstraint;
|
||||
using OperandDynamicDimensionFn = std::function<Status(
|
||||
HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint)>;
|
||||
int64 operand_index, HloInstruction* dynamic_size)>;
|
||||
|
||||
using DynamicDimensionFn = std::function<Status(
|
||||
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint)>;
|
||||
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size)>;
|
||||
|
||||
Status ForEachOperandDynamicDimension(HloInstruction* inst,
|
||||
const OperandDynamicDimensionFn&);
|
||||
@ -184,8 +181,7 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
|
||||
Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
return UnimplementedStrCat(
|
||||
"Asked to propagate a dynamic dimension from hlo ", operand->name(),
|
||||
"@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
|
||||
@ -197,13 +193,11 @@ Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
if (hlo->tuple_index() == index[0]) {
|
||||
ShapeIndex new_index =
|
||||
ShapeIndexView(index).ConsumeFront().ToShapeIndex();
|
||||
parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
@ -212,11 +206,9 @@ Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
|
||||
Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
index.push_front(operand_index);
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -224,11 +216,9 @@ Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
|
||||
Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
int64 broadcast_dim = hlo->dimensions(dimension);
|
||||
parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -244,8 +234,7 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
// returns the padded data output and the dynamic sizes of input
|
||||
// dimensions.
|
||||
ShapeIndex data_output = {0};
|
||||
parent_->SetDynamicSize(hlo, data_output, i, dynamic_size,
|
||||
DimensionConstraint(1, 1));
|
||||
parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -255,15 +244,14 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
}
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
// Resize custom call should propagate dynamic batch (0) and channel (3)
|
||||
// dimensions.
|
||||
if (hlo->custom_call_target() == "SliceToDynamic" ||
|
||||
hlo->custom_call_target() == "Sharding" ||
|
||||
(absl::StartsWith(hlo->custom_call_target(), "Resize") &&
|
||||
(dimension == 0 || dimension == 3))) {
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
|
||||
return Status::OK();
|
||||
}
|
||||
return Unimplemented(
|
||||
@ -274,16 +262,15 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index,
|
||||
int64 dynamic_dimension, int64 operand_index,
|
||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||
hlo,
|
||||
[&](HloInstruction* operand, ShapeIndex index, int64 dynamic_dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
|
||||
if (sort->values_count() == 0) {
|
||||
parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
|
||||
} else {
|
||||
parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
|
||||
dynamic_size, constraint);
|
||||
dynamic_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -293,8 +280,7 @@ Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
|
||||
Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
if (operand_index != 0) {
|
||||
return Unimplemented(
|
||||
"Dynamic dimension on padding value is not supported");
|
||||
@ -311,8 +297,7 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
|
||||
dynamic_size_adjusted->shape(), HloOpcode::kAdd,
|
||||
dynamic_size_adjusted, adjustment));
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Unimplemented(
|
||||
@ -327,8 +312,7 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
|
||||
Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
HloInstruction* reduce = hlo;
|
||||
int64 operand_count = reduce->operand_count();
|
||||
bool is_variadic_reduce = operand_count > 2;
|
||||
@ -354,13 +338,12 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
// reduce has a dynamic dimension, we set all outputs to use the
|
||||
// same dynamic size in corresponding dimensions.
|
||||
for (int64 i = 0; i < operand_count / 2; ++i) {
|
||||
parent_->SetDynamicSize(reduce, {i},
|
||||
dimensions_not_reduced_count,
|
||||
dynamic_size, constraint);
|
||||
parent_->SetDynamicSize(
|
||||
reduce, {i}, dimensions_not_reduced_count, dynamic_size);
|
||||
}
|
||||
} else {
|
||||
parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
|
||||
dynamic_size, constraint);
|
||||
dynamic_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -378,7 +361,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index,
|
||||
int64 operand_dimension, int64 operand_index,
|
||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||
HloInstruction* dynamic_size) {
|
||||
// There are three types of dimensions in a dot:
|
||||
// A. batch dims
|
||||
// B. contracting dims
|
||||
@ -451,8 +434,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
|
||||
// work item to trace that dimension.
|
||||
auto iter = result_dim_mapping.find(operand_dimension);
|
||||
if (iter != result_dim_mapping.end()) {
|
||||
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -463,8 +445,7 @@ Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo,
|
||||
[&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) -> Status {
|
||||
int64 operand_index, HloInstruction* dynamic_size) -> Status {
|
||||
int64 permuted_dim = -1;
|
||||
for (int64 i = 0; i < hlo->dimensions().size(); ++i) {
|
||||
if (hlo->dimensions()[i] == dimension) {
|
||||
@ -472,8 +453,7 @@ Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
|
||||
permuted_dim = i;
|
||||
}
|
||||
}
|
||||
parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -482,8 +462,7 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
HloInstruction* conv = hlo;
|
||||
const ConvolutionDimensionNumbers& dimension_numbers =
|
||||
conv->convolution_dimension_numbers();
|
||||
@ -492,7 +471,7 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution(
|
||||
if (dimension == dimension_numbers.input_batch_dimension()) {
|
||||
parent_->SetDynamicSize(conv, {},
|
||||
dimension_numbers.output_batch_dimension(),
|
||||
dynamic_size, constraint);
|
||||
dynamic_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -542,20 +521,18 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate(
|
||||
dim_size_total, dynamic_dim));
|
||||
}
|
||||
parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
|
||||
dim_size_total, DimensionConstraint(1, 1));
|
||||
dim_size_total);
|
||||
}
|
||||
|
||||
// Simply pass through non-concat dynamic dimensions.
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
int64 concatenate_dimension = hlo->concatenate_dimension();
|
||||
if (concatenate_dimension == dimension) {
|
||||
return Status::OK();
|
||||
}
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -596,18 +573,15 @@ Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
|
||||
if (!dimension_is_static) {
|
||||
// Propagate dynamic dimension indicated by this set dimension size
|
||||
// instruction.
|
||||
parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1),
|
||||
DimensionConstraint(1, 1));
|
||||
parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
|
||||
}
|
||||
|
||||
// Also Propagate dynamic dimension already set by operands.
|
||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
if (dimension != hlo->dimension()) {
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
@ -619,10 +593,8 @@ Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
|
||||
constraint);
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -654,8 +626,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
hlo,
|
||||
[&](HloInstruction* operand, ShapeIndex index,
|
||||
int64 input_dynamic_dimension, int64 operand_index,
|
||||
HloInstruction* operand_dynamic_size,
|
||||
DimensionConstraint constraint) -> Status {
|
||||
HloInstruction* operand_dynamic_size) -> Status {
|
||||
HloInstruction* reshape = hlo;
|
||||
if (reshape->shape().rank() == 0) {
|
||||
VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
|
||||
@ -751,9 +722,6 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
|
||||
if (output_dynamic_dimension == -1 &&
|
||||
output_dim_end - output_dim_start > 1) {
|
||||
// TODO(yunxing): We now have a better way to decide output dimension
|
||||
// in the bridge. No need for this constraint propagation logic.
|
||||
//
|
||||
// One input dimension is splitted into multiple output dimensions.
|
||||
// Output dimension is decomposed from input most major dimension.
|
||||
// In this case, we don't know which one is dynamic, e.g., when we
|
||||
@ -770,61 +738,17 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
// We use the following logics to disambiguate:
|
||||
// 1. If the user sets "inferred_dimension", then use that as
|
||||
// dynamic dimension.
|
||||
// 2. If the one dimension in the reshape is dynamic, use that as
|
||||
// dynamic dimension.
|
||||
// E.g.:
|
||||
// [<=4]
|
||||
// |
|
||||
// reshape
|
||||
// |
|
||||
// [1, <=2, 2]
|
||||
// We use second dim as dynamic dimension.
|
||||
//
|
||||
// 2. Use the "multiple_of" constraint, e.g, :
|
||||
// [<=2, 4]
|
||||
// | Reshape
|
||||
// [<=8]
|
||||
// | Reshape
|
||||
// [2, 4] // Which is dynamic?
|
||||
//
|
||||
// If the dynamic value has to be multiple of 4 (constraint
|
||||
// created by the first reshape), then 2 must be the dynamic
|
||||
// dimension.
|
||||
//
|
||||
// But this logic doesn't help with the case where two
|
||||
// dimensions are the same:
|
||||
//
|
||||
// [<=3, 3]
|
||||
// | Reshape
|
||||
// [<=9]
|
||||
// | Reshape
|
||||
// [3, 3] // Which is dynamic?
|
||||
//
|
||||
// Both dynamic dimension can be multiple of 3.
|
||||
//
|
||||
// We then need the next constraint to disambiguate this case:
|
||||
//
|
||||
// 3. Use the "stride" constraint (also see the comment at the
|
||||
// definition):
|
||||
//
|
||||
// [<=3, 3]
|
||||
// | Reshape
|
||||
// [<=9] // constraint.stride = 1
|
||||
// | Reshape
|
||||
// [3, 3]
|
||||
// ^ ^
|
||||
// | |
|
||||
// stride= 1 3
|
||||
//
|
||||
// Each dimension will have different strides, only one will
|
||||
// satisfy the stride constraint.
|
||||
//
|
||||
// Note that the stride constrint itself is not enough:
|
||||
//
|
||||
//
|
||||
// [<=128]
|
||||
// | Reshape
|
||||
// [1, 128]
|
||||
// ^ ^
|
||||
// | |
|
||||
// stride= 1 1
|
||||
//
|
||||
// In this case, both dimensions have the same stride, which is
|
||||
// ambiguous. That's why we need the "multiple_of" constraint
|
||||
// as used above.
|
||||
//
|
||||
// 4. If all logics above cannot disambiguate, e.g.,:
|
||||
// 3. If all logics above cannot disambiguate, e.g.,:
|
||||
//
|
||||
// [<=1]
|
||||
// |
|
||||
@ -833,68 +757,15 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
// [1, 1, 1]
|
||||
//
|
||||
// We bail out and return an error.
|
||||
// TODO(yunxing): Further simplify this, remove 1. and fully rely
|
||||
// on 2.
|
||||
output_dynamic_dimension = reshape->inferred_dimension();
|
||||
if (output_dynamic_dimension == -1) {
|
||||
// The user of XLA didn't specify a dynamic dimension, try infer
|
||||
// it from the current constraint.
|
||||
//
|
||||
// Find all output dimensions that are decomposed from the first
|
||||
// dimension. Among those dimensions, find all dimensions that
|
||||
// satisfy the constraint of the dynamic dimension. In the
|
||||
// previous example, if `a` is 9 and constraint is a multiple of
|
||||
// `3', then in the output shape both a/c and c can be dynamic.
|
||||
int64 current_product = 1;
|
||||
int64 dimension_iter = output_dim_start;
|
||||
|
||||
// compatible_dimensions are dimensions that satisfies
|
||||
// "multiple_of" constraints.
|
||||
std::vector<int64> compatible_dimensions;
|
||||
while (current_product <
|
||||
operand->shape().dimensions(input_dynamic_dimension)) {
|
||||
current_product *= reshape->shape().dimensions(dimension_iter);
|
||||
if (operand->shape().dimensions(input_dynamic_dimension) /
|
||||
reshape->shape().dimensions(dimension_iter) ==
|
||||
constraint.multiple_of) {
|
||||
compatible_dimensions.push_back(dimension_iter);
|
||||
// Try find dynamic dimension from the result shape.
|
||||
for (int64 i = 0; i < reshape->shape().rank(); ++i) {
|
||||
if (reshape->shape().is_dynamic_dimension(i)) {
|
||||
output_dynamic_dimension = i;
|
||||
}
|
||||
dimension_iter++;
|
||||
}
|
||||
CHECK_EQ(current_product,
|
||||
operand->shape().dimensions(input_dynamic_dimension))
|
||||
<< "Not a valid reshape: " << hlo->ToString();
|
||||
// If there is only one compatible dimension, it must be the
|
||||
// dynamic one in the output.
|
||||
if (compatible_dimensions.size() == 1) {
|
||||
output_dynamic_dimension = compatible_dimensions[0];
|
||||
}
|
||||
|
||||
// When there are multiple compatible dimensions, e.g:
|
||||
// [<=9]
|
||||
// | Reshape
|
||||
// [3, 3]
|
||||
// Use stride constraint to figure out which one is the true
|
||||
// dynamic one.
|
||||
//
|
||||
// [<=9]
|
||||
// | Reshape
|
||||
// [3, 3]
|
||||
// ^ ^
|
||||
// | |
|
||||
// stride= 1 3
|
||||
//
|
||||
std::vector<int64> compatible_dimensions_with_stride;
|
||||
absl::c_copy_if(
|
||||
compatible_dimensions,
|
||||
std::back_inserter(compatible_dimensions_with_stride),
|
||||
[&](int64 dimension) {
|
||||
int64 stride_total = 1;
|
||||
for (int64 i = 0; i < dimension + 1; ++i) {
|
||||
stride_total *= reshape->shape().dimensions(dimension);
|
||||
}
|
||||
return stride_total == constraint.stride;
|
||||
});
|
||||
if (compatible_dimensions_with_stride.size() == 1) {
|
||||
output_dynamic_dimension = compatible_dimensions_with_stride[0];
|
||||
}
|
||||
}
|
||||
|
||||
@ -914,9 +785,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
return InvalidArgument(
|
||||
"Reshape's input dynamic dimension is decomposed into "
|
||||
"multiple output dynamic dimensions, but the constraint is "
|
||||
"ambiguous and XLA can't infer the output dimension %s. "
|
||||
"Constraint: multiple_of: %lld, stride: %lld",
|
||||
hlo->ToString(), constraint.multiple_of, constraint.stride);
|
||||
"ambiguous and XLA can't infer the output dimension %s. ",
|
||||
hlo->ToString());
|
||||
}
|
||||
}
|
||||
|
||||
@ -931,7 +801,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
if (input_dim_size == output_dim_size) {
|
||||
// Simply forward dynamic dimension.
|
||||
parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
|
||||
operand_dynamic_size, constraint);
|
||||
operand_dynamic_size);
|
||||
}
|
||||
|
||||
if (input_dim_size > output_dim_size) {
|
||||
@ -946,9 +816,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
operand_dynamic_size->shape(), HloOpcode::kDivide,
|
||||
operand_dynamic_size, divisor_hlo));
|
||||
|
||||
parent_->SetDynamicSize(
|
||||
reshape, {}, output_dynamic_dimension, new_dynamic_size,
|
||||
DimensionConstraint(1, constraint.multiple_of / divisor));
|
||||
parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
|
||||
new_dynamic_size);
|
||||
}
|
||||
|
||||
if (input_dim_size < output_dim_size) {
|
||||
@ -985,12 +854,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
|
||||
output_dynamic_size->shape(), HloOpcode::kMultiply,
|
||||
new_dynamic_size, operand_dynamic_size));
|
||||
int64 new_multiple_of_constraint =
|
||||
constraint.multiple_of * output_dim_size /
|
||||
operand->shape().dimensions(input_dynamic_dimension);
|
||||
parent_->SetDynamicSize(
|
||||
reshape, {}, output_dynamic_dimension, new_dynamic_size,
|
||||
DimensionConstraint(1, new_multiple_of_constraint));
|
||||
parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
|
||||
new_dynamic_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -1001,8 +866,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
HloInstruction* reduce_window = hlo;
|
||||
const WindowDimension& window_dimension =
|
||||
reduce_window->window().dimensions(dimension);
|
||||
@ -1013,8 +877,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
|
||||
reduce_window->ToString());
|
||||
}
|
||||
|
||||
parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size);
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
@ -1024,8 +887,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
HloInstruction* select_and_scatter = hlo;
|
||||
const WindowDimension& window_dimension =
|
||||
select_and_scatter->window().dimensions(dimension);
|
||||
@ -1036,8 +898,8 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
|
||||
select_and_scatter->ToString());
|
||||
}
|
||||
|
||||
parent_->SetDynamicSize(select_and_scatter, {}, dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(select_and_scatter, {}, dimension,
|
||||
dynamic_size);
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
@ -1046,8 +908,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
|
||||
Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension,
|
||||
int64 /*operand_index*/, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 /*operand_index*/, HloInstruction* dynamic_size) {
|
||||
if (hlo->slice_starts(dimension) != 0 ||
|
||||
hlo->slice_strides(dimension) != 1 ||
|
||||
hlo->slice_limits(dimension) !=
|
||||
@ -1056,7 +917,7 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
@ -1066,8 +927,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64 dimension,
|
||||
int64 /*operand_index*/, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 /*operand_index*/, HloInstruction* dynamic_size) {
|
||||
if (hlo->shape().dimensions(dimension) !=
|
||||
hlo->operand(0)->shape().dimensions(dimension)) {
|
||||
// Slicing a single element out kills the dynamic dimension.
|
||||
@ -1080,7 +940,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
|
||||
hlo->ToString());
|
||||
}
|
||||
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
@ -1089,9 +949,9 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
|
||||
Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
|
||||
HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/,
|
||||
int64 dimension, int64 /*operand_index*/,
|
||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||
hlo,
|
||||
[&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
|
||||
int64 /*operand_index*/, HloInstruction* dynamic_size) {
|
||||
if (hlo->shape().dimensions(dimension) !=
|
||||
hlo->operand(0)->shape().dimensions(dimension)) {
|
||||
return Unimplemented(
|
||||
@ -1100,7 +960,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
|
||||
hlo->ToString());
|
||||
}
|
||||
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
@ -1108,16 +968,16 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
|
||||
|
||||
Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/,
|
||||
int64 dimension, int64 /*operand_index*/,
|
||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||
hlo,
|
||||
[&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
|
||||
int64 /*operand_index*/, HloInstruction* dynamic_size) {
|
||||
if (absl::c_linear_search(hlo->dimensions(), dimension)) {
|
||||
return Unimplemented(
|
||||
"Dynamic dimension propagation on reversed dimension is not "
|
||||
"supported %s",
|
||||
hlo->ToString());
|
||||
}
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
@ -1127,7 +987,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
|
||||
int64 input_dynamic_dimension, int64 operand_index,
|
||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||
HloInstruction* dynamic_size) {
|
||||
const GatherDimensionNumbers& gather_dims =
|
||||
hlo->gather_dimension_numbers();
|
||||
if (operand_index != 1) {
|
||||
@ -1147,8 +1007,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
|
||||
output_dimension--;
|
||||
}
|
||||
}
|
||||
parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
|
||||
return Status::OK();
|
||||
}
|
||||
return Unimplemented(
|
||||
@ -1171,8 +1030,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
|
||||
indices_dim++;
|
||||
}
|
||||
if (indices_dim++ == input_dynamic_dimension) {
|
||||
parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -1220,8 +1078,7 @@ Status DynamicDimensionInferenceVisitor::HandleConditional(
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
|
||||
hlo, operand_index,
|
||||
[&](HloInstruction*, ShapeIndex, int64, int64,
|
||||
HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) -> Status {
|
||||
HloInstruction* dynamic_size) -> Status {
|
||||
TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
|
||||
<< "Only tuple typed inputs can have dynamic dimension. Please "
|
||||
"file a bug against XLA team.";
|
||||
@ -1263,8 +1120,7 @@ Status DynamicDimensionInferenceVisitor::HandleConditional(
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
|
||||
hlo, operand_index,
|
||||
[&](HloInstruction*, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* dynamic_size) {
|
||||
DynamicParameterBinding::DynamicParameter dynamic_parameter{
|
||||
0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
|
||||
DynamicParameterBinding::DynamicDimension dynamic_dimension{
|
||||
@ -1284,8 +1140,8 @@ Status DynamicDimensionInferenceVisitor::HandleConditional(
|
||||
// that into the root instruction as additional tuple elements.
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
|
||||
new_computation->root_instruction(),
|
||||
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
|
||||
DimensionConstraint) -> Status {
|
||||
[&](ShapeIndex index, int64 dim,
|
||||
HloInstruction* dynamic_size) -> Status {
|
||||
TF_RET_CHECK(hlo->shape().IsTuple())
|
||||
<< "Only tuple typed conditionals can have dynamic dimension. "
|
||||
"Please file a bug against XLA team.";
|
||||
@ -1347,11 +1203,9 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
|
||||
return ForEachOperandDynamicDimension(
|
||||
hlo,
|
||||
[&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
|
||||
int64 operand_index, HloInstruction* operand_dynamic_size,
|
||||
DimensionConstraint constraint) {
|
||||
int64 operand_index, HloInstruction* operand_dynamic_size) {
|
||||
if (operand_index == 0) {
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size,
|
||||
constraint);
|
||||
parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1385,7 +1239,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
||||
int64 operand_count = original_tuple_count;
|
||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||
hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64,
|
||||
HloInstruction* dynamic_size, DimensionConstraint constraint) {
|
||||
HloInstruction* dynamic_size) {
|
||||
operands_to_add.push_back(dynamic_size);
|
||||
dynamic_output_mapping.mutable_element(index)->emplace(dim,
|
||||
operand_count++);
|
||||
@ -1413,8 +1267,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
||||
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
|
||||
hlo,
|
||||
[&](HloInstruction*, ShapeIndex index, int64 dimension,
|
||||
int64 operand_index, HloInstruction* dynamic_size,
|
||||
DimensionConstraint constraint) -> Status {
|
||||
int64 operand_index, HloInstruction* dynamic_size) -> Status {
|
||||
TF_RET_CHECK(!operands_to_add.empty());
|
||||
const int64 output_dynamic_size_index =
|
||||
dynamic_output_mapping.element(index).at(dimension);
|
||||
@ -1431,7 +1284,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
||||
ShapeUtil::MakeScalarShape(S32), hlo,
|
||||
output_dynamic_size_index));
|
||||
parent_->SetDynamicSize(result.replacement_instr, index, dimension,
|
||||
output_dynamic_size, constraint);
|
||||
output_dynamic_size);
|
||||
return Status::OK();
|
||||
}));
|
||||
// Set the replacement instruction as visited to avoid visiting it again.
|
||||
@ -1465,8 +1318,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
|
||||
// Add dynamic dimension size as new parameters.
|
||||
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
|
||||
hlo->while_body()->root_instruction(),
|
||||
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
|
||||
DimensionConstraint) -> Status {
|
||||
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size) -> Status {
|
||||
const int64 output_index =
|
||||
dynamic_output_mapping.element(index).at(dim);
|
||||
new_root_operands[output_index] = dynamic_size;
|
||||
@ -1503,8 +1355,7 @@ Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
|
||||
|
||||
parent_->SetDynamicSize(target_parameter,
|
||||
dynamic_dimension.parameter_index,
|
||||
dynamic_dimension.dimension, dynamic_size,
|
||||
DimensionConstraint(1, 1));
|
||||
dynamic_dimension.dimension, dynamic_size);
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
@ -1517,10 +1368,8 @@ Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
|
||||
HloInstruction* dynamic_size = parent_->GetDynamicSize(
|
||||
dynamic_dimension.inst, dynamic_dimension.index,
|
||||
dynamic_dimension.dim);
|
||||
CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0);
|
||||
TF_RETURN_IF_ERROR(fn(dynamic_dimension.index, dynamic_dimension.dim,
|
||||
dynamic_size,
|
||||
parent_->constraint_mapping_[dynamic_dimension]));
|
||||
TF_RETURN_IF_ERROR(
|
||||
fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1536,10 +1385,9 @@ Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
|
||||
HloInstruction* dynamic_size = parent_->GetDynamicSize(
|
||||
dynamic_dimension.inst, dynamic_dimension.index,
|
||||
dynamic_dimension.dim);
|
||||
CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0);
|
||||
TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
|
||||
dynamic_dimension.dim, operand_index, dynamic_size,
|
||||
parent_->constraint_mapping_[dynamic_dimension]));
|
||||
dynamic_dimension.dim, operand_index,
|
||||
dynamic_size));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1555,6 +1403,24 @@ Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
|
||||
const ShapeIndex& index,
|
||||
int64 dim,
|
||||
HloInstruction* size) {
|
||||
VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
|
||||
<< index.ToString() << "@" << dim << " to " << size->ToShortString();
|
||||
Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
|
||||
CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
|
||||
CHECK(dim < subshape.rank() && dim >= 0)
|
||||
<< "Asked to set invalid dynamic dimension. Shape: "
|
||||
<< subshape.ToString() << ", Dimension: " << dim;
|
||||
DynamicDimension dynamic_dimension{inst, index, dim};
|
||||
// Updating a dynamic dimension twice overwrites the previous one.
|
||||
dynamic_mapping_[dynamic_dimension] = size;
|
||||
auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
|
||||
iter.first->second.emplace(dynamic_dimension);
|
||||
}
|
||||
|
||||
void DynamicDimensionInference::CopyMapping(HloInstruction* from,
|
||||
HloInstruction* to) {
|
||||
auto iter = per_hlo_dynamic_dimensions_.find(from);
|
||||
@ -1564,7 +1430,7 @@ void DynamicDimensionInference::CopyMapping(HloInstruction* from,
|
||||
GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
|
||||
dynamic_dimension.dim);
|
||||
SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
|
||||
dynamic_size, constraint_mapping_[dynamic_dimension]);
|
||||
dynamic_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1624,8 +1490,6 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
|
||||
auto iter = dynamic_mapping_.find(dynamic_dimension);
|
||||
if (iter != dynamic_mapping_.end()) {
|
||||
dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
|
||||
constraint_mapping_.insert(
|
||||
{dynamic_dimension_new, constraint_mapping_[dynamic_dimension]});
|
||||
auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
|
||||
iter.first->second.emplace(dynamic_dimension_new);
|
||||
}
|
||||
|
@ -55,8 +55,7 @@ class DynamicDimensionInference {
|
||||
// go into tuples.
|
||||
bool HasDynamicDimension(HloInstruction* inst) const;
|
||||
|
||||
// Forward dynamic dimension size at `dim` and its constraint from `inst` to
|
||||
// `new_inst`.
|
||||
// Forward dynamic dimension size at `dim` from `inst` to `new_inst`.
|
||||
Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst,
|
||||
const ShapeIndex& index);
|
||||
|
||||
@ -64,9 +63,7 @@ class DynamicDimensionInference {
|
||||
// `inst` at `index` has a dynamic size, and its runtime size is represented
|
||||
// by a scalar instruction `size`.
|
||||
void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
|
||||
HloInstruction* size) {
|
||||
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
|
||||
}
|
||||
HloInstruction* size);
|
||||
|
||||
// For all tensors whose dynamic dimension is `replace`, replace them with
|
||||
// `with`.
|
||||
@ -106,116 +103,6 @@ class DynamicDimensionInference {
|
||||
}
|
||||
};
|
||||
|
||||
// DimensionConstraint is attached to each dynamic dimension and describe the
|
||||
// constraint of each dimension. This is used to disambiguate the index of
|
||||
// dynamic dimension for reshapes that "splits" a dimension into two.
|
||||
//
|
||||
// As an example, consider the following reshapes:
|
||||
// [<=3, 3] <- Assume first dimension is dynamic.
|
||||
// |
|
||||
// Reshape.1
|
||||
// |
|
||||
// [<=9] <- Dimension 9 is dynamic
|
||||
// |
|
||||
// Reshape.2
|
||||
// |
|
||||
// [3, 3] <- Ambiguous dimension after splitting 9 into [3, 3]
|
||||
//
|
||||
// There is no way to know which dimension is dynamic by looking at the second
|
||||
// reshape locally.
|
||||
//
|
||||
// However, if we look at the dynamic dimension 9, since it comes from
|
||||
// collapsing a major dynamic dimension of 3 (the dynamic size can be 0, 1, 2,
|
||||
// 3, denoted as i in the diagram below) and a minor static dimension of 3, we
|
||||
// know it has certain constraints that the reshape can only be one of the 4
|
||||
// forms:
|
||||
//
|
||||
// o: Padded Data
|
||||
// x: Effective Data
|
||||
//
|
||||
// [<=3, 3] to [9]
|
||||
//
|
||||
// +---+ +---+ +---+ +---+
|
||||
// |ooo| |ooo| |ooo| |xxx|
|
||||
// |ooo| |ooo| |xxx| |xxx|
|
||||
// |ooo| |xxx| |xxx| |xxx|
|
||||
// +---+ +---+ +---+ +---+
|
||||
//
|
||||
// Reshape Reshape Reshape Reshape
|
||||
//
|
||||
// +-----------+ +-----------+ +-----------+ +-----------+
|
||||
// |ooo|ooo|ooo| or |xxx|ooo|ooo| or |xxx|xxx|ooo| or |xxx|xxx|xxx| stride=1
|
||||
// +-----------+ +-----------+ +-----------+ +-----------+
|
||||
// i = 0 i = 1 i = 2 i = 3
|
||||
//
|
||||
// On the other hand, if the minor dimension 3 is dynamic and major dimension
|
||||
// is static, we will have the following form:
|
||||
//
|
||||
// [3, <=3] to [9]
|
||||
//
|
||||
// +---+ +---+ +---+ +---+
|
||||
// |ooo| |xoo| |xxo| |xxx|
|
||||
// |ooo| |xoo| |xxo| |xxx|
|
||||
// |ooo| |xoo| |xxo| |xxx|
|
||||
// +---+ +---+ +---+ +---+
|
||||
//
|
||||
// Reshape Reshape Reshape Reshape
|
||||
//
|
||||
// +-----------+ +-----------+ +-----------+ +-----------+
|
||||
// |ooo|ooo|ooo| or |xoo|xoo|xoo| or |xxo|xxo|xxo| or |xxo|xxo|xxo| stride=3
|
||||
// +-----------+ +-----------+ +-----------+ +-----------+
|
||||
// i = 0 i = 1 i = 2 i = 3
|
||||
//
|
||||
// By encoding constraint as a stride of elements we can recover this
|
||||
// information later when we reshape from [9] to [3, 3]. We know which form
|
||||
// ([3, i] or [i,3]) we should reshape the [9] into.
|
||||
//
|
||||
//
|
||||
struct DimensionConstraint {
|
||||
explicit DimensionConstraint(int64 s, int64 m)
|
||||
: stride(s), multiple_of(m) {}
|
||||
DimensionConstraint() : stride(1), multiple_of(1) {}
|
||||
// Stride represents the distance of a newly placed element and the previous
|
||||
// placed element on this dynamic dimension.
|
||||
int64 stride;
|
||||
|
||||
// multiple_of represents the constraints that
|
||||
//
|
||||
// `dynamic_size` % `multiple_of` == 0
|
||||
int64 multiple_of;
|
||||
};
|
||||
|
||||
using ConstraintMapping =
|
||||
absl::flat_hash_map<DynamicDimension, DimensionConstraint>;
|
||||
|
||||
ConstraintMapping constraint_mapping_;
|
||||
|
||||
// Update the dynamic mapping so that we know dimension `dim` of instruction
|
||||
// `inst` at `index` has a dynamic size, and its runtime size is represented
|
||||
// by a scalar instruction `size`.
|
||||
void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
|
||||
HloInstruction* size, DimensionConstraint constraint) {
|
||||
VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
|
||||
<< index.ToString() << "@" << dim << " to " << size->ToShortString()
|
||||
<< " constraint: " << constraint.multiple_of;
|
||||
Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
|
||||
CHECK(!subshape.IsTuple())
|
||||
<< "Can't set a tuple shape to dynamic dimension";
|
||||
CHECK(dim < subshape.rank() && dim >= 0)
|
||||
<< "Asked to set invalid dynamic dimension. Shape: "
|
||||
<< subshape.ToString() << ", Dimension: " << dim;
|
||||
DynamicDimension dynamic_dimension{inst, index, dim};
|
||||
// Updating a dynamic dimension twice overwrites the previous one.
|
||||
dynamic_mapping_[dynamic_dimension] = size;
|
||||
if (constraint_mapping_.count(dynamic_dimension) != 0) {
|
||||
CHECK_EQ(constraint_mapping_[dynamic_dimension].stride,
|
||||
constraint.stride);
|
||||
}
|
||||
constraint_mapping_[dynamic_dimension] = constraint;
|
||||
auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
|
||||
iter.first->second.emplace(dynamic_dimension);
|
||||
}
|
||||
|
||||
// Copies the internal mapping from instruction `from` to instruction `to`.
|
||||
// This is useful when an instruction is replaced by the other during the
|
||||
// inferencing process.
|
||||
|
@ -27,6 +27,7 @@ load(
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
||||
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
@ -686,7 +687,7 @@ cc_library(
|
||||
":gpu_autotuning_proto_cc",
|
||||
":gpu_conv_runner",
|
||||
":gpu_executable",
|
||||
":hlo_algorithm_blacklist",
|
||||
":hlo_algorithm_denylist",
|
||||
":ir_emission_utils",
|
||||
":stream_executor_util",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -1660,9 +1661,9 @@ tf_proto_library_cc(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_algorithm_blacklist",
|
||||
srcs = ["hlo_algorithm_blacklist.cc"],
|
||||
hdrs = ["hlo_algorithm_blacklist.h"],
|
||||
name = "hlo_algorithm_denylist",
|
||||
srcs = ["hlo_algorithm_denylist.cc"],
|
||||
hdrs = ["hlo_algorithm_denylist.h"],
|
||||
deps = [
|
||||
":gpu_autotuning_proto_cc",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
@ -1673,12 +1674,12 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_algorithm_blacklist_test",
|
||||
srcs = ["hlo_algorithm_blacklist_test.cc"],
|
||||
name = "hlo_algorithm_denylist_test",
|
||||
srcs = ["hlo_algorithm_denylist_test.cc"],
|
||||
data = ["data/hlo_algorithm_denylist.pbtxt"],
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":hlo_algorithm_blacklist",
|
||||
":hlo_algorithm_denylist",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
@ -1875,3 +1876,49 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_thunks_ops_inc_gen",
|
||||
tbl_outs = [
|
||||
("-gen-op-decls", "ir/xla_thunks_ops.h.inc"),
|
||||
("-gen-op-defs", "ir/xla_thunks_ops.cc.inc"),
|
||||
("-gen-struct-attr-decls", "ir/xla_thunks_structs.h.inc"),
|
||||
("-gen-struct-attr-defs", "ir/xla_thunks_structs.cc.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/xla_thunks_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:LLVMOpsTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_thunks_ops",
|
||||
srcs = [
|
||||
"ir/xla_thunks_ops.cc",
|
||||
"ir/xla_thunks_ops.cc.inc",
|
||||
"ir/xla_thunks_ops.h.inc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/xla_thunks_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":xla_thunks_ops_inc_gen",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
],
|
||||
)
|
||||
|
||||
# Library with XLA thunks dialect static initialization.
|
||||
cc_library(
|
||||
name = "xla_thunks_dialect_registration",
|
||||
srcs = [
|
||||
"ir/dialect_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":xla_thunks_ops",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
|
||||
|
||||
#include <string>
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user