Merge branch 'master' into wsign-compare-semi-final-lite-python-stream-executor

This commit is contained in:
tg-at-google 2020-07-26 20:12:17 -04:00 committed by GitHub
commit 9424fb57d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
807 changed files with 22129 additions and 12180 deletions

View File

@ -28,6 +28,7 @@
#
# Other build options:
# short_logs: Only log errors during build, skip warnings.
# verbose_logs: Show all compiler warnings during build.
# monolithic: Build all TF C++ code into a single shared object.
# dynamic_kernels: Try to link all kernels dynamically (experimental).
# libc++: Link against libc++ instead of stdlibc++
@ -331,6 +332,8 @@ build:windows --distinct_host_configuration=false
# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING
build:verbose_logs --output_filter=
build --config=short_logs
# Instruction set optimizations
# TODO(gunan): Create a feature in toolchains for avx/avx2 to
@ -547,6 +550,7 @@ try-import %workspace%/.bazelrc.user
# Here are bazelrc configs for release builds
build:release_common --config=opt
build:release_common --config=v2
build:release_common --distinct_host_configuration=false
build:release_common --action_env TF_CONFIGURE_IOS="0"
build:release_cpu_linux --config=release_common
@ -564,9 +568,10 @@ build:release_gpu_common --config=tensorrt
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"

View File

@ -31,7 +31,6 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* TF Core:
* <ADD RELEASE NOTES HERE>
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
type annotation for variables representing a Tensor or a value that can be
converted to Tensor by `tf.convert_to_tensor`.
@ -39,9 +38,18 @@
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
* The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__`
and `__invert__` now support non-`bool` arguments and apply the
corresponding bitwise ops. `bool` arguments continue to be supported and
dispatch to logical ops. This brings them more in line with Python and NumPy
benavior.
* `tf.data`:
* Added new `tf.data.experimental.service.register_dataset` and
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
to register a dataset with the tf.data service, and another process to
consume data from the dataset.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the
dataset when it is safe to do so. The optimization can be disabled via
@ -51,11 +59,12 @@
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
* <ADD RELEASE NOTES HERE>
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* `tf.lite`:
* Better support for ops with high-dimensional broadcasting inputs by adding
`BroadcastTo` ops when necessary.
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
@ -219,7 +228,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
`Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in
`tf.function`. See the API docs for
`tf.autograph.experimental.set_loop_options` for additonal info.
`tf.autograph.experimental.set_loop_options` for additional info.
* AutoGraph error messages now exclude frames corresponding to APIs
internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more
@ -302,7 +311,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
also deterministic back-prop of bias-addition in Keras layers) to
include when XLA JIT compilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment
variable `TF_DETERMINSTIC_OPS` or environment variable
variable `TF_DETERMINISTIC_OPS` or environment variable
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
configurations led to an exception with the message "No algorithm
worked!"
@ -345,32 +354,86 @@ This release contains contributions from many people at Google, as well as:
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.
## Major Features and Improvements
* The `tensorflow` pip package now includes GPU support by default (same as `tensorflow-gpu`) for both Linux and Windows. This runs on machines with and without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size.
* **Windows users:** Officially-released `tensorflow` Pip packages are now built with Visual Studio 2019 version 16.4 in order to take advantage of the new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new packages, you must install "Microsoft Visual C++ Redistributable for Visual Studio 2015, 2017 and 2019", available from Microsoft's website [here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads).
* This does not change the minimum required version for building TensorFlow from source on Windows, but builds enabling `EIGEN_STRONG_INLINE` can take over 48 hours to compile without this flag. Refer to `configure.py` for more information about `EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`.
* If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll` (new), are missing on your machine, `import tensorflow` will print a warning message.
* The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6.
* `tf.keras`
* Experimental support for mixed precision is available on GPUs and Cloud TPUs. See [usage guide](https://www.tensorflow.org/guide/keras/mixed_precision).
* Introduced the `TextVectorization` layer, which takes as input raw strings and takes care of text standardization, tokenization, n-gram generation, and vocabulary indexing. See this [end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3).
* Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be outside of the DistributionStrategy scope, as long as the model was constructed inside of a scope.
* Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and `.predict` is available for Cloud TPUs, Cloud TPU, for all types of Keras models (sequential, functional and subclassing models).
* Automatic outside compilation is now enabled for Cloud TPUs. This allows `tf.summary` to be used more conveniently with Cloud TPUs.
* Dynamic batch sizes with DistributionStrategy and Keras are supported on Cloud TPUs.
* Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in addition to `tf.data.Dataset`.
* Keras reference implementations for many popular models are available in the TensorFlow [Model Garden](https://github.com/tensorflow/models/tree/master/official).
* `tf.data`
* Changes rebatching for `tf.data datasets` + DistributionStrategy for better performance. Note that the dataset also behaves slightly differently, in that the rebatched dataset cardinality will always be a multiple of the number of replicas.
* `tf.data.Dataset` now supports automatic data distribution and sharding in distributed environments, including on TPU pods.
* Distribution policies for `tf.data.Dataset` can now be tuned with 1. `tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2. `tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)`
* `tf.debugging`
* Add `tf.debugging.enable_check_numerics()` and `tf.debugging.disable_check_numerics()` to help debugging the root causes of issues involving infinities and `NaN`s.
* `tf.distribute`
* Custom training loop support on TPUs and TPU pods is avaiable through `strategy.experimental_distribute_dataset`, `strategy.experimental_distribute_datasets_from_function`, `strategy.experimental_run_v2`, `strategy.reduce`.
* Support for a global distribution strategy through `tf.distribute.experimental_set_strategy(),` in addition to `strategy.scope()`.
* `TensorRT`
* [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new) is now supported and enabled by default. This adds support for more TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D, MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the TensorFlow-TensorRT python conversion API is exported as `tf.experimental.tensorrt.Converter`.
* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to "true" or "1", this environment variable makes `tf.nn.bias_add` operate deterministically (i.e. reproducibly), but currently only when XLA JIT compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or "1" also makes cuDNN convolution and max-pooling operate deterministically. This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in both the forward and backward directions when running on a CUDA-enabled GPU.
* The `tensorflow` pip package now includes GPU support by default (same as
`tensorflow-gpu`) for both Linux and Windows. This runs on machines with and
without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only
packages can be downloaded at `tensorflow-cpu` for users who are concerned
about package size.
* **Windows users:** Officially-released `tensorflow` Pip packages are now
built with Visual Studio 2019 version 16.4 in order to take advantage of the
new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new
packages, you must install "Microsoft Visual C++ Redistributable for Visual
Studio 2015, 2017 and 2019", available from Microsoft's website
[here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads).
* This does not change the minimum required version for building
TensorFlow from source on Windows, but builds enabling
`EIGEN_STRONG_INLINE` can take over 48 hours to compile without this
flag. Refer to `configure.py` for more information about
`EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`.
* If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll`
(new), are missing on your machine, `import tensorflow` will print a
warning message.
* The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6.
* `tf.keras`
* Experimental support for mixed precision is available on GPUs and Cloud
TPUs. See
[usage guide](https://www.tensorflow.org/guide/keras/mixed_precision).
* Introduced the `TextVectorization` layer, which takes as input raw
strings and takes care of text standardization, tokenization, n-gram
generation, and vocabulary indexing. See this
[end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3).
* Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be
outside of the DistributionStrategy scope, as long as the model was
constructed inside of a scope.
* Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and
`.predict` is available for Cloud TPUs, Cloud TPU, for all types of
Keras models (sequential, functional and subclassing models).
* Automatic outside compilation is now enabled for Cloud TPUs. This allows
`tf.summary` to be used more conveniently with Cloud TPUs.
* Dynamic batch sizes with DistributionStrategy and Keras are supported on
Cloud TPUs.
* Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in
addition to `tf.data.Dataset`.
* Keras reference implementations for many popular models are available in
the TensorFlow
[Model Garden](https://github.com/tensorflow/models/tree/master/official).
* `tf.data`
* Changes rebatching for `tf.data datasets` + DistributionStrategy for
better performance. Note that the dataset also behaves slightly
differently, in that the rebatched dataset cardinality will always be a
multiple of the number of replicas.
* `tf.data.Dataset` now supports automatic data distribution and sharding
in distributed environments, including on TPU pods.
* Distribution policies for `tf.data.Dataset` can now be tuned with 1.
`tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2.
`tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)`
* `tf.debugging`
* Add `tf.debugging.enable_check_numerics()` and
`tf.debugging.disable_check_numerics()` to help debugging the root
causes of issues involving infinities and `NaN`s.
* `tf.distribute`
* Custom training loop support on TPUs and TPU pods is available through
`strategy.experimental_distribute_dataset`,
`strategy.experimental_distribute_datasets_from_function`,
`strategy.experimental_run_v2`, `strategy.reduce`.
* Support for a global distribution strategy through
`tf.distribute.experimental_set_strategy(),` in addition to
`strategy.scope()`.
* `TensorRT`
* [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new)
is now supported and enabled by default. This adds support for more
TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D,
MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the
TensorFlow-TensorRT python conversion API is exported as
`tf.experimental.tensorrt.Converter`.
* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to
"true" or "1", this environment variable makes `tf.nn.bias_add` operate
deterministically (i.e. reproducibly), but currently only when XLA JIT
compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or
"1" also makes cuDNN convolution and max-pooling operate deterministically.
This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in
both the forward and backward directions when running on a CUDA-enabled GPU.
## Breaking Changes
* Deletes `Operation.traceback_with_start_lines` for which we know of no usages.

View File

@ -262,6 +262,7 @@ cc_library(
],
deps = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
],
)

View File

@ -18,11 +18,12 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/refcount.h"
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
// execution mode.
class AbstractTensorHandle {
class AbstractTensorHandle : public core::RefCounted {
protected:
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
@ -34,14 +35,6 @@ class AbstractTensorHandle {
AbstractTensorHandleKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
private:
const AbstractTensorHandleKind kind_;
};
@ -50,7 +43,7 @@ namespace internal {
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) {
p->Release();
p->Unref();
}
}
};

View File

@ -147,7 +147,7 @@ TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }

View File

@ -49,7 +49,6 @@ class GraphTensor : public TracingTensorHandle {
public:
explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {}
void Release() override { delete this; }
tensorflow::DataType DataType() const override {
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));

View File

@ -51,25 +51,14 @@ int64 ToId(AbstractTensorHandle* t) {
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
: handle_(handle), ctx_(ctx) {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
handle_->Ref();
}
TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_;
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
handle_->Ref();
ctx_ = other.ctx_;
}
TapeTensor::~TapeTensor() {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Unref();
}
TapeTensor::~TapeTensor() { handle_->Unref(); }
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
@ -192,7 +181,7 @@ TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Release();
gradient->Unref();
}
// Helper functions which delegate to `AbstractOperation`, update

View File

@ -93,7 +93,7 @@ Status AddGradModel(AbstractContext* ctx,
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
for (auto add_output : add_outputs) {
add_output->Release();
add_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
@ -144,14 +144,14 @@ Status RunModel(Model model, AbstractContext* ctx,
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
for (auto func_input : func_inputs) {
func_input->Release();
func_input->Unref();
}
AbstractFunction* func = nullptr;
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
output_list.outputs[0]->Release();
output_list.outputs[1]->Release();
output_list.outputs[0]->Unref();
output_list.outputs[1]->Unref();
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
@ -252,7 +252,7 @@ TEST_P(CppGradients, TestAddGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Release();
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
@ -260,7 +260,7 @@ TEST_P(CppGradients, TestAddGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Release();
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
@ -270,7 +270,7 @@ TEST_P(CppGradients, TestAddGrad) {
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*tfrt*/ ::testing::Values(true, false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(

View File

@ -50,6 +50,14 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
// Return a copy of the handle.
virtual ImmediateExecutionTensorHandle* Copy() = 0;
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;

View File

@ -448,6 +448,41 @@ GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
stat_cache_max_age, stat_cache_max_entries);
}
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client, bool compose,
uint64_t block_size, size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries)
: gcs_client(gcs_client),
compose(compose),
block_cache_lock(),
block_size(block_size) {
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
[this](const std::string& filename, size_t offset, size_t buffer_size,
char* buffer, TF_Status* status) {
return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this,
status);
});
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
stat_cache_max_age, stat_cache_max_entries);
}
void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size,
size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries,
TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem =
new GCSFile(std::move(client.value()), compose, block_size, max_bytes,
max_staleness, stat_cache_max_age, stat_cache_max_entries);
TF_SetStatus(status, TF_OK, "");
}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();

View File

@ -62,8 +62,19 @@ typedef struct GCSFile {
// of block_size.
std::unique_ptr<ExpiringLRUCache<GcsFileStat>> stat_cache;
GCSFile(google::cloud::storage::Client&& gcs_client);
// This constructor is used for testing purpose only.
GCSFile(google::cloud::storage::Client&& gcs_client, bool compose,
uint64_t block_size, size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries);
} GCSFile;
// This function is used to initialize a filesystem without the need of setting
// manually environement variables.
void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size,
size_t max_bytes, uint64_t max_staleness,
uint64_t stat_cache_max_age, size_t stat_cache_max_entries,
TF_Status* status);
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,

View File

@ -66,6 +66,9 @@ static std::string* GetTmpDir() {
namespace tensorflow {
namespace {
// TODO(vnvo2409): Refactor `gcs_filesystem_test` to remove unnecessary tests
// after porting all tests from
// `//tensorflow/core/platform/cloud:gcs_file_system_test`.
class GCSFilesystemTest : public ::testing::Test {
public:
void SetUp() override {
@ -74,13 +77,14 @@ class GCSFilesystemTest : public ::testing::Test {
::testing::UnitTest::GetInstance()->current_test_info()->name());
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
filesystem_->plugin_filesystem = nullptr;
// Because different tests requires different setup for filesystem. We
// initialize filesystem in each testcase.
}
void TearDown() override {
TF_DeleteStatus(status_);
tf_gcs_filesystem::Cleanup(filesystem_);
if (filesystem_->plugin_filesystem != nullptr)
tf_gcs_filesystem::Cleanup(filesystem_);
delete filesystem_;
}
@ -117,6 +121,21 @@ class GCSFilesystemTest : public ::testing::Test {
}
}
::testing::AssertionResult InsertObject(const std::string& path,
const std::string& content,
gcs::Client* gcs_client,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK)
return ::testing::AssertionFailure() << TF_Message(status);
auto metadata = gcs_client->InsertObject(bucket, object, content);
if (metadata)
return ::testing::AssertionSuccess();
else
return ::testing::AssertionFailure() << metadata.status().message();
}
::testing::AssertionResult CompareSubString(int64_t offset, size_t length,
absl::string_view result,
size_t read) {
@ -172,6 +191,9 @@ TEST_F(GCSFilesystemTest, ParseGCSPath) {
}
TEST_F(GCSFilesystemTest, RandomAccessFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string filepath = GetURIForPath("a_file");
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file,
@ -208,6 +230,9 @@ TEST_F(GCSFilesystemTest, RandomAccessFile) {
}
TEST_F(GCSFilesystemTest, WritableFile) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string filepath = GetURIForPath("a_file");
TF_WritableFile* file = new TF_WritableFile;
tf_gcs_filesystem::NewWritableFile(filesystem_, filepath.c_str(), file,
@ -273,6 +298,9 @@ TEST_F(GCSFilesystemTest, WritableFile) {
}
TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
@ -298,6 +326,131 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) {
delete region;
}
// These tests below are ported from
// `//tensorflow/core/platform/cloud:gcs_file_system_test`
TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) {
tf_gcs_filesystem::InitTest(filesystem_, false, 0, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(6);
int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_);
ASSERT_EQ(read, 6) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "012345") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 6, 6, &result[0], status_);
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "6789") << "Result: " << result << "\n";
}
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered) {
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(6);
int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_);
ASSERT_EQ(read, 6) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "012345") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 6, 6, &result[0], status_);
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "6789") << "Result: " << result << "\n";
}
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) {
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(10);
int64_t read = tf_random_access_file::Read(file, 0, result.length(),
&result[0], status_);
ASSERT_EQ(read, 10) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "0123456789") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, result.length(), result.length(),
&result[0], status_);
ASSERT_EQ(read, 0) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "") << "Result: " << result << "\n";
}
TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) {
tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
std::string path = GetURIForPath("a_file");
auto gcs_file =
static_cast<tf_gcs_filesystem::GCSFile*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(InsertObject(path, "012345678", &gcs_file->gcs_client, status_));
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file,
status_);
ASSERT_TF_OK(status_);
std::string result;
result.resize(5);
int64_t read = tf_random_access_file::Read(file, 0, result.length(),
&result[0], status_);
ASSERT_EQ(read, 5) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
ASSERT_EQ(result, "01234") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 4, result.length(), &result[0],
status_);
ASSERT_EQ(read, 5) << "Read: " << read << "\n";
ASSERT_TF_OK(status_);
result.resize(read);
ASSERT_EQ(result, "45678") << "Result: " << result << "\n";
read = tf_random_access_file::Read(file, 5, result.length(), &result[0],
status_);
ASSERT_EQ(read, 4) << "Read: " << read << "\n";
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
result.resize(read);
ASSERT_EQ(result, "5678") << "Result: " << result << "\n";
}
} // namespace
} // namespace tensorflow

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <aws/s3/model/CompletedPart.h>
#include <aws/s3/model/CopyObjectRequest.h>
#include <aws/s3/model/CreateMultipartUploadRequest.h>
#include <aws/s3/model/DeleteObjectRequest.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadBucketRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
@ -44,6 +45,7 @@ limitations under the License.
constexpr char kS3FileSystemAllocationTag[] = "S3FileSystemAllocation";
constexpr char kS3ClientAllocationTag[] = "S3ClientAllocation";
constexpr int64_t kS3TimeoutMsec = 300000; // 5 min
constexpr int kS3GetChildrenMaxKeys = 100;
constexpr char kExecutorTag[] = "TransferManagerExecutorAllocation";
constexpr int kExecutorPoolSize = 25;
@ -961,7 +963,157 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
s3_file, status);
}
// TODO(vnvo2409): Implement later
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
Aws::S3::Model::DeleteObjectRequest delete_object_request;
delete_object_request.WithBucket(bucket).WithKey(object);
auto delete_object_outcome =
s3_file->s3_client->DeleteObject(delete_object_request);
if (!delete_object_outcome.IsSuccess())
TF_SetStatusFromAWSError(delete_object_outcome.GetError(), status);
else
TF_SetStatus(status, TF_OK, "");
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
if (object.empty()) {
Aws::S3::Model::HeadBucketRequest head_bucket_request;
head_bucket_request.WithBucket(bucket);
auto head_bucket_outcome =
s3_file->s3_client->HeadBucket(head_bucket_request);
if (!head_bucket_outcome.IsSuccess())
TF_SetStatusFromAWSError(head_bucket_outcome.GetError(), status);
else
TF_SetStatus(status, TF_OK, "");
return;
}
Aws::String dir_path = path;
if (dir_path.back() != '/') dir_path.push_back('/');
PathExists(filesystem, dir_path.c_str(), status);
if (TF_GetCode(status) == TF_OK) {
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile * file)> file(
new TF_WritableFile, [](TF_WritableFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file);
delete file;
}
});
file->plugin_file = nullptr;
NewWritableFile(filesystem, dir_path.c_str(), file.get(), status);
if (TF_GetCode(status) != TF_OK) return;
tf_writable_file::Close(file.get(), status);
if (TF_GetCode(status) != TF_OK) return;
}
TF_SetStatus(status, TF_OK, "");
}
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
if (object.back() != '/') object.push_back('/');
Aws::S3::Model::ListObjectsRequest list_objects_request;
list_objects_request.WithBucket(bucket).WithPrefix(object).WithMaxKeys(2);
list_objects_request.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
auto list_objects_outcome =
s3_file->s3_client->ListObjects(list_objects_request);
if (list_objects_outcome.IsSuccess()) {
auto contents = list_objects_outcome.GetResult().GetContents();
if (contents.size() > 1 ||
(contents.size() == 1 && contents[0].GetKey() != object)) {
TF_SetStatus(status, TF_UNKNOWN,
"Cannot delete a non-empty directory. "
"This operation will be retried in case this "
"is due to S3's eventual consistency.");
}
if (contents.size() == 1 && contents[0].GetKey() == object) {
Aws::String dir_path = path;
if (dir_path.back() != '/') dir_path.push_back('/');
DeleteFile(filesystem, dir_path.c_str(), status);
}
} else {
TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status);
}
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
Aws::String bucket, prefix;
ParseS3Path(path, true, &bucket, &prefix, status);
if (TF_GetCode(status) != TF_OK) return -1;
if (!prefix.empty() && prefix.back() != '/') prefix.push_back('/');
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
Aws::S3::Model::ListObjectsRequest list_objects_request;
list_objects_request.WithBucket(bucket)
.WithPrefix(prefix)
.WithMaxKeys(kS3GetChildrenMaxKeys)
.WithDelimiter("/");
list_objects_request.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
Aws::S3::Model::ListObjectsResult list_objects_result;
std::vector<Aws::String> result;
do {
auto list_objects_outcome =
s3_file->s3_client->ListObjects(list_objects_request);
if (!list_objects_outcome.IsSuccess()) {
TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status);
return -1;
}
list_objects_result = list_objects_outcome.GetResult();
for (const auto& object : list_objects_result.GetCommonPrefixes()) {
Aws::String s = object.GetPrefix();
s.erase(s.length() - 1);
Aws::String entry = s.substr(prefix.length());
if (entry.length() > 0) {
result.push_back(entry);
}
}
for (const auto& object : list_objects_result.GetContents()) {
Aws::String s = object.GetKey();
Aws::String entry = s.substr(prefix.length());
if (entry.length() > 0) {
result.push_back(entry);
}
}
list_objects_result.SetMarker(list_objects_result.GetNextMarker());
} while (list_objects_result.GetIsTruncated());
int num_entries = result.size();
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++)
(*entries)[i] = strdup(result[i].c_str());
TF_SetStatus(status, TF_OK, "");
}
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
} // namespace tf_s3_filesystem
@ -969,6 +1121,48 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_s3_filesystem::Init;
ops->filesystem_ops->cleanup = tf_s3_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_s3_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_s3_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_s3_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_s3_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_s3_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_s3_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_s3_filesystem::DeleteDir;
ops->filesystem_ops->copy_file = tf_s3_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_s3_filesystem::PathExists;
ops->filesystem_ops->get_file_size = tf_s3_filesystem::GetFileSize;
ops->filesystem_ops->stat = tf_s3_filesystem::Stat;
ops->filesystem_ops->get_children = tf_s3_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_s3_filesystem::TranslateName;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -104,6 +104,12 @@ TF_ShapeHandle* TF_NewShapeHandle() {
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
}
TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
auto* handle = new ShapeHandle;
*handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
return reinterpret_cast<TF_ShapeHandle*>(handle);
}
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size) {
auto* handle = new ShapeHandle;

View File

@ -280,6 +280,11 @@ extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
int i, TF_ShapeHandle* handle,
TF_Status* status);
// Returns a newly-allocated scalar shape handle. The returned handle should
// be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar(
TF_ShapeInferenceContext* ctx);
// Returns a newly-allocate shape handle representing a vector of the given
// size. The returned handle should be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(

View File

@ -316,5 +316,16 @@ TEST(OpsTest, ShapeInferenceSubshape) {
TF_DeleteShapeHandle(handle);
}
TEST(OpsTest, ShapeInferenceScalarShape) {
NodeDef def;
shape_inference::InferenceContext c(0, def, MakeOpDef(0, 0), {S({})}, {}, {},
{});
TF_ShapeHandle* TF_scalar_shape = TF_ShapeInferenceContextScalar(C_CTX(&c));
shape_inference::ShapeHandle* scalar_shape =
reinterpret_cast<shape_inference::ShapeHandle*>(TF_scalar_shape);
ASSERT_EQ("[]", c.DebugString(*scalar_shape));
TF_DeleteShapeHandle(TF_scalar_shape);
}
} // namespace
} // namespace tensorflow

View File

@ -52,7 +52,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties):
native.py_test(
name = name,
srcs = ["@llvm-project//llvm:lit"],
tags = tags + ["no_windows"],
tags = tags + ["no_pip", "no_windows"],
args = [
"tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v",
] + features,

View File

@ -140,7 +140,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
Here the effective workload shape roughly represents the maximum
parallelism can be used during the codegen stage. It's used to check
the shape-compatibility of the operation. During fusion, we only
try to fuse shape-compatible ops for performace.
try to fuse shape-compatible ops for performance.
For example, the effective workload shape of an elementwise op is its
output shape, while the effective workload shape of a reduction op may
be its operand shape.

View File

@ -640,25 +640,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
}
};
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
template <typename OpTy, bool isLHLO = true>
class IotaConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
OpTy iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType =
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
if (!resultMemrefType) return failure();
ShapedType resultShapedType = getHloOpResultType<isLHLO>(iotaOp);
if (!resultShapedType) return failure();
auto resultElementType = resultMemrefType.getElementType();
auto resultElementType = resultShapedType.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank();
unsigned nloops = resultShapedType.getRank();
rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), ArrayRef<Type>{}, args,
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), isLHLO ? ArrayRef<Type>{} : resultShapedType, args,
0, // args_in
1, // args_out
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
@ -669,14 +669,16 @@ class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
nestedBuilder.getIntegerType(
resultElementType.getIntOrFloatBitWidth()));
if (resultElementType.isa<FloatType>()) {
if (resultElementType.template isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
resultElementType);
}
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
});
rewriter.replaceOp(iotaOp, llvm::None);
if (isLHLO)
rewriter.replaceOp(iotaOp, llvm::None);
else
rewriter.replaceOp(iotaOp, linalgOp.output_tensors());
return success();
}
};
@ -768,7 +770,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
IotaConverter<lmhlo::IotaOp>,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<lmhlo::AddOp>,
@ -870,36 +872,37 @@ namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
patterns
->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {

View File

@ -89,12 +89,10 @@ def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
// Absolute value is evaluated as:
// result = sqrt(val.real * val.real + val.imag * val.imag)
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp (HLO_RealOp:$real $val), $real),
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>;
// Exponential can be lowered to an exponential on the real component and a
// sum of sinusoids of the imaginary component, which equates to a normal

View File

@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota
func @iota() -> tensor<7x10xf32> {
%result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>)
return %result : tensor<7x10xf32>
}
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32

View File

@ -182,11 +182,10 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]]
return %2 : tensor<2xf32>
return %1 : tensor<2xf32>
}
// CHECK-LABEL: @exp

View File

@ -237,28 +237,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "constant_utils",
srcs = [
"utils/constant_utils.cc",
],
hdrs = [
"utils/constant_utils.h",
],
copts = ["-std=c++14"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "lstm_utils",
srcs = [
@ -368,20 +346,23 @@ cc_library(
"transforms/passes.h",
],
deps = [
":constant_utils",
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
":tftext_utils",
":validators",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",

View File

@ -943,7 +943,10 @@ def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [
NoSideEffect,
TFL_OperandHasAtleastRank<0, 2>,
TFL_OperandHasAtleastRank<1, 2>,
SameOperandsAndResultElementType]> {
PredOpTrait<"x and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
PredOpTrait<"y and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "Batch Matrix Multiply Operator";
@ -4345,55 +4348,4 @@ def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
let regions = (region SizedRegion<1>:$body);
}
def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankAtMost<0, 8>,
TFL_OperandHasRank<1, 1>,
PredOpTrait<"output dimension count must be at most 8",
Or<[TFL_OperandIsUnrankedPred<1>,
TFL_OperandDimIsAtMost<1, 0, 8>]>>,
NoSideEffect]> {
let summary = "Broadcast an array for a compatible shape.";
let description = [{
Broadcasting is the process of making arrays to have compatible shapes
for arithmetic operations. Two shapes are compatible if for each
dimension pair they are either equal or one of them is one. When trying
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
and works its way forward.
For example,
>>> x = tf.constant([1, 2, 3])
>>> y = tf.broadcast_to(x, [3, 3])
>>> print(y)
tf.Tensor(
[[1 2 3]
[1 2 3]
[1 2 3]], shape=(3, 3), dtype=int32)
In the above example, the input Tensor with the shape of `[1, 3]`
is broadcasted to output Tensor with shape of `[3, 3]`.
When doing broadcasted operations such as multiplying a tensor
by a scalar, broadcasting (usually) confers some time or space
benefit, as the broadcasted tensor is never materialized.
However, `broadcast_to` does not carry with it any such benefits.
The newly-created tensor takes the full memory of the broadcasted
shape. (In a graph context, `broadcast_to` might be fused to
subsequent operation and then be optimized away, however.)
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
TFL_I32OrI64Tensor:$shape
);
let results = (outs
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
);
}
#endif // TFL_OPS

View File

@ -5,6 +5,8 @@ func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3
return %0: tensor<3x3xbf16>
// CHECK-LABEL: broadcast_to_bf16
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
// CHECK: return [[BCT]] : tensor<3x3xbf16>
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: return [[MUL]] : tensor<3x3xbf16>
}

View File

@ -1487,8 +1487,10 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
// CHECK: return [[BCT]] : tensor<3x3xf32>
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
@ -1496,8 +1498,10 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
// CHECK: return [[BCT]] : tensor<3x3xi32>
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {

View File

@ -1248,6 +1248,13 @@ func @testSpaceToBatchND(%arg0 : tensor<1x4x4x3xf32>, %arg1 : tensor<2xi32>, %ar
// -----
func @testBatchMatmulQuant(%arg0 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06:-2>>, %arg1 : tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>> {
// CHECK: "tfl.batch_matmul"(%arg0, %arg1)
%0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32x!quant.uniform<i8:f32, 0.06:-2>>, tensor<1x4x384x32x!quant.uniform<i8:f32, 0.11:-16>>) -> tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
return %0 : tensor<1x4x384x384x!quant.uniform<i8:f32, 1.02:-73>>
}
// -----
func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
@ -2310,21 +2317,3 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<i32> {
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>)
return %0#0 : tensor<i32>
}
// -----
// CHECK-LABEL: testBroadcastToWithI32ShapeTensor
func @testBroadcastToWithI32ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32> {
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi32>):
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
}
// CHECK-LABEL: testBroadcastToWithI64ShapeTensor
func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32> {
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi64>):
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
}

View File

@ -400,6 +400,32 @@ func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<
// FOLD: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConstWithActivation
// FOLD-LABEL: @FuseFullyConnectedReshapeAddConstWithActivation
func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant dense<3.0> : tensor<40x40xf32>
%cst2 = constant dense<2.0> : tensor<40xf32>
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
return %3 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]]
// CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]]
// CHECK: return %[[rs2]]
// FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
// FOLD: return %[[fc]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter
func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> {
%cst = constant dense<2.0> : tensor<40xf32>

View File

@ -1,5 +1,7 @@
// RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} {
func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<256x3x32x32xf32>) :
// OK
@ -579,69 +581,18 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) ->
// CHECK: return %[[RES]]
}
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32_low_dim
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
// CHECK-LABEL: xla_conv
func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> {
%0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1")
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("XlaConv/feature_group_count")
%2 = "tf.Const"() {value = dense<1> : tensor<2x2xi32>} : () -> tensor<2x2xi32> loc("XlaConv/padding")
%3 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> loc("XlaConv/window_strides")
%4 = "tf.XlaConv"(%arg0, %0, %3, %2, %3, %3, %1) {device = "", dimension_numbers = "\18\02 \032\02\00\01@\03P\03Z\02\01\02b\02\01\02", precision_config = ""} : (tensor<4x8x8x16xf32>, tensor<3x3x16x16xf32>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<i32>) -> tensor<4x8x8x16xf32>
return %4 : tensor<4x8x8x16xf32>
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<16xf32>
// CHECK: %[[CST0:.*]] = constant dense<1.000000e+00> : tensor<16x3x3x16xf32>
// CHECK: %[[RES:.*]] = "tfl.conv_2d"(%arg0, %[[CST0]], %[[CST]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4x8x8x16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<4x8x8x16xf32>
// CHECK: return %[[RES]]
}
func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32_low_dim
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32>
// CHECK: return [[MUL]] : tensor<*xi32>
}
func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
return %0: tensor<7x8x1x2x3x4x5x6xf32>
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32>
}
func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
return %0: tensor<*xf32>
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
// CHECK: return [[BCT]] : tensor<*xf32>
}
func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32>
return %0: tensor<*xf32>
// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output
// CHECK: "tf.BroadcastTo"(%arg0, %arg1)
}

View File

@ -180,6 +180,9 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// control flow ops (IfOp, CaseOp).
pass_manager->addPass(mlir::createInlinerPass());
// This pass removes the asset file dependencies in hash table use cases.
pass_manager->addPass(mlir::TF::CreateInitTextFileToImportPass());
pass_manager->addPass(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());

View File

@ -109,9 +109,6 @@ def LegalizeArgMax : Pat<(TF_ArgMaxOp $input, $dim),
def LegalizeArgMin : Pat<(TF_ArgMinOp $input, $dim),
(TFL_ArgMinOp $input, $dim)>;
def LegalizeBroadcastTo : Pat<(TF_BroadcastToOp $input, $dim),
(TFL_BroadcastToOp $input, $dim)>;
def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>;
def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>;

View File

@ -45,7 +45,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
@ -138,6 +137,7 @@ DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(Reciprocal);
DECL_CONVERT_OP(RandomUniform);
DECL_CONVERT_OP(BroadcastTo);
#undef DECL_CONVERT_OP
@ -483,6 +483,89 @@ LogicalResult ConvertTFAssertOp::matchAndRewrite(
return success();
}
StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
Location loc,
ShapedType shaped_type,
int value) {
Type element_type = shaped_type.getElementType();
ShapedType scalar_type = RankedTensorType::get({}, element_type);
Attribute attr;
switch (element_type.getKind()) {
case mlir::StandardTypes::F16: {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
break;
}
case mlir::StandardTypes::Complex: {
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value =
std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
break;
}
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
}
case mlir::StandardTypes::Integer: {
const auto& itype = element_type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
break;
}
default:
return Status(tensorflow::error::INVALID_ARGUMENT, "Unsupported type");
}
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}
LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reciprocal_op = cast<TF::ReciprocalOp>(op);
@ -503,6 +586,31 @@ LogicalResult ConvertTFReciprocalOp::matchAndRewrite(
return success();
}
LogicalResult ConvertTFBroadcastToOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
auto element_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
auto output_type = tf_broadcast_to_op.output().getType();
auto status_or_const_op =
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), element_type, 1);
if (!status_or_const_op.ok()) {
return failure();
}
auto tfl_fill_op = rewriter.create<TFL::FillOp>(
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
status_or_const_op.ValueOrDie());
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::MulOp>(
op, output_type, tf_broadcast_to_op.input(), tfl_fill_op,
fused_activation_function);
return success();
}
// Legalize unidirectional sequence lstm.
struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
explicit LegalizeUnidirectionalSequenceLstm(MLIRContext* context)
@ -643,7 +751,7 @@ void LegalizeTF::runOnFunction() {
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp>(context);
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,

View File

@ -341,8 +341,8 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
// make sure $rhs is the tail shape of $lhs.
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a), TFL_AF_None),
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
(ConstantOp:$rhs $a), $act_fn),
(TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have

View File

@ -41,7 +41,9 @@ limitations under the License.
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -49,16 +51,18 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#define DEBUG_TYPE "tf-tfl-legalization"
@ -681,46 +685,6 @@ struct ConvertTFStridedSlice : public RewritePattern {
}
};
struct ConvertTFBroadcastTo : public RewritePattern {
explicit ConvertTFBroadcastTo(MLIRContext *context)
: RewritePattern(TF::BroadcastToOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto tf_broadcast_to_op = cast<TF::BroadcastToOp>(op);
auto input_type = tf_broadcast_to_op.input().getType().cast<ShapedType>();
auto output_type = tf_broadcast_to_op.output().getType().cast<ShapedType>();
auto shape_type = tf_broadcast_to_op.shape().getType().cast<ShapedType>();
Type element_type = input_type.getElementType();
// Allow lowering when low dimension inputs are given and its type is F32 or
// I32.
if (!((output_type.hasRank() && output_type.getRank() <= 4) ||
(shape_type.hasStaticShape() && shape_type.getRank() == 1 &&
shape_type.getDimSize(0) <= 4)))
return failure();
if (!((element_type.getKind() == mlir::StandardTypes::F32) ||
(element_type.getKind() == mlir::StandardTypes::Integer &&
element_type.cast<mlir::IntegerType>().getWidth() == 32)))
return failure();
auto status_or_const_op =
CreateConstOpWithSingleValue(&rewriter, op->getLoc(), input_type, 1);
if (!status_or_const_op.ok()) {
return failure();
}
auto tf_fill_op = rewriter.create<TF::FillOp>(
op->getLoc(), output_type, tf_broadcast_to_op.shape(),
status_or_const_op.ValueOrDie());
auto mul_op = rewriter.create<TF::MulOp>(
op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op);
rewriter.replaceOp(op, mul_op.getResult());
return success();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
// Returns success if all the operations in the `op`'s regions including `op`
@ -737,6 +701,23 @@ LogicalResult ValidateOp(Operation *op) {
return failure(has_illegal_ops);
}
// Converts a set of TF2XLA ops into pure TF ops for future legalizations as
// TF2XLA ops aren't supported by later stages.
LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) {
ConversionTarget target(*context);
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<TF::TensorFlowDialect>();
target.addLegalOp<ModuleOp>();
target.addLegalOp<FuncOp>();
target.addIllegalOp<TF::XlaConvOp>();
OwningRewritePatternList patterns;
mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns);
TF::PopulateLegalizeHloToTfPatterns(&patterns, context);
return applyPartialConversion(func, target, patterns);
}
void PrepareTFPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
@ -752,6 +733,11 @@ void PrepareTFPass::runOnFunction() {
return;
}
if (failed(ConvertTf2XlaOps(func, ctx))) {
signalPassFailure();
return;
}
// This pattern was intented to uses TFL QDQs to preserve the quantization
// parameters from the TF Quant ops, thus this pattern should run with the
// first `applyPatternsGreedily` method, which would otherwise removes the
@ -780,7 +766,7 @@ void PrepareTFPass::runOnFunction() {
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
}
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFBroadcastTo, ConvertTFConv2D,
patterns.insert<TF::ConvertTFEinsumOp, ConvertTFConv2D,
ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice>(ctx);
applyPatternsAndFoldGreedily(func, patterns);
}

View File

@ -1,112 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/status.h"
namespace mlir {
namespace TFL {
xla::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
PatternRewriter* rewriter, Location loc, ShapedType shaped_type,
int value) {
Type element_type = shaped_type.getElementType();
ShapedType scalar_type = RankedTensorType::get({}, element_type);
Attribute attr;
switch (element_type.getKind()) {
case mlir::StandardTypes::F16: {
auto floatType = mlir::FloatType::getF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
break;
}
case mlir::StandardTypes::Complex: {
auto etype = element_type.cast<mlir::ComplexType>().getElementType();
if (etype.isF32()) {
auto dialect = etype.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr;
repr.set_dtype(tensorflow::DT_COMPLEX64);
tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape();
shape->set_unknown_rank(false);
shape->add_dim()->set_size(int64_t{1});
std::string content;
auto complex_value =
std::complex<float>(static_cast<float>(value), 0.0f);
content.assign(reinterpret_cast<const char*>(&complex_value),
sizeof(complex_value));
repr.set_tensor_content(content);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled);
break;
}
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
case mlir::StandardTypes::Integer: {
const auto& itype = element_type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
case 8:
attr = DenseElementsAttr::get<int8_t>(scalar_type,
static_cast<int8_t>(value));
break;
case 16:
attr = DenseElementsAttr::get<int16_t>(scalar_type,
static_cast<int16_t>(value));
break;
case 32:
attr = DenseElementsAttr::get<int32_t>(scalar_type,
static_cast<int32_t>(value));
break;
case 64:
attr = DenseElementsAttr::get<int64_t>(scalar_type,
static_cast<int64_t>(value));
break;
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
break;
}
default:
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
"Unsupported type");
}
return rewriter->create<ConstantOp>(loc, scalar_type, attr);
}
} // namespace TFL
} // namespace mlir

View File

@ -1,35 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/xla/statusor.h"
namespace mlir {
namespace TFL {
// Returns a Constant op with a single value.
xla::StatusOr<ConstantOp> CreateConstOpWithSingleValue(
PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value);
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_

View File

@ -73,7 +73,8 @@ tool_names = [
'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir'
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir',
'xla-thunks-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -727,8 +727,10 @@ cc_library(
"transforms/generated_optimize.inc",
"transforms/gpu_fusion.cc",
"transforms/graph_pruning.cc",
"transforms/init_text_file_to_import.cc",
"transforms/launch_to_device_attribute.cc",
"transforms/layout_optimization.cc",
"transforms/mark_ops_for_outside_compilation.cc",
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc",
"transforms/parallel_execute_to_islands.cc",
@ -825,6 +827,7 @@ cc_library(
cc_library(
name = "tensorflow_test_passes",
srcs = [
"transforms/init_text_file_to_import_test_pass.cc",
"transforms/lift_variables_test_pass.cc",
"transforms/lower_tf_pass.cc",
],
@ -840,8 +843,10 @@ cc_library(
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:threadpool_options",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
alwayslink = 1,

View File

@ -102,8 +102,6 @@ class MlirTensor : public TracingTensorHandle {
return type;
}
void Release() override { delete this; }
Value getValue() { return value_; }
// For LLVM style RTTI.

View File

@ -87,7 +87,7 @@ tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>,
def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType, TF_LayoutAgnostic]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
@ -136,7 +136,7 @@ Inputs must be of same size and shape.
let hasFolder = 1;
}
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>,
def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType, TF_LayoutAgnostic]>,
WithBroadcastableBinOpBuilder {
let summary = "Returns x + y element-wise.";
@ -725,6 +725,30 @@ window in `value`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> {
let summary = "Performs 3D average pooling on the input.";
let description = [{
Each entry in `output` is the mean of the corresponding size `ksize` window in
`value`.
}];
let arguments = (ins
TF_FpTensor:$input,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$ksize,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$strides,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
DefaultValuedAttr<TF_AnyStrAttrOf<["NDHWC", "NCDHW"]>, "NDHWC">:$data_format
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> {
let summary = "Computes gradients of average pooling function.";
@ -745,30 +769,6 @@ def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> {
let summary = "Performs 3D average pooling on the input.";
let description = [{
Each entry in `output` is the mean of the corresponding size `ksize`
window in `value`.
}];
let arguments = (ins
TF_FpTensor:$value,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$ksize,
Confined<I64ArrayAttr, [ArrayMinCount<5>]>:$strides,
TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding,
DefaultValuedAttr<TF_AnyStrAttrOf<["NDHWC", "NCDHW"]>, "NDHWC">:$data_format
);
let results = (outs
TF_FpTensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> {
let summary = "Computes gradients of the average pooling function.";
@ -4231,6 +4231,29 @@ Where to extract the key and value from a line is specified by `key_index` and
let results = (outs);
}
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = "Updates specified rows 'i' with values 'v'.";
let description = [{
Computes `x[i, :] = v; return x`.
Originally this function is mutative however for compilation we make this
operation create / operate on a copy of `x`.
}];
let arguments = (ins
TF_Tensor:$x,
I32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the reciprocal of x element-wise.";
@ -6311,6 +6334,8 @@ This is the opposite of `unpack`.
let verifier = [{
return Verify(*this);
}];
let hasFolder = 1;
}
def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
@ -9711,6 +9736,22 @@ For internal use only.
);
}
def TF_TPUOrdinalSelectorOp : TF_Op<"TPUOrdinalSelector", []> {
let summary = "A TPU core selector Op.";
let description = [{
This Op produces a set of TPU cores (for warm-up) or a single TPU core
(for regular inference) to execute the TPU program on. The output is
consumed by TPUPartitionedCall.
}];
let arguments = (ins);
let results = (outs
I32Tensor:$device_ordinals
);
}
def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> {
let summary = "Connects N inputs to an N-way replicated TPU computation.";

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project

View File

@ -225,12 +225,25 @@ else_branch: A function that takes 'inputs' and returns a list of
TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
// Get the then branch function.
FuncOp then_func() {
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(then_branch());
}
// Get the else branch function.
FuncOp else_func() {
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(else_branch());
}
}];
}
def TF_YieldOp : TF_Op<"Yield",
@ -612,7 +625,6 @@ body: A function that takes a list of tensors and returns another
FlatSymbolRefAttr:$cond,
FlatSymbolRefAttr:$body,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
@ -625,10 +637,24 @@ body: A function that takes a list of tensors and returns another
);
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
// Get the condition function.
FuncOp cond_func() {
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(cond());
}
// Get the body function.
FuncOp body_func() {
return getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(body());
}
}];
}
def TL_WhileRegionOp : TF_Op<"WhileRegion",
@ -1070,31 +1096,6 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
}
// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once
// autogenerated op def matches.
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = "Updates specified rows 'i' with values 'v'.";
let description = [{
Computes `x[i, :] = v; return x`.
Originally this function is mutative however for compilation we make this
operation create / operate on a copy of `x`.
}];
let arguments = (ins
TF_Tensor:$x,
I32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Bessel i0e function of `x` element-wise.";
@ -1255,4 +1256,74 @@ def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> {
);
}
def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> {
let summary = [{
Batches all the inputs tensors to the computation done by the function.
}];
let description = [{
So, for example, in the following code
```python
# This input will be captured.
y = tf.placeholder_with_default(1.0, shape=[])
@tf.Defun(tf.float32)
def computation(a):
return tf.matmul(a, a) + y
b = gen_batch_ops.batch_function(
f=computation
in_tensors=[a],
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg],
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
batching_queue="")
If more than one session.run call is simultaneously trying to compute `b`
the values of `a` will be gathered, non-deterministically concatenated
along the first axis, and only one thread will run the computation.
Assumes that all arguments of the function are Tensors which will be batched
along their first dimension.
Arguments that are captured, are not batched. The session.run call which does
the concatenation, will use the values of the captured tensors available to it.
Therefore, typical uses of captured tensors should involve values which remain
unchanged across session.run calls. Inference is a good example of this.
SparseTensor is not supported. The return value of the decorated function
must be a Tensor or a list/tuple of Tensors.
}];
let arguments = (ins
Variadic<TF_Tensor>:$in_tensors,
Variadic<TF_Tensor>:$captured_tensors,
SymbolRefAttr:$f,
I64Attr:$num_batch_threads,
I64Attr:$max_batch_size,
I64Attr:$batch_timeout_micros,
DefaultValuedAttr<I64Attr, "10">:$max_enqueued_batches,
DefaultValuedAttr<I64ArrayAttr, "{}">:$allowed_batch_sizes,
StrAttr:$container,
StrAttr:$shared_name,
StrAttr:$batching_queue,
DefaultValuedAttr<BoolAttr, "false">:$enable_large_batch_splitting,
I32ElementsAttr:$operand_segment_sizes
);
let results = (outs
Variadic<TF_Tensor>:$out_tensors
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedOperandTypeListAttr Tcaptured = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
}
#endif // TF_OPS

View File

@ -1615,6 +1615,10 @@ static LogicalResult Verify(IfOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// IfOp canonicalization.
//===----------------------------------------------------------------------===//
class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
public:
explicit FoldConstantIfOp(MLIRContext *context)
@ -1662,7 +1666,7 @@ LogicalResult FoldConstantIfOp::matchAndRewrite(
void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<FoldConstantIfOp>(context);
results.insert<FoldConstantIfOp, DropAttributes<IfOp>>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project

View File

@ -578,3 +578,23 @@ LogicalResult VerifyRegionResults(Operation *op, Region &region,
}
return success();
}
//===----------------------------------------------------------------------===//
// Function control flow canonicalization.
//===----------------------------------------------------------------------===//
// Eliminate attributes that are not needed, but can get attached to Ops
// during import.
template <typename Op>
struct DropAttributes : public OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
// Drop the "output_shapes" attribute.
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
bool found = op.removeAttr("output_shapes") ==
MutableDictionaryAttr::RemoveResult::Removed;
return success(found);
}
};

View File

@ -217,6 +217,97 @@ static LogicalResult Verify(PackOp op) {
return success();
}
OpFoldResult PackOp::fold(ArrayRef<Attribute> operands) {
// Fold pack operation if it computes the input tensor shape:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
//
// Where `...` are some statically known dimensions. In this case %pack can be
// replaced with a %shape. This is a common pattern in models with a dynamic
// batch size.
// Pack operation should pack at least two values.
if (values().size() < 2) return {};
// Dimensions packed along axis = 0 (pack scalars into vector).
if (axis().getSExtValue() != 0) return {};
// First packed value is defined by a strided slice operation.
auto slice_op = dyn_cast_or_null<StridedSliceOp>(values()[0].getDefiningOp());
if (!slice_op) return {};
// Input to the slice op is defined by shape operation.
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
if (!shape_op) return {};
// Input tensor, which shape is reconstructed by the pack operation.
Value tensor = shape_op.input();
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
// scalar value from input vector).
if (slice_op.begin_mask().getSExtValue() != 0 ||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
slice_op.end_mask().getSExtValue() != 0 ||
slice_op.new_axis_mask().getSExtValue() != 0 ||
slice_op.shrink_axis_mask().getSExtValue() != 1)
return {};
// Returns a value if the `value` is defined by a ConstOp with a single
// integer element in it and has an expected rank.
auto get_const_int = [](Value value, int expected_rank) -> Optional<int64_t> {
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
if (!const_op) return None;
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
if (!value_attr || value_attr.getNumElements() != 1) return None;
auto value_ty = value_attr.getType();
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
auto splat = value_attr.getSplatValue<IntegerAttr>();
return splat.getValue().getSExtValue();
};
// All other packed values are scalar constants.
SmallVector<int64_t, 4> packed_dims;
packed_dims.reserve(values().size() - 1);
for (Value operand : llvm::drop_begin(values(), 1)) {
if (auto dim = get_const_int(operand, /*expected_rank=*/0)) {
packed_dims.push_back(*dim);
} else {
return {};
}
}
// Slice exactly the first shape dimension:
// begin = [0] end = [1], strides = [1]
auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1);
auto end = get_const_int(slice_op.end(), /*expected_rank=*/1);
auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1);
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
*begin != 0 || *end != 1 || *strides != 1)
return {};
// First tensor dimension is dynamic.
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
!arg_ty.isDynamicDim(0))
return {};
// Argument tensor rank is equal to the number of packed dimensions.
if (arg_ty.getRank() != values().size()) return {};
// All other dimensions are statically known and equal to packed dims.
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
return {};
// Replace %pack with %shape.
return slice_op.input();
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
@ -608,12 +699,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RedundantReshape>(context);
results.insert<RedundantReshape, ReshapeToSelfShape>(context);
}
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
Value tensor = this->tensor();
Value shape = this->shape();
// Fold reshape if operand and result types are the same and all dimensions
// are statically known (no-op reshape).
@ -624,90 +714,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return tensor;
}
// Fold reshape if the shape is computed from the input tensor:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
//
// Where `...` are some statically known dimensions. In this case reshape is
// a no-op and can be replaced by %arg (assuming `...` are equal).
auto pack_op = dyn_cast_or_null<PackOp>(shape.getDefiningOp());
if (!pack_op || pack_op.values().size() < 2) return {};
// Dimensions packed along axis = 0 (pack scalars into vector).
if (pack_op.axis().getSExtValue() != 0) return {};
// First packed value is defined by a strided slice operation.
auto slice_op =
dyn_cast_or_null<StridedSliceOp>(pack_op.values()[0].getDefiningOp());
if (!slice_op) return {};
// Input to the slice op is defined by shape operation.
auto shape_op = dyn_cast_or_null<ShapeOp>(slice_op.input().getDefiningOp());
if (!shape_op || shape_op.input() != tensor) return {};
// All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing
// scalar value from input vector).
if (slice_op.begin_mask().getSExtValue() != 0 ||
slice_op.ellipsis_mask().getSExtValue() != 0 ||
slice_op.end_mask().getSExtValue() != 0 ||
slice_op.new_axis_mask().getSExtValue() != 0 ||
slice_op.shrink_axis_mask().getSExtValue() != 1)
return {};
// Returns a value if the `value` is defined by a ConstOp with a single
// integer element in it and has an expected rank.
auto get_value = [](Value value, int expected_rank) -> Optional<int64_t> {
auto const_op = dyn_cast_or_null<ConstOp>(value.getDefiningOp());
if (!const_op) return None;
auto value_attr = const_op.value().dyn_cast<DenseIntElementsAttr>();
if (!value_attr || value_attr.getNumElements() != 1) return None;
auto value_ty = value_attr.getType();
if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None;
auto splat = value_attr.getSplatValue<IntegerAttr>();
return splat.getValue().getSExtValue();
};
// All other packed values are scalar constants.
SmallVector<int64_t, 4> packed_dims;
packed_dims.reserve(pack_op.values().size() - 1);
for (Value operand : llvm::drop_begin(pack_op.values(), 1)) {
if (auto dim = get_value(operand, /*expected_rank=*/0)) {
packed_dims.push_back(*dim);
} else {
return {};
}
}
// Slice exactly the first shape dimension:
// begin = [0] end = [1], strides = [1]
auto begin = get_value(slice_op.begin(), /*expected_rank=*/1);
auto end = get_value(slice_op.end(), /*expected_rank=*/1);
auto strides = get_value(slice_op.strides(), /*expected_rank=*/1);
if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() ||
*begin != 0 || *end != 1 || *strides != 1)
return {};
// First tensor dimension is dynamic.
auto arg_ty = tensor.getType().dyn_cast<ShapedType>();
if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 ||
!arg_ty.isDynamicDim(0))
return {};
// Argument tensor rank is equal to the number of packed dimensions.
if (arg_ty.getRank() != pack_op.values().size()) return {};
// All other dimensions are statically known and equal to packed dims.
auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1);
if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin()))
return {};
return tensor;
return {};
}
//===----------------------------------------------------------------------===//
@ -2065,6 +2072,14 @@ static LogicalResult Verify(WhileOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// WhileOp canonicalization.
//===----------------------------------------------------------------------===//
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<DropAttributes<WhileOp>>(context);
}
//===----------------------------------------------------------------------===//
// WhileRegionOp
//===----------------------------------------------------------------------===//

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project

View File

@ -377,6 +377,15 @@ func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> {
// CHECK: return %1 : tensor<2x8xi32>
}
// CHECK-LABEL: testReshapeToSelfShape
func @testReshapeToSelfShape(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
%0 = "tf.Shape"(%arg0) : (tensor<?x4xf32>) -> tensor<2xi32>
%1 = "tf.Reshape"(%arg0, %0) : (tensor<?x4xf32>, tensor<2xi32>) -> tensor<?x4xf32>
// CHECK: return %arg0 : tensor<?x4xf32>
return %1: tensor<?x4xf32>
}
// CHECK-LABEL: func @testReshapeNoOp
func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> {
%0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32>
@ -385,8 +394,8 @@ func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x
return %0 : tensor<2x4xf32>
}
// CHECK-LABEL: func @testReshapeNoOpShapeComputation
func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>) {
// CHECK-LABEL: func @testPackShapeComputation
func @testPackShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x2xf32>, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) {
// Test dimensions sizes.
%d1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%d2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
@ -396,65 +405,56 @@ func @testReshapeNoOpShapeComputation(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1x
%1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
%2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
// Fold reshape if the shape is computed from the input tensor:
// Fold pack operation if it computes the input tensor shape:
//
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim value
// %new_shape = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
// %reshape = tf.Reshape(%arg, %new_shape) // this is no-op
// %shape = tf.Shape(%arg) // [? x ...]
// %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value
// %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...]
//
// Where `...` are some statically known dimensions. In this case reshape is
// a no-op and can be replaced by %arg (assuming `...` are equal).
// Where `...` are some statically known dimensions. In this case %pack can be
// replace with a %shape. This is a common pattern in models with a dynamic
// batch size.
// Test Rank 2
// CHECK: %[[SHAPE0:.*]] = "tf.Shape"
%3 = "tf.Shape"(%arg0) : (tensor<?x1xf32>) -> tensor<2xi32>
%4 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%5 = "tf.Pack"(%4, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>) -> tensor<2xi32>
%6 = "tf.Reshape"(%arg0, %5) : (tensor<?x1xf32>, tensor<2xi32>) -> tensor<?x1xf32>
// Test Rank 3.
// CHECK: %[[SHAPE1:.*]] = "tf.Shape"
%7 = "tf.Shape"(%arg1) : (tensor<?x1x2xf32>) -> tensor<3xi32>
%8 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%9 = "tf.Pack"(%8, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
%10 = "tf.Reshape"(%arg1, %9) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
// Shape was taken from the op that is not reshaped in the end:
// Reshape(%arg1) vs Shape(%arg0)
%11 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%12 = "tf.Pack"(%11, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"
%13 = "tf.Reshape"(%arg1, %12) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
// Packed dimensions have different order from the reshape operand:
// [?, 1, 2] vs [?, 2, 1]
%14 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"
%16 = "tf.Reshape"(%arg1, %15) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x2x1xf32>
// CHECK: %[[PACK0:.*]] = "tf.Pack"
// StridedSlice takes second dimension from the shape:
// begin = [1], end = [2], stride = [1]
%17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE2:.*]] = "tf.Reshape"
%19 = "tf.Reshape"(%arg1, %18) : (tensor<?x1x2xf32>, tensor<3xi32>) -> tensor<?x1x2xf32>
// CHECK: %[[PACK1:.*]] = "tf.Pack"
// Packed dimensions have higher rank than the reshape operand:
// [?, 1] vs [?, 1, 1]
%20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
%21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<3xi32>
// CHECK: %[[RESHAPE3:.*]] = "tf.Reshape"
%22 = "tf.Reshape"(%arg0, %21) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32>
// CHECK: %[[PACK2:.*]] = "tf.Pack"
// Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass
%23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32>
%24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
%25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%26 = "tf.Reshape"(%arg2, %25) : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32>
// CHECK: %[[PACK3:.*]] = "tf.Pack"
// CHECK: return %arg0, %arg1, %[[RESHAPE0]], %[[RESHAPE1]], %[[RESHAPE2]], %[[RESHAPE3]]
return %6, %10, %13, %16, %19, %22, %26 : tensor<?x1xf32>, tensor<?x1x2xf32>, tensor<?x1x2xf32>, tensor<?x2x1xf32>, tensor<?x1x2xf32>, tensor<?x1x1xf32>, tensor<*xf32>
// CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]]
return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>
}
// CHECK-LABEL: testSelectScalarPred
@ -985,3 +985,36 @@ func @testWhileRegionUnusedValue(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %ar
// CHECK: return %[[WHILE_OUT]]#0 : tensor<*xf32>
return %0#0 : tensor<*xf32>
}
// Check that output_shapes attribute is removed for tf.If
func @testIfThen(tensor<*xf32>) -> tensor<*xf32>
func @testIfElse(tensor<*xf32>) -> tensor<*xf32>
// CHECK-LABEL: func @testIfDropOutputShapes
func @testIfDropOutputShapes(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<2xf32>):
// CHECK: "tf.If"
// CHECK-NOT: output_shapes
%1 = "tf.If"(%arg0, %arg1) {
then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false, output_shapes = [#tf.shape<>]
} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}
// Check that output_shapes attribute is removed for tf.Whileß
func @testWhileCond(tensor<*xf32>) -> (tensor<i1>)
func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>)
// CHECK-LABEL: func @testWhileDropOutputShapes
func @testWhileDropOutputShapes(tensor<*xf32>) -> (tensor<*xf32>) {
^bb0(%arg0: tensor<*xf32>):
// CHECK: "tf.While"
// CHECK-NOT: output_shapes
%1 = "tf.While"(%arg0) {
cond = @testWhileCond,
body = @testWhileBody,
is_stateless = false,
output_shapes = [#tf.shape<>]
} : (tensor<*xf32>) -> (tensor<*xf32>)
return %1 : tensor<*xf32>
}

View File

@ -0,0 +1,14 @@
// RUN: tf-opt -tf-init-text-file-to-import-test %s | FileCheck %s
// Tests that the tf.InitializeTableFromTextFileV2 op are inlined.
func @init_all_tables() {
%cst = constant dense<"%FILE_PLACEHOLDER"> : tensor<!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
return
// CHECK: [[CST:%.*]] = constant dense<["apple", "banana", "grape"]> : tensor<3x!tf.string>
// CHECK: [[CST_0:%.*]] = constant dense<[0, 1, 2]> : tensor<3xi64>
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
// CHECK: "tf.LookupTableImportV2"([[VAL]], [[CST]], [[CST_0]])
}

View File

@ -0,0 +1,53 @@
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-init-text-file-to-import %s | FileCheck %s
// Tests that the given vocabulary file does not exist.
func @init_all_tables() {
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
// expected-error @+1 {{'tf.InitializeTableFromTextFileV2' op failed to open vocabulary file (vocab_file_does_not_exist.txt): cannot open input file 'vocab_file_does_not_exist.txt': No such file or directory}}
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
return
}
// -----
// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since
// unsupported key_index, -1.
func @init_all_tables() {
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -1 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
return
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
// CHECK: tf.InitializeTableFromTextFileV2"
}
// -----
// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since
// unsupported value_index, 0.
func @init_all_tables() {
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = 0 : i64, vocab_size = -1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
return
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
// CHECK: tf.InitializeTableFromTextFileV2"
}
// -----
// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since
// unsupported vocab_size, 1.
func @init_all_tables() {
%cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor<!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = 1 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
return
// CHECK: [[VAL:%.*]] = "tf.HashTableV2"()
// CHECK: tf.InitializeTableFromTextFileV2"
}

View File

@ -1,13 +1,13 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<4xf32>, %arg3: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0:2 = tf_executor.graph {
%outputs_2, %control_3 = tf_executor.island wraps "tf.Less"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<i1>
%outputs_4, %control_5 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatefulIf")
%outputs_6, %control_7 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> loc("StatelessIf")
tf_executor.fetch %outputs_4, %outputs_6 : tensor<f32>, tensor<f32>
%outputs_4, %control_5 = tf_executor.island wraps "tf.If"(%outputs_2, %arg2, %arg3) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<i1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("StatefulIf")
%outputs_6, %control_7 = tf_executor.island wraps "tf.If"(%outputs_2, %arg2, %arg3) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor<i1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("StatelessIf")
tf_executor.fetch %outputs_4, %outputs_6 : tensor<4xf32>, tensor<4xf32>
}
return %0#0, %0#1 : tensor<f32>, tensor<f32>
return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
}
func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
@ -34,8 +34,32 @@ func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: name:
// CHECK: op: "If"
// CHECK-NOT: is_stateless
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
// CHECK: list {
// CHECK: shape {
// CHECK: dim {
// CHECK: size: 4
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: name: "StatelessIf"
// CHECK-NOT: name:
// CHECK: op: "StatelessIf"
// CHECK-NOT: is_stateless
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
// CHECK: list {
// CHECK: shape {
// CHECK: dim {
// CHECK: size: 4
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }

View File

@ -1,12 +1,12 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<i32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%0:2 = tf_executor.graph {
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<f32>) -> (tensor<i32>, tensor<f32>) loc("StatelessWhile")
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<f32>, tensor<f32>
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
}
return %0#0, %0#1 : tensor<f32>, tensor<f32>
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
}
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
@ -36,8 +36,34 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK-NOT: name:
// CHECK: op: "While"
// CHECK-NOT: is_stateless
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
// CHECK: list {
// CHECK: shape {
// CHECK: dim {
// CHECK: size: 5
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: name: "StatelessWhile"
// CHECK-NOT: name:
// CHECK: op: "StatelessWhile"
// CHECK-NOT: is_stateless
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
// CHECK: list {
// CHECK: shape {
// CHECK: dim {
// CHECK: size: 5
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }

View File

@ -56,7 +56,7 @@ func @propagate_if_op(
"tf.If"(%arg1, %id0, %var_handle) {
then_branch = @if_then,
else_branch = @if_else,
output_shapes = [], is_stateless = false}
is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
tf_executor.yield
@ -128,8 +128,7 @@ func @propagate_while_op(
// CHECK-NEXT: "tf.While"
"tf.While"(%arg1, %id0, %var_handle) {
body = @while_body,
cond = @while_cond,
output_shapes = [], is_stateless = false}
cond = @while_cond, is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) ->
(tensor<i32>, tensor<*x!tf.resource<tensor<32xf32>>>,
@ -209,8 +208,7 @@ func @error_on_conflict_multiple_callers(
: () -> tensor<*x!tf.resource<tensor<32xf32>>>
"tf.If"(%arg1, %id0, %var_handle) {
then_branch = @if_then_and_else,
else_branch = @if_then_and_else,
output_shapes = [], is_stateless = false}
else_branch = @if_then_and_else, is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<32xf32>>>,
tensor<*x!tf.resource<tensor<32xf32>>>) -> ()
"tf.If"(%arg1, %var_handle, %id0) {

View File

@ -147,8 +147,7 @@ func @cluster_with_loop() -> () {
"tf_device.cluster"() ( {
// CHECK: %[[WHILE:.*]]:2 = "tf.While"(%[[COUNT]], %[[READ]])
%2:3 = "tf.While"(%0, %1, %unused)
{body = @while_body, cond = @while_cond, device = "", is_stateless = false,
output_shapes = [#tf.shape<>, #tf.shape<>]}
{body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
// CHECK: tf_device.return %[[WHILE]]#1 : tensor<f32>
@ -197,8 +196,7 @@ func @cluster_with_loop() -> () {
"tf_device.cluster"() ( {
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]])
%1 = "tf.While"(%0) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
output_shapes = [#tf.shape<>]}
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<*x!tf.resource<tensor<f32>>>)
// CHECK: tf_device.return %[[WHILE]] : tensor<f32>
@ -239,8 +237,7 @@ func @cluster_with_loop() -> () {
"tf_device.cluster"() ( {
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]])
%1 = "tf.While"(%0) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
output_shapes = [#tf.shape<>]}
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<*x!tf.resource<tensor<f32>>>)
// CHECK: tf_device.return
@ -278,8 +275,7 @@ func @cluster_with_nested_loop() -> () {
"tf_device.cluster"() ( {
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]])
%2:2 = "tf.While"(%0, %1) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
output_shapes = [#tf.shape<>, #tf.shape<>]}
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
// CHECK: tf_device.return %[[WHILE]] : tensor<f32>
@ -295,8 +291,7 @@ func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf
-> (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
// CHECK: %[[WHILE:.*]] = "tf.While"(%[[BARG0]])
%0:2 = "tf.While"(%arg0, %arg1) {
body = @while_body1, cond = @while_cond1, device = "", is_stateless = false,
output_shapes = [#tf.shape<>, #tf.shape<>]}
body = @while_body1, cond = @while_cond1, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
-> (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
// CHECK-NEXT: return %[[WHILE]]
@ -334,8 +329,7 @@ func @cluster_with_loop() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.cluster"() ( {
%1 = "tf.While"(%0) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
output_shapes = [#tf.shape<>]}
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
@ -359,8 +353,7 @@ func @cluster_with_loop() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.cluster"() ( {
%1 = "tf.While"(%0) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
output_shapes = [#tf.shape<>]}
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
@ -384,8 +377,7 @@ func @cluster_with_loop() -> () {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<f32>>>
"tf_device.cluster"() ( {
%1 = "tf.While"(%0) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false,
output_shapes = [#tf.shape<>]}
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>)
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()

View File

@ -100,10 +100,11 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
return %1 : tensor<?x?x?x?xf32>
}
// CHECK-LABEL: func @shape_from_if_to_branch_functions
func @shape_from_if_to_branch_functions(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
return %0 : tensor<1x2x3xf32>
// CHECK-LABEL: func @shape_from_if_to_branch_functions_to_results
// CHECK-SAME: (%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
func @shape_from_if_to_branch_functions_to_results(%arg0: tensor<i1>, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> {
%0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor<i1>, tensor<1x2x3xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @if_then_branch

View File

@ -20,8 +20,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false,
output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]}
cond = @while_cond_7550, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
@ -217,8 +216,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false,
output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]}
cond = @while_cond_7550, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
@ -305,8 +303,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"],
body = @while_body_7560,
cond = @while_cond_7550, device = "", is_stateless = false,
output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]}
cond = @while_cond_7550, device = "", is_stateless = false}
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,

View File

@ -7,7 +7,7 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>)
%3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>)
return
}
// CHECK-LABEL: func @while_body_2710

View File

@ -209,6 +209,11 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
(TF_ReshapeOp $arg, $shape)>;
def IsSame : Constraint<CPred<"$0 == $1">>;
def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)),
(replaceWithValue $arg0),
[(IsSame $arg0, $arg1)]>;
//===----------------------------------------------------------------------===//
// Select op patterns.
//===----------------------------------------------------------------------===//

View File

@ -19,15 +19,62 @@ limitations under the License.
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace tf_executor {
// Visits an op's operand if it is an output of an Operation in the same
// tf_executor.graph.
void VisitOpOperand(GraphOp graph, Value operand,
llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
Operation* def = operand.getDefiningOp();
if (def && def->getParentOp() == graph && reachable_ops->insert(def).second) {
// Op has not been visited, add to queue to visit later.
ops_to_visit->push_back(def);
}
}
// Visits all operands of an op where each operand is an output of an Operation
// in the same tf_executor.graph.
void VisitOpOperands(GraphOp graph, Operation* op,
llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
for (Value operand : op->getOperands())
VisitOpOperand(graph, operand, reachable_ops, ops_to_visit);
}
// Visits an op and it's associated operands. IslandOps are handled differently
// where it's regions op operands are also visited as values may be implicitly
// captured within. NextIterationSourceOp will also visit it's associated
// NextIterationSinkOp.
void VisitOp(GraphOp graph, Operation* op,
llvm::SmallPtrSetImpl<Operation*>* reachable_ops,
llvm::SmallVectorImpl<Operation*>* ops_to_visit) {
if (auto island = llvm::dyn_cast<IslandOp>(op)) {
mlir::visitUsedValuesDefinedAbove(
island.body(), island.body(), [&](OpOperand* operand) {
VisitOpOperand(graph, operand->get(), reachable_ops, ops_to_visit);
});
}
VisitOpOperands(graph, op, reachable_ops, ops_to_visit);
// If op is a `tf_executor.NextIteration.Source`, visit its associated
// `tf_executor.NextIteration.Sink` op.
if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
Operation* sink_op = source_op.GetSink().getOperation();
if (reachable_ops->insert(sink_op).second) ops_to_visit->push_back(sink_op);
}
}
// Prunes unreachable operations of a tf_executor.graph operation.
void PruneGraph(GraphOp graph) {
// A graph has a single block which forms a DAG: operations that aren't
@ -36,49 +83,23 @@ void PruneGraph(GraphOp graph) {
llvm::SmallPtrSet<Operation*, 8> reachable_ops;
llvm::SmallVector<Operation*, 8> ops_to_visit;
// Visit an op's operands if it is output of an Operation in same graph.
auto visit_op = [&](Operation* op) {
for (Value operand : op->getOperands()) {
Operation* def = operand.getDefiningOp();
if (def && def->getParentOp() == graph &&
reachable_ops.insert(def).second) {
// Op has not been visited, add to queue to visit later.
ops_to_visit.push_back(def);
}
}
};
// Visit `fetch` operands.
visit_op(graph.GetFetch());
// Visit fetches first to create a starting point for ops that are reachable.
reachable_ops.insert(graph.GetFetch());
VisitOpOperands(graph, graph.GetFetch(), &reachable_ops, &ops_to_visit);
// Visit transitive ops until no there are no reachable ops left that have not
// been visited.
while (!ops_to_visit.empty()) {
Operation* op = ops_to_visit.pop_back_val();
if (llvm::isa<IslandOp>(op)) {
// Visit island and island inner ops operands.
op->walk([&](Operation* inner_op) { visit_op(inner_op); });
continue;
} else {
// Op is not an island, only visit its operands.
visit_op(op);
}
// If op is a `tf_executor.NextIteration.Source`, visit its associated
// `tf_executor.NextIteration.Sink` op.
if (auto source_op = llvm::dyn_cast<NextIterationSourceOp>(op)) {
Operation* sink_op = source_op.GetSink().getOperation();
if (reachable_ops.insert(sink_op).second) {
ops_to_visit.push_back(sink_op);
}
}
VisitOp(graph, op, &reachable_ops, &ops_to_visit);
}
// Erase unreachable ops in reverse order.
for (Operation& op : llvm::make_early_inc_range(
llvm::drop_begin(llvm::reverse(graph.GetBody()), 1))) {
if (reachable_ops.find(&op) == reachable_ops.end()) {
op.erase();
}
}
// Erase unreachable ops in reverse order so references don't need to be
// dropped before removing an op. Going in reverse order will guarantee that
// when an op to be erased is reached, there are no users left.
for (Operation& op :
llvm::make_early_inc_range(llvm::reverse(graph.GetBody())))
if (!reachable_ops.contains(&op)) op.erase();
}
namespace {

View File

@ -0,0 +1,134 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <numeric>
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
static constexpr int kTextFileIndex_WholeLine = -2;
static constexpr int kTextFileIndex_LineNumber = -1;
// InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the
// corresponding LookupTableImportV2Op if possible.
class InitTextFileToImportPass
: public mlir::PassWrapper<InitTextFileToImportPass, FunctionPass> {
public:
explicit InitTextFileToImportPass() {}
private:
void runOnFunction() override;
};
class ConvertInitializeTableFromTextFileV2
: public OpRewritePattern<InitializeTableFromTextFileV2Op> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op,
PatternRewriter& rewriter) const override {
// Now, this pattern matching only supports the following case, which is
// commonly used among inference use cases:
//
// tf.lookup.TextFileInitializer(
// "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
// tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ")
//
// In the above case, the delimiter will be not used since the key is just a
// whole line and value is a line number.
if (op.key_index() != kTextFileIndex_WholeLine ||
op.value_index() != kTextFileIndex_LineNumber ||
op.vocab_size() != -1) {
return failure();
}
// Try to find filename from constant op.
DenseStringElementsAttr filename_attr;
if (!matchPattern(op.filename().getDefiningOp(),
m_Constant(&filename_attr))) {
return failure();
}
StringRef filename = filename_attr.getRawStringData()[0];
// Read the content of the file.
std::string error_message;
auto file = openInputFile(filename, &error_message);
if (!file) {
return op.emitOpError("failed to open vocabulary file")
<< " (" << filename.str() << "): " << error_message;
}
// Splits into lines.
SmallVector<StringRef, 8> lines;
file->getBuffer().split(lines, "\n", -1, false);
// Map each line to line number, starting from zero.
SmallVector<int64_t, 8> line_nums;
line_nums.resize(lines.size());
std::iota(line_nums.begin(), line_nums.end(), 0);
// Create constant ops for keys an values.
Value key_constant_tensor = rewriter.create<ConstantOp>(
op.getLoc(),
DenseStringElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(lines.size()),
StringType::get(rewriter.getContext())),
lines));
Value value_constant_tensor = rewriter.create<ConstantOp>(
op.getLoc(), rewriter.getI64TensorAttr(line_nums));
// Replace the given op with LookupTableImportV2Op.
rewriter.create<LookupTableImportV2Op>(op.getLoc(), op.table_handle(),
key_constant_tensor,
value_constant_tensor);
rewriter.eraseOp(op);
return success();
}
};
void InitTextFileToImportPass::runOnFunction() {
OwningRewritePatternList patterns;
MLIRContext* context = &getContext();
FuncOp func = getFunction();
patterns.insert<ConvertInitializeTableFromTextFileV2>(context);
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
// Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops.
std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass() {
return std::make_unique<InitTextFileToImportPass>();
}
static PassRegistration<InitTextFileToImportPass> pass(
"tf-init-text-file-to-import",
"convert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op to "
"remove the dependency on asset files");
} // namespace TF
} // namespace mlir

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/Casting.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TF {
namespace {
// InitTextFileToImportTestPass generates a temporary file and run the
// InitTextFileToImportPass for testing purpose.
class InitTextFileToImportTestPass
: public mlir::PassWrapper<InitTextFileToImportTestPass,
OperationPass<ModuleOp>> {
public:
explicit InitTextFileToImportTestPass() {}
private:
void runOnOperation() override;
};
void InitTextFileToImportTestPass::runOnOperation() {
ModuleOp module = getOperation();
// Create a temporary vocab file.
int fd;
SmallString<256> filename;
std::error_code error_code =
llvm::sys::fs::createTemporaryFile("text", "vocab", fd, filename);
if (error_code) return signalPassFailure();
llvm::ToolOutputFile temp_file(filename, fd);
const char* dictionary_in_lines =
"apple\n"
"banana\n"
"grape";
temp_file.os() << dictionary_in_lines;
temp_file.os().flush();
// Replace filename constant ops to use the temporary file.
MLIRContext* context = &getContext();
for (FuncOp func : module.getOps<FuncOp>()) {
llvm::SmallVector<ConstantOp, 4> constant_ops(func.getOps<ConstantOp>());
for (auto op : constant_ops) {
ShapedType shaped_type =
RankedTensorType::get({1}, StringType::get(context));
DenseStringElementsAttr attr;
if (!matchPattern(op.getOperation(), m_Constant(&attr))) {
continue;
}
ArrayRef<StringRef> values = attr.getRawStringData();
if (values.size() != 1 || values[0] != "%FILE_PLACEHOLDER") {
continue;
}
op.valueAttr(DenseStringElementsAttr::get(shaped_type, {filename}));
}
}
// Run the lowering pass.
PassManager pm(context);
pm.addPass(CreateInitTextFileToImportPass());
if (failed(pm.run(module))) return signalPassFailure();
}
} // namespace
static PassRegistration<InitTextFileToImportTestPass> pass(
"tf-init-text-file-to-import-test",
"generate a temporary file and invoke InitTextFileToImportPass");
} // namespace TF
} // namespace mlir

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
namespace mlir {
@ -744,9 +745,7 @@ void LegalizeHloToTf::runOnFunction() {
// Add legalization patterns to the list.
OwningRewritePatternList patterns;
populateWithGenerated(&context, &patterns);
patterns.insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum>(&context);
PopulateLegalizeHloToTfPatterns(&patterns, &context);
ConversionTarget target(context);
target.addLegalDialect<TensorFlowDialect>();
@ -762,6 +761,13 @@ static PassRegistration<LegalizeHloToTf> pass(
} // end namespace
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
MLIRContext *context) {
populateWithGenerated(context, patterns);
patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum>(context);
}
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
return std::make_unique<LegalizeHloToTf>();
}

View File

@ -344,12 +344,56 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
}
};
// Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints,
// since we currently don't have an implementation that can use this
// information. Adds appropriate casts where necessary to align element types
// of operands and result for `TF::MatMulOp`.
class LowerSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
public:
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
PatternRewriter &rewriter) const override {
// Result type must be f32 for applying the pattern (currently this is
// required by the op anyway but this might change).
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
return failure();
}
MLIRContext *context = rewriter.getContext();
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
for (Value &operand : operands) {
TensorType tensor_type = operand.getType().cast<TensorType>();
Type element_type = tensor_type.getElementType();
if (element_type.isF32()) continue;
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
// must be bf16 here.
assert(element_type.isBF16());
Type tensor_type_f32;
if (tensor_type.hasRank()) {
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
FloatType::getF32(context));
} else {
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
}
// Add cast to f32 to conform with element type of result.
operand =
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
}
Value result = rewriter.create<TF::MatMulOp>(
op.getLoc(), op.product().getType(), operands[0], operands[1],
op.transpose_a(), op.transpose_b());
rewriter.replaceOp(op, {result});
return success();
}
};
} // namespace
void PopulateLoweringTFPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<LowerAddNOp, LowerDynamicStitchOp, LowerInvertPermutationOp,
LowerPackOp>(context);
LowerPackOp, LowerSparseMatMulOp>(context);
populateWithGenerated(context, patterns);
}

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <utility>
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TFDevice {
namespace {
// This pass marks unsupported ops in a device cluster with
// `_xla_outside_compilation` attribute so the operations will run on the host
// instead of the device. Unsupported ops are ops that can not be code
// generated to run on the device for the cluster.
struct MarkOpsForOutsideCompilation
: public PassWrapper<MarkOpsForOutsideCompilation,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
void MarkOpsForOutsideCompilation::runOnOperation() {
auto module = getOperation();
module.walk([&](tf_device::ClusterOp cluster) {});
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass() {
return std::make_unique<MarkOpsForOutsideCompilation>();
}
static PassRegistration<MarkOpsForOutsideCompilation> pass(
"tf-mark-ops-for-outside-compilation",
"Marks unsupported ops a device cluster for outside compilation.");
} // namespace TFDevice
} // namespace mlir

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
namespace mlir {
@ -148,6 +150,10 @@ CreateTensorArrayOpsDecompositionPass();
// Create a pass that legalize HLO to TF dialect.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
// Addds the HLO to TF rewrite patterns to the specified pattern list.
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList* patterns,
MLIRContext* context);
// Matches sequence of ops to TensorFlow fused kernels. This pass should not be
// generally used beyond exporting to runtimes that supports these ops. In the
// future these fusions may be codegen'd automatically.
@ -155,6 +161,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
// Creates function pass to replace InitializeTableFromTextFileV2Ops with
// LookupTableImportV2Op ops.
std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
} // namespace TF
namespace tf_executor {
@ -237,6 +247,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass();
std::unique_ptr<OperationPass<ModuleOp>>
CreateAnnotateParameterReplicationPass();
// Creates a pass that marks unsupported ops in device cluster for outside
// compilation.
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass();
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
// attribute to each TensorFlow dialect op in the body based on the `device`
// attribute on the `tf_device.launch`.

View File

@ -373,8 +373,7 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
OpBuilder builder(while_region);
auto while_op = builder.create<WhileOp>(
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
builder.getArrayAttr({}), while_region.parallel_iterations(),
while_region.is_stateless());
while_region.parallel_iterations(), while_region.is_stateless());
// Redirect old results to new results.
for (auto it : llvm::zip(

View File

@ -627,8 +627,6 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
});
// Recreate the while op.
OpBuilder builder(while_op);
auto new_output_shapes = FilterRange<Attribute, ArrayRef<Attribute>>(
while_op.output_shapes().getValue(), resource_arg_uses);
// Now use the filtered original operands, which will be replaced by
// AddLoadsStoresOutsideControlFlowOp().
auto new_while = builder.create<TF::WhileOp>(
@ -636,8 +634,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
FilterRange<Value, OperandRange>(while_op.getOperands(),
resource_arg_uses),
while_op.getAttrs());
// Prepare for AddLoadsStoresOutsideControlFlowOp() and update
// new_output_shapes.
// Prepare for AddLoadsStoresOutsideControlFlowOp().
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index;
for (const auto& entry : remaining_resource_data_types) {
@ -647,14 +644,9 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
: entry.getFirst();
arg_data_type_and_updated_output_index[entry.getFirst()] = {
entry.getSecond(), update_index};
if (!new_output_shapes.empty()) {
new_output_shapes[entry.getFirst()] =
tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond());
}
}
AddLoadsStoresOutsideControlFlowOp(new_while,
arg_data_type_and_updated_output_index);
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
// Replace uses.
for (int64_t i = 0; i < old_to_new_indices.size(); ++i) {
if (old_to_new_indices[i] >= 0) {

View File

@ -262,22 +262,6 @@ bool InferShapeForCall(Operation* op) {
return changed;
}
// Infer the shape IfRegion outputs based on the shapes of the then and else
// yields.
bool InferShapeForIfRegion(IfRegionOp op) {
bool changed = false;
Operation* then_yield = op.then_branch().front().getTerminator();
Operation* else_yield = op.else_branch().front().getTerminator();
for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
else_yield->getOperandTypes())) {
// If then and else types do not match, skip refinement for that result.
if (std::get<1>(result) != std::get<2>(result)) continue;
changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
changed;
}
return changed;
}
bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
Value result = op.getResult();
if (!CanBeRefined(result.getType())) return false;
@ -306,6 +290,37 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) {
return true;
}
// Infer the shape IfOp outputs based on the shapes of the then and else
// function result types.
bool InferShapeForIf(IfOp op) {
bool changed = false;
auto then_results = op.then_func().getType().getResults();
auto else_results = op.else_func().getType().getResults();
for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
// If then and else types do not match, skip refinement for that result.
if (std::get<1>(it) != std::get<2>(it)) continue;
changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
}
return changed;
}
// Infer the shape IfRegion outputs based on the shapes of the then and else
// yields.
bool InferShapeForIfRegion(IfRegionOp op) {
bool changed = false;
Operation* then_yield = op.then_branch().front().getTerminator();
Operation* else_yield = op.else_branch().front().getTerminator();
for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
else_yield->getOperandTypes())) {
// If then and else types do not match, skip refinement for that result.
if (std::get<1>(result) != std::get<2>(result)) continue;
changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
changed;
}
return changed;
}
bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
Dialect* tf_dialect) {
Operation* op = infer_ti.getOperation();
@ -768,17 +783,23 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
op))
return InferShapeForCall(op);
// Handle IfRegion operations by infering return shape from the then and else
// branches.
if (auto if_region = dyn_cast<IfRegionOp>(op))
return InferShapeForIfRegion(if_region);
// tf.Cast are only inferred if they have at least one user in the TF dialect
// or feeding into the function return. This is necessary to avoid inserting
// casts which cannot be refined.
if (auto cast_op = dyn_cast<CastOp>(op))
return InferShapeForCast(cast_op, tf_dialect_);
// Handle IfOp here by inferring the shape from the else/then function
// results. Since `output_shapes` is a derived attribute, avoid going down the
// TF InferenceContext path as IfOp shape inference is implemented as just
// a lookup of the output_shapes attribute.
if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
// Handle IfRegion operations by infering return shape from the then and else
// branches.
if (auto if_region = dyn_cast<IfRegionOp>(op))
return InferShapeForIfRegion(if_region);
StringRef op_name = op->getName().getStringRef();
// Drop the `tf.` prefix to query TF registry.
auto node_name =

View File

@ -197,24 +197,16 @@ LogicalResult HandleWhileOp(
if (!signature_change) return success();
// Create the new while op.
auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
auto new_output_shapes =
llvm::to_vector<8>(while_op.output_shapes().getValue());
OpBuilder builder(while_op);
assert(while_op.getNumOperands() == while_op.getNumResults());
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
auto it = data_var_to_size_var.find(while_op.getOperand(i));
if (it == data_var_to_size_var.end()) continue;
new_while_operands.push_back(it->getSecond());
if (!new_output_shapes.empty()) {
// Size is a scalar shape.
new_output_shapes.push_back(
mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef<int64_t>()));
}
}
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
new_while_operands, while_op.getAttrs());
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
.isa<TF::ResourceType>()) {

View File

@ -595,8 +595,6 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module,
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
operands, while_op.getAttrs());
// Clear the output shapes as it is not needed for XLA lowering.
new_while.setAttr("output_shapes", builder.getArrayAttr({}));
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
if (ta_arg_buffer_type(i)) {
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
@ -663,8 +661,6 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module,
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
then_branch.getType().getResults(),
operands, if_op.getAttrs());
// Clear the output shapes as it is not needed for XLA lowering.
new_if.setAttr("output_shapes", builder.getArrayAttr({}));
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
auto retval = f.front().getTerminator()->getOperand(ret_ind);
auto arg = retval.dyn_cast<BlockArgument>();

View File

@ -190,22 +190,14 @@ LogicalResult HandleWhileOp(
}
// Create the new while op.
auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
auto new_output_shapes =
llvm::to_vector<8>(while_op.output_shapes().getValue());
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
auto it = buffer_to_size->find(while_op.getOperand(i));
if (it == buffer_to_size->end()) continue;
new_while_operands.push_back(it->getSecond().size);
if (!new_output_shapes.empty()) {
// Size is a scalar shape.
new_output_shapes.push_back(
mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef<int64_t>()));
}
}
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
new_while_operands, while_op.getAttrs());
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
for (const auto& entry : output_buffer_to_size) {
(*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};

View File

@ -365,16 +365,6 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body,
while_op.getLoc(),
append_types(llvm::to_vector<4>(while_op.getResultTypes())),
new_while_operands, while_op.getAttrs());
if (new_while_op.output_shapes().size() != 0) {
auto new_output_shapes = llvm::to_vector<4>(new_while_op.output_shapes());
// VarHandleOp is a scalar shape resource.
for (int64_t i = 0; i < state_vars.size(); ++i) {
new_output_shapes.push_back(
mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef<int64_t>()));
}
new_while_op.setAttr("output_shapes",
builder.getArrayAttr(new_output_shapes));
}
while_op.replaceAllUsesWith(
new_while_op.getResults().take_front(while_op.getNumResults()));
while_op.erase();

View File

@ -1836,6 +1836,9 @@ Operation *AvgPoolDivideByCount(
return result;
}
Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
// dimensions with add as the reduction function. The reduction result is
// then divided by the number of elements in the window.
@ -1846,8 +1849,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Value input_value = GetAvgPoolInput(op);
auto input_type =
op.value().getType().template dyn_cast<RankedTensorType>();
input_value.getType().template dyn_cast<RankedTensorType>();
if (!input_type) return failure();
// We will do accumulation first; use a larger bitwidth if suitable.
@ -1862,8 +1866,6 @@ class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
else
result_type = UnrankedTensorType::get(sum_element_type);
Value input_value = op.value();
// Convert if we need enlarge the element type's bitwidth.
if (input_element_type != sum_element_type)
input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
@ -5398,50 +5400,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
}
};
// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness
// hints, since we currently don't have an implementation that can use this
// information. Adds appropriate casts where necessary to align element types
// of operands and result for `TF::MatMulOp`.
class ConvertSparseMatMulOp : public OpRewritePattern<TF::SparseMatMulOp> {
public:
using OpRewritePattern<TF::SparseMatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::SparseMatMulOp op,
PatternRewriter &rewriter) const override {
// Result type must be f32 for applying the pattern (currently this is
// required by the op anyway but this might change).
if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
return failure();
}
MLIRContext *context = rewriter.getContext();
llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
for (Value &operand : operands) {
TensorType tensor_type = operand.getType().cast<TensorType>();
Type element_type = tensor_type.getElementType();
if (element_type.isF32()) continue;
// Element type can either be f32 or bf16 for `SparseMatMulOp` so it
// must be bf16 here.
assert(element_type.isBF16());
Type tensor_type_f32;
if (tensor_type.hasRank()) {
tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
FloatType::getF32(context));
} else {
tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
}
// Add cast to f32 to conform with element type of result.
operand =
rewriter.create<TF::CastOp>(op.getLoc(), tensor_type_f32, operand);
}
Value result = rewriter.create<TF::MatMulOp>(
op.getLoc(), op.product().getType(), operands[0], operands[1],
op.transpose_a(), op.transpose_b());
rewriter.replaceOp(op, {result});
return success();
}
};
// Emits debug information which includes the number of ops of each type which
// failed to legalize.
void EmitLegalizationErrors(Operation *op,
@ -5531,11 +5489,10 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op,
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSparseMatMulOp,
ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp,
ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp,
ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp,
ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
ConvertRandomShuffleOp, ConvertXlaShardingOp,
ConvertXlaDynamicUpdateSliceOp>(op->getContext());

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
@ -223,18 +224,20 @@ static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
class Tf2XlaRewriter {
public:
static LogicalResult RewriteOp(Operation* op, OpBuilder& builder,
static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
const std::string& device_type) {
Tf2XlaRewriter rewriter(op, builder, device_type);
return rewriter.LegalizeOp();
Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type);
return tf2xla_rewriter.LegalizeOp();
}
private:
Tf2XlaRewriter(Operation* op, OpBuilder builder,
Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
const std::string& device_type)
: op_(op),
device_type_(device_type),
hlo_builder_(op->getName().getStringRef().str(), builder, op->getLoc()),
rewriter_(rewriter),
hlo_builder_(op->getName().getStringRef().str(), rewriter_,
op->getLoc()),
context_(nullptr) {}
~Tf2XlaRewriter() {
@ -259,6 +262,7 @@ class Tf2XlaRewriter {
Operation* op_;
std::string device_type_;
PatternRewriter& rewriter_;
::xla::MlirHloBuilder hlo_builder_;
tensorflow::OpOrArgLocNameMapper name_mapper_;
@ -429,6 +433,8 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
// Replace uses of old results using the corresponding value after the
// lowering.
llvm::SmallVector<Value, 2> values;
values.reserve(op_->getNumResults());
for (int i = 0, e = op_->getNumResults(); i < e; i++) {
tensorflow::Tensor* output = op_context.mutable_output(i);
const tensorflow::XlaExpression* expr =
@ -442,10 +448,9 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
value =
hlo_builder_.create<mlir::TensorCastOp>(value, old_result.getType());
}
old_result.replaceAllUsesWith(value);
values.push_back(value);
}
op_->erase();
rewriter_.replaceOp(op_, values);
return success();
}
@ -529,6 +534,11 @@ static PassRegistration<LegalizeTF> pass(
} // end namespace
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns) {
patterns.insert<Tf2XlaRewritePattern>(device_type.str());
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
llvm::StringRef device_type) {
return std::make_unique<LegalizeTF>(device_type);

View File

@ -18,6 +18,9 @@ limitations under the License.
#include <memory>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
namespace mlir {
@ -41,6 +44,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
llvm::StringRef device_type);
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns);
/// Lowers from TF dialect's control flow to HLO dialect's control flow.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();

View File

@ -1333,6 +1333,7 @@ tf_xla_py_test(
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
"notap", # b/162025277
],
deps = [
":xla_test",

View File

@ -8,6 +8,7 @@ load(
"tf_cuda_tests_tags",
"tf_exec_properties",
)
load("//tensorflow:tensorflow.bzl", "py_test")
def all_backends():
b = ["cpu"] + plugins.keys()
@ -121,7 +122,7 @@ def tf_xla_py_test(
updated_name = updated_name[:-5]
updated_name += "_mlir_bridge_test"
native.py_test(
py_test(
name = updated_name,
srcs = srcs,
srcs_version = "PY2AND3",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for while loops in XLA."""
"""Tests for case statements in XLA."""
from __future__ import absolute_import
from __future__ import division

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
@ -332,7 +333,6 @@ void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
Status CreateTRTNode(const ConversionParams& params,
const std::vector<EngineInfo>& infos, int pos,
int max_batch_size, Graph* graph,
nvinfer1::IGpuAllocator* alloc,
std::vector<Node*>* engine_nodes) {
const auto& info = infos.at(pos);
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
@ -428,16 +428,30 @@ Status CreateTRTNode(const ConversionParams& params,
// Build the engine and get its serialized representation.
string segment_string;
if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
std::pair<int, Allocator*> device_allocator =
GetDeviceAndAllocator(params, info);
int cuda_device_id = 0;
std::unique_ptr<TRTBaseAllocator> trt_allocator;
if (device_allocator.first >= 0) {
cuda_device_id = device_allocator.first;
trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
} else {
// The value in trt_allocator is a nullptr and cudamalloc will be used.
LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
"device 0 and use cudamalloc as an allocator";
}
cudaSetDevice(cuda_device_id);
auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
// Create static engine for fp32/fp16 mode.
// Create static engines with precision_mode fp32/fp16.
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
// TODO(sami): What happens if 1st dim is not batch?
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
info.segment_graph_def,
calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
alloc, /*calibrator=*/nullptr, &engine, info.use_calibration,
params.use_implicit_batch, /*convert_successfully=*/nullptr,
trt_allocator.get(), /*calibrator=*/nullptr, &engine,
info.use_calibration, params.use_implicit_batch,
/*convert_successfully=*/nullptr,
/*profile=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string = string(static_cast<const char*>(engine_data->data()),
@ -793,13 +807,27 @@ Status ConvertAfterShapes(const ConversionParams& params) {
}
}
// Create a TRT node for each segment using its EngineInfo.
int old_cuda_device = 0;
auto err = cudaGetDevice(&old_cuda_device);
if (err != cudaSuccess) {
LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
// Save the cuda device if we may need to switch to another cuda device to
// build static engines.
absl::optional<int> old_cuda_device = absl::nullopt;
if (!params.is_dyn_op) {
int cuda_device_id;
cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
if (cuda_error != cudaSuccess) {
LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
<< cudaGetErrorString(cuda_error);
} else {
VLOG(1) << "Current cuda device is " << cuda_device_id;
old_cuda_device = cuda_device_id;
}
}
VLOG(1) << "Current cuda device is " << old_cuda_device;
auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
if (old_cuda_device.has_value()) {
cudaSetDevice(old_cuda_device.value());
}
});
std::vector<Node*> engine_nodes;
engine_nodes.resize(engine_segments.size());
for (int i = 0; i < engine_segments.size(); ++i) {
@ -813,24 +841,8 @@ Status ConvertAfterShapes(const ConversionParams& params) {
2.0;
VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
<< engine.engine_name;
// The allocator is used to build the engine. The build and the built engine
// will be destroyed after we get the serialized engine string, so it's fine
// to use unique_ptr here.
std::unique_ptr<TRTBaseAllocator> alloc;
auto device_alloc = GetDeviceAndAllocator(params, engine);
int cuda_device_id = 0;
if (device_alloc.first >= 0) {
cuda_device_id = device_alloc.first;
alloc.reset(new TRTDeviceAllocator(device_alloc.second));
} else {
// Setting allocator as nullptr should get revert to the cudamalloc
LOG_WARNING_WITH_PREFIX
<< "Can't identify the cuda device. Running on device 0 ";
}
cudaSetDevice(cuda_device_id);
auto status =
CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
alloc.get(), &engine_nodes);
auto status = CreateTRTNode(params, engine_segments, i,
params.max_batch_size, &graph, &engine_nodes);
string msg = StrCat("segment ", i, " consisting of ",
converted_segments.at(i).size(), " nodes by ",
@ -859,7 +871,6 @@ Status ConvertAfterShapes(const ConversionParams& params) {
}
}
}
cudaSetDevice(old_cuda_device);
graph.ToGraphDef(params.output_graph_def);
VLOG(1) << "Returning from conversion";
return Status::OK();

View File

@ -1309,7 +1309,8 @@ std::vector<float> GetDataAsFloat(InputOutputData& data) {
class OpConverterTest : public ::testing::Test {
public:
OpConverterTest()
: scope_(Scope::NewRootScope()), allocator_(new GpuManagedAllocator()) {
: tensor_buffer_allocator_(new GpuManagedAllocator()),
scope_(Scope::NewRootScope()) {
QCHECK_EQ(0, cudaStreamCreate(&stream_));
Reset();
}
@ -1341,7 +1342,7 @@ class OpConverterTest : public ::testing::Test {
// Constructs a flat tensor with 'vals' in Unified Memory.
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals) { // non-absl ok
Tensor ret(allocator_.get(), DataTypeToEnum<T>::value,
Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum<T>::value,
{static_cast<int64>(vals.size())});
std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
return ret;
@ -1351,7 +1352,7 @@ class OpConverterTest : public ::testing::Test {
template <typename T>
Tensor AsTensor(gtl::ArraySlice<T> vals, // non-absl ok
const TensorShape& shape) {
Tensor ret(allocator_.get(), DataTypeToEnum<T>::value,
Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum<T>::value,
{static_cast<int64>(vals.size())});
CHECK(ret.CopyFrom(AsTensor(vals), shape));
return ret;
@ -1363,7 +1364,8 @@ class OpConverterTest : public ::testing::Test {
template <typename T>
Tensor AsTensor(std::vector<T> vals, const std::vector<int> input_dims,
DataType tf_type) {
Tensor ret(allocator_.get(), tf_type, {static_cast<int64>(vals.size())});
Tensor ret(tensor_buffer_allocator_.get(), tf_type,
{static_cast<int64>(vals.size())});
if (tf_type == DT_FLOAT) {
auto conv_vals = CastTestVector<T, float>(vals);
std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat<float>().data());
@ -1646,13 +1648,15 @@ class OpConverterTest : public ::testing::Test {
Logger logger_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
cudaStream_t stream_;
// Used to create placeholders with shape and data type information. The
// created placeholders will be used as inputs to the node to be verified,
// thus we need the shape and data type information to get a non-empty
// GraphProperties.
std::unique_ptr<Allocator> tensor_buffer_allocator_;
// The scope that contains the graph being converted. Because
// tensor_buffer_allocator_ provides the storage for tensor contents that are
// represented as attributes for graph nodes within scope_,
// tensor_buffer_allocator_ needs to be available when destructing scope_.
// Therefore, scope_ comes after tensor_buffer_allocator_ in the class member
// field list.
Scope scope_;
std::unordered_map<string, Output> node_inputs_;
std::unique_ptr<Allocator> allocator_;
};
// General test parameters to be used with ops that take a single input tensor.

View File

@ -160,17 +160,15 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
XlaCompiler* compiler = ctx->compiler();
std::vector<XlaCompiler::CompilationResult> branch_results(num_branches);
std::vector<XlaCompiler::CompilationResult*> branch_results_p(num_branches);
for (int j = 0; j < num_branches; ++j) {
OP_REQUIRES_OK(ctx,
compiler->CompileFunction(options, branches[j], arguments,
&branch_results[j]));
branch_results_p[j] = &branch_results[j];
}
bool has_tensor_array_gradients = false;
for (XlaCompiler::CompilationResult* result : branch_results_p) {
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
for (XlaCompiler::CompilationResult& result : branch_results) {
for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) {
XlaResource* resource;
OP_REQUIRES_OK(ctx,
ctx->GetResourceInput(update.input_index + 1, &resource));

View File

@ -47,6 +47,122 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
}
}
// Populates tensor array gradients for compiled branches, returns whether the
// set of found tensor array gradients is non-empty.
static xla::StatusOr<bool> PopulateTensorArrayGradients(
XlaOpKernelContext* ctx, xla::XlaBuilder* b,
absl::Span<XlaCompiler::Argument> arguments,
XlaCompiler::CompilationResult* then_result,
XlaCompiler::CompilationResult* else_result) {
bool has_tensor_array_gradients = false;
for (XlaCompiler::CompilationResult* result : {then_result, else_result}) {
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
XlaResource* resource;
TF_RETURN_IF_ERROR(
ctx->GetResourceInput(update.input_index + 1, &resource));
XlaCompiler::Argument& arg = arguments[update.input_index];
// Add any TensorArray gradients touched by the then/else computation to
// the enclosing graph.
for (const string& grad_source : update.tensor_array_gradients_accessed) {
VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
<< grad_source;
XlaResource* gradient;
TF_RETURN_IF_ERROR(resource->GetOrCreateTensorArrayGradient(
grad_source, b, &gradient));
}
// Add all of the TensorArray gradients to the argument. For simplicity,
// we always pass all known gradients.
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
if (!resource->tensor_array_gradients().empty())
has_tensor_array_gradients = true;
}
}
return has_tensor_array_gradients;
}
// Checks that shapes matches on both sides of the conditional.
static Status ValidateShapes(
XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& then_result,
const XlaCompiler::CompilationResult& else_result) {
// Check that both branches have identical input shapes.
if (then_result.xla_input_shapes.size() != 1) {
return errors::FailedPrecondition("Expected one input shape");
}
xla::Shape then_input_shape = then_result.xla_input_shapes[0];
if (!then_input_shape.IsTuple()) {
return errors::FailedPrecondition("Expected tuple shape");
}
if (else_result.xla_input_shapes.size() != 1) {
return errors::FailedPrecondition("Expected one input shape");
}
xla::Shape else_input_shape = else_result.xla_input_shapes[0];
if (!else_input_shape.IsTuple()) {
return errors::FailedPrecondition("Expected tuple shape");
}
if (!xla::ShapeUtil::Compatible(then_input_shape, else_input_shape)) {
return errors::InvalidArgument(
"Input shapes of then and else branches do not match: ",
xla::ShapeUtil::HumanString(then_input_shape), " vs. ",
xla::ShapeUtil::HumanString(else_input_shape));
}
// Check that both branches have identical output shapes.
if (!xla::ShapeUtil::Compatible(then_result.xla_output_shape,
else_result.xla_output_shape)) {
return errors::InvalidArgument(
"Output shapes of then and else branches do not match: ",
xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ",
xla::ShapeUtil::HumanString(else_result.xla_output_shape));
}
// Check that both branches have same TensorList output indices.
for (int output_index = 0; output_index < then_result.outputs.size();
output_index++) {
bool is_tensor_list_in_then_branch =
then_result.outputs[output_index].is_tensor_list;
bool is_tensor_list_in_else_branch =
else_result.outputs[output_index].is_tensor_list;
if (is_tensor_list_in_then_branch != is_tensor_list_in_else_branch) {
return errors::FailedPrecondition(
"Output #", output_index, " is ",
(is_tensor_list_in_then_branch ? "" : "not"),
" a TensorList in then branch, but is ",
(is_tensor_list_in_else_branch ? "" : "not"),
" a TensorList in else branch");
}
}
VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape);
VLOG(2) << "Output shape: "
<< xla::ShapeUtil::HumanString(then_result.xla_output_shape);
// We set return_updated_values_for_all_resources=true and we pass the same
// arguments to both computations, so the resource update count must match.
if (then_result.resource_updates.size() !=
else_result.resource_updates.size()) {
return errors::FailedPrecondition(
"Different number of resources in then and else branch");
}
for (int i = 0; i < then_result.resource_updates.size(); ++i) {
const auto& lhs = then_result.resource_updates[i];
const auto& rhs = else_result.resource_updates[i];
bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape &&
lhs.tensor_array_gradients_accessed ==
rhs.tensor_array_gradients_accessed;
if (!equal) {
return errors::FailedPrecondition(
"Mismatch in resource of then and else branch for resource ", i);
}
}
return Status::OK();
}
// TODO(b/35949885): There is duplication here with the handling of the
// while_op. Refactor the common code out/rework.
void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
@ -137,35 +253,12 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
arguments, &else_result));
bool has_tensor_array_gradients = false;
for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
XlaResource* resource;
OP_REQUIRES_OK(ctx,
ctx->GetResourceInput(update.input_index + 1, &resource));
XlaCompiler::Argument& arg = arguments[update.input_index];
// Add any TensorArray gradients touched by the then/else computation to
// the enclosing graph.
for (const string& grad_source : update.tensor_array_gradients_accessed) {
VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
<< grad_source;
XlaResource* gradient;
OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
grad_source, b, &gradient));
}
// Add all of the TensorArray gradients to the argument. For simplicity,
// we always pass all known gradients.
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
if (!resource->tensor_array_gradients().empty())
has_tensor_array_gradients = true;
}
}
xla::StatusOr<bool> has_tensor_array_gradients = PopulateTensorArrayGradients(
ctx, b, absl::MakeSpan(arguments), &then_result, &else_result);
OP_REQUIRES_OK(ctx, has_tensor_array_gradients.status());
// Recompile the functions to update the argument shapes for tensor arrays.
if (has_tensor_array_gradients) {
if (*has_tensor_array_gradients) {
then_result = {};
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_,
arguments, &then_result));
@ -174,72 +267,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
arguments, &else_result));
}
// Check that both branches have identical input shapes.
OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
xla::Shape then_input_shape = then_result.xla_input_shapes[0];
OP_REQUIRES(ctx, then_input_shape.IsTuple(),
errors::FailedPrecondition("Expected tuple shape"));
OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
xla::Shape else_input_shape = else_result.xla_input_shapes[0];
OP_REQUIRES(ctx, else_input_shape.IsTuple(),
errors::FailedPrecondition("Expected tuple shape"));
OP_REQUIRES(ctx,
xla::ShapeUtil::Compatible(then_input_shape, else_input_shape),
errors::InvalidArgument(
"Input shapes of then and else branches do not match: ",
xla::ShapeUtil::HumanString(then_input_shape), " vs. ",
xla::ShapeUtil::HumanString(else_input_shape)));
// Check that both branches have identical output shapes.
OP_REQUIRES(
ctx,
xla::ShapeUtil::Compatible(then_result.xla_output_shape,
else_result.xla_output_shape),
errors::InvalidArgument(
"Output shapes of then and else branches do not match: ",
xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ",
xla::ShapeUtil::HumanString(else_result.xla_output_shape)));
// Check that both branches have same TensorList output indices.
for (int output_index = 0; output_index < then_result.outputs.size();
output_index++) {
bool is_tensor_list_in_then_branch =
then_result.outputs[output_index].is_tensor_list;
bool is_tensor_list_in_else_branch =
else_result.outputs[output_index].is_tensor_list;
OP_REQUIRES(
ctx, is_tensor_list_in_then_branch == is_tensor_list_in_else_branch,
errors::FailedPrecondition("Output #", output_index, " is ",
(is_tensor_list_in_then_branch ? "" : "not"),
" a TensorList in then branch, but is ",
(is_tensor_list_in_else_branch ? "" : "not"),
" a TensorList in else branch"));
}
VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape);
VLOG(2) << "Output shape: "
<< xla::ShapeUtil::HumanString(then_result.xla_output_shape);
// We set return_updated_values_for_all_resources=true and we pass the same
// arguments to both computations, so the resource update count must match.
OP_REQUIRES(ctx,
then_result.resource_updates.size() ==
else_result.resource_updates.size(),
errors::FailedPrecondition(
"Different number of resources in then and else branch"));
for (int i = 0; i < then_result.resource_updates.size(); ++i) {
const auto& lhs = then_result.resource_updates[i];
const auto& rhs = else_result.resource_updates[i];
bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape &&
lhs.tensor_array_gradients_accessed ==
rhs.tensor_array_gradients_accessed;
OP_REQUIRES(
ctx, equal,
errors::FailedPrecondition(
"Mismatch in resource of then and else branch for resource ", i));
}
OP_REQUIRES_OK(ctx, ValidateShapes(ctx, then_result, else_result));
int num_inputs = then_result.input_mapping.size();
std::vector<xla::XlaOp> inputs(num_inputs);
@ -263,22 +291,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
}
auto input_tuple = xla::Tuple(b, inputs);
xla::XlaOp input_tuple = xla::Tuple(b, inputs);
xla::XlaOp outputs =
xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation,
input_tuple, *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i;
auto shape_or = b->GetShape(output_handle);
if (shape_or.ok()) {
LOG(INFO) << "Shape for output " << i << ": "
<< xla::ShapeUtil::HumanString(shape_or.ValueOrDie());
} else {
LOG(INFO) << "Shape unknown for output " << i;
}
xla::StatusOr<xla::Shape> shape = b->GetShape(output_handle);
VLOG(2) << "Setting output " << i << " with shape "
<< (shape.ok() ? shape->ToString() : "<unknown>");
}
// We have checked that both branches have same TensorList output indices.
if (then_result.outputs[i].is_tensor_list) {
@ -287,6 +311,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
ctx->SetOutput(i, output_handle);
}
}
if (has_token_input_output_) {
// Set token output for this "If" op. Token output is the last output of
// XLA computation, which comes after all "normal" TF outputs and resource

View File

@ -30,8 +30,15 @@ class ShardingOp : public XlaOpKernel {
~ShardingOp() override = default;
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
auto shape_or = ctx->InputXlaShape(0);
xla::XlaOp input;
{
// The builder might create a broadcast from a constant, so we clear
// sharding for the input.
xla::XlaScopedShardingAssignment no_sharding(ctx->builder(),
absl::nullopt);
input = ctx->Input(0);
}
auto shape_or = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, shape_or.status());
ctx->SetOutput(

View File

@ -28,6 +28,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -415,8 +416,11 @@ sharding = gen_xla_ops.xla_sharding
@ops.RegisterGradient("XlaSharding")
def _sharding_grad(op, grad):
del op # Unused
return [grad]
grad_sharding = gen_xla_ops.xla_sharding(grad)
# pylint: disable=protected-access
grad_sharding.op._set_attr(
"_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding")))
return [grad_sharding]
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape

View File

@ -1030,6 +1030,11 @@ Status XlaCompiler::BuildArguments(
xla::XlaScopedShardingAssignment assign_sharding(
builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
: it->second);
auto& arg = args[input_to_args->at(i)];
xla::OpMetadata arg_metadata;
arg_metadata.set_op_name(arg.node_name);
builder->SetOneShotOpMetadata(arg_metadata);
arg_handles[i] = xla::GetTupleElement(tuple, i);
}
} else {

View File

@ -3021,7 +3021,12 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
instr.add_operand_ids(operand.handle());
}
*instr.mutable_metadata() = metadata_;
if (one_shot_metadata_.has_value()) {
*instr.mutable_metadata() = one_shot_metadata_.value();
one_shot_metadata_.reset();
} else {
*instr.mutable_metadata() = metadata_;
}
if (sharding_) {
*instr.mutable_sharding() = *sharding_;
}

View File

@ -153,6 +153,11 @@ class XlaBuilder {
// OpMetadata attached until a call to ClearOpMetadata.
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
// Similar to SetOpMetadata, but only set the metadata for the next op.
void SetOneShotOpMetadata(OpMetadata metadata) {
metadata_ = std::move(metadata);
}
// Clears the HloMetadata state.
void ClearOpMetadata() { metadata_.Clear(); }
@ -842,6 +847,9 @@ class XlaBuilder {
// throughout the TensorFlow op kernel implementations).
OpMetadata metadata_;
// A temporary metadata that will only be applied to the next op created.
absl::optional<OpMetadata> one_shot_metadata_;
// Sharding for this operator. This is structured as a "model"-like operation,
// in order to simplify client code, similar to metadata_.
absl::optional<OpSharding> sharding_;

View File

@ -17,6 +17,8 @@ upper_tabs:
path: /xla
- title: XLA architecture
path: /xla/architecture
- title: Known issues
path: /xla/known_issues
- title: Broadcasting semantics
path: /xla/broadcasting
- title: Develop a new backend for XLA

View File

@ -177,30 +177,6 @@ a bug to a single XLA program by using the
[`replay_computation`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/run_hlo_module_main.cc)
and iteratively running it on generated programs.
## Known Issues
Compilation with XLA can greatly improve the performance of your programs, but
the TensorFlow interop has a number of known sharp corners.
### TensorArray TF/XLA Interconversion
The problem manifests itself as an error message
`Support for TensorList crossing the XLA/TF boundary is not implemented`.
XLA supports `tf.TensorArray`. However, the _interconversion_ between TF and
XLA representations is not implemented yet.
This error often arises when the `TensorArray` is used inside the compiled
block, but the derivative is taken outside.
Workaround: compile the outermost scope which is taking the derivative.
### Random Number Generation
XLA currently ignores TF seeds to random operations. This affects stateful TF
random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will
behave as if the compilation was seeded with a new unique seed at each run. This
limitation does not apply to stateless random ops.
## XLA Frontends
Apart from TensorFlow, XLA programs can be generated by:

View File

@ -0,0 +1,32 @@
# Known Issues
Compilation with XLA can greatly improve the performance of your programs, but
the TensorFlow interop has a number of known sharp corners.
## TensorArray TF/XLA interconversion
The problem manifests itself as an error message
`Support for TensorList crossing the XLA/TF boundary is not implemented`.
XLA supports `tf.TensorArray`. However, the _interconversion_ between TF and
XLA representations is not implemented yet.
This error often arises when the `TensorArray` is used inside the compiled
block, but the derivative is taken outside.
Workaround: compile the outermost scope which is taking the derivative.
## Dynamic `tf.TensorArray` is not supported
Writes into `tf.TensorArray(..., dynamic_size=True)` are not compilable with
XLA, as such writes require an unknown number of reallocations when the array
exceeds the original bound.
Workaround: provide a statically known bound to your arrays.
## Random number generation
XLA currently ignores TF seeds to random operations. This affects stateful TF
random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will
behave as if the compilation was seeded with a new unique seed at each run. This
limitation does not apply to stateless random ops.

View File

@ -114,24 +114,26 @@ void BuildOpsSubmodule(py::module* m) {
"CustomCall",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const py::bytes& opaque) -> XlaOp {
return CustomCall(builder, call_target_name, operands, shape, opaque);
const py::bytes& opaque, bool has_side_effect) -> XlaOp {
return CustomCall(builder, call_target_name, operands, shape, opaque,
has_side_effect);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape"), py::arg("opaque") = py::bytes(""));
py::arg("shape"), py::arg("opaque") = py::bytes(""),
py::arg("has_side_effect") = false);
ops.def(
"CustomCallWithLayout",
[](XlaBuilder* builder, const py::bytes& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout,
const py::bytes& opaque) -> XlaOp {
return CustomCallWithLayout(builder, call_target_name, operands,
shape_with_layout,
operand_shapes_with_layout, opaque);
const py::bytes& opaque, bool has_side_effect) -> XlaOp {
return CustomCallWithLayout(
builder, call_target_name, operands, shape_with_layout,
operand_shapes_with_layout, opaque, has_side_effect);
},
py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
py::arg("opaque") = py::bytes(""));
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
py::arg("precision_config") = nullptr);
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),

View File

@ -2475,6 +2475,33 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
}
}
{
HloInstruction *a, *b, *c1, *c2;
// Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
// constant1*constant2)
if (Match(multiply,
m::Multiply(
m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
TF_ASSIGN_OR_RETURN(auto* product_of_constants,
MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
!ShapeUtil::IsScalar(multiply->shape())) {
product_of_constants =
computation_->AddInstruction(HloInstruction::CreateBroadcast(
multiply->shape(), product_of_constants, {}));
}
return ReplaceWithNewInstruction(
multiply,
HloInstruction::CreateBinary(
multiply->shape(), HloOpcode::kMultiply,
computation_->AddInstruction(HloInstruction::CreateBinary(
multiply->shape(), HloOpcode::kMultiply, a, b)),
product_of_constants));
}
}
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
HloInstruction *a, *c1, *c2;
if (Match(multiply,
@ -3088,6 +3115,17 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
HloOpcode::kMultiply, lhs, lhs));
}
// Pow(A, 3) is used in GELU.
VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
if (IsAll(rhs, 3)) {
HloInstruction* tmp =
computation_->AddInstruction(HloInstruction::CreateBinary(
power->shape(), HloOpcode::kMultiply, lhs, lhs));
return ReplaceWithNewInstruction(
power, HloInstruction::CreateBinary(power->shape(),
HloOpcode::kMultiply, lhs, tmp));
}
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
return ReplaceWithNewInstruction(

View File

@ -117,6 +117,29 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) {
m::ConstantScalar(0.125))));
}
// (A*C1) * (B*C2) => (A*B)*(C1*C2)
TEST_F(AlgebraicSimplifierTest, MultiplyChain) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
c = f32[] constant(2)
d = f32[] constant(4)
x = f32[] multiply(p0, c)
y = f32[] multiply(p1, d)
ROOT z = f32[] multiply(x, y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::MultiplyAnyOrder(
m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)),
m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4)))));
}
// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2.
TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) {
const char* kModuleStr = R"(
@ -1568,6 +1591,32 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0))));
}
// Test that pow(A, 3) is simplified to A*A*A.
TEST_F(AlgebraicSimplifierTest, Pow3) {
auto m = CreateNewVerifiedModule();
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* three = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three));
auto computation = m->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three))));
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
EXPECT_THAT(
computation->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0),
m::Multiply(m::Parameter(0), m::Parameter(0)))));
}
// Test that pow(A, -1) is simplified to 1/A.
TEST_F(AlgebraicSimplifierTest, PowNegative1) {
auto m = CreateNewVerifiedModule();

View File

@ -270,11 +270,48 @@ Status DotOpEmitter::EmitLinalgMatmul() {
return EmitMlirFuncAndCall(
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1);
CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1);
mlir::MLIRContext* context = builder->getContext();
mlir::edsc::ScopedContext scope(*builder, function.getLoc());
mlir::Value a = function.getArgument(0), b = function.getArgument(1),
c = function.getArgument(2);
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
mlir::ValueRange{b, c, a});
llvm::SmallVector<mlir::AffineExpr, 2> b_exprs(
dot_info_.lhs_shape.rank());
llvm::SmallVector<mlir::AffineExpr, 2> c_exprs(
dot_info_.rhs_shape.rank());
llvm::SmallVector<mlir::AffineExpr, 2> parallel_exprs;
mlir::AffineExpr reduce_expr;
for (int i = 0; i != dot_info_.result_shape.rank(); ++i) {
parallel_exprs.push_back(mlir::getAffineDimExpr(i, context));
}
reduce_expr =
mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context);
// The reduction expr is shared for both inputs.
b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr;
c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr;
// Fill in the remaining parallel exprs.
int par_expr_num = 0;
for (auto* v : {&b_exprs, &c_exprs}) {
for (auto& e : *v) {
if (!e) {
e = parallel_exprs[par_expr_num++];
}
}
}
llvm::SmallVector<mlir::IteratorType, 4> types(
parallel_exprs.size(), mlir::IteratorType::Parallel);
types.push_back(mlir::IteratorType::Reduction);
mlir::edsc::StructuredIndexed s_a(a), s_b(b), s_c(c);
mlir::edsc::makeGenericLinalgOp(types, {s_b(b_exprs), s_c(c_exprs)},
{s_a(parallel_exprs)},
mlir::edsc::ops::macRegionBuilder);
mlir::edsc::intrinsics::std_ret();
mlir::linalg::LinalgTilingOptions tilingOptions;
@ -283,13 +320,13 @@ Status DotOpEmitter::EmitLinalgMatmul() {
target_machine_features_.minimum_alignment_for_allocation(
ShapeUtil::ByteSizeOf(dot_info_.result_shape));
mlir_strategy::MatmulCodegenStrategy strategy;
strategy.tile<mlir::linalg::MatmulOp>(tilingOptions)
.promote<mlir::linalg::MatmulOp>(
strategy.tile<mlir::linalg::GenericOp>(tilingOptions)
.promote<mlir::linalg::GenericOp>(
mlir::linalg::LinalgPromotionOptions()
.setAlignment(alignment)
.setUseFullTileBuffersByDefault(true)
.setUseAlloca(true))
.vectorize<mlir::linalg::MatmulOp>()
.vectorize<mlir::linalg::GenericOp>()
.setVectorTransformsOptions(
mlir::vector::VectorTransformsOptions()
.setVectorTransformsOptions(

View File

@ -148,15 +148,12 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
Status HandleDomain(HloInstruction* hlo) override;
private:
using DimensionConstraint = DynamicDimensionInference::DimensionConstraint;
using OperandDynamicDimensionFn = std::function<Status(
HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint)>;
int64 operand_index, HloInstruction* dynamic_size)>;
using DynamicDimensionFn = std::function<Status(
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size,
DimensionConstraint constraint)>;
ShapeIndex index, int64 dimension, HloInstruction* dynamic_size)>;
Status ForEachOperandDynamicDimension(HloInstruction* inst,
const OperandDynamicDimensionFn&);
@ -184,8 +181,7 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
return UnimplementedStrCat(
"Asked to propagate a dynamic dimension from hlo ", operand->name(),
"@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
@ -197,13 +193,11 @@ Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
if (hlo->tuple_index() == index[0]) {
ShapeIndex new_index =
ShapeIndexView(index).ConsumeFront().ToShapeIndex();
parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
}
return Status::OK();
});
@ -212,11 +206,9 @@ Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
index.push_front(operand_index);
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
return Status::OK();
});
}
@ -224,11 +216,9 @@ Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
int64 broadcast_dim = hlo->dimensions(dimension);
parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
return Status::OK();
});
}
@ -244,8 +234,7 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
// returns the padded data output and the dynamic sizes of input
// dimensions.
ShapeIndex data_output = {0};
parent_->SetDynamicSize(hlo, data_output, i, dynamic_size,
DimensionConstraint(1, 1));
parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
}
}
return Status::OK();
@ -255,15 +244,14 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
}
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
// Resize custom call should propagate dynamic batch (0) and channel (3)
// dimensions.
if (hlo->custom_call_target() == "SliceToDynamic" ||
hlo->custom_call_target() == "Sharding" ||
(absl::StartsWith(hlo->custom_call_target(), "Resize") &&
(dimension == 0 || dimension == 3))) {
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
return Status::OK();
}
return Unimplemented(
@ -274,16 +262,15 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index,
int64 dynamic_dimension, int64 operand_index,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
hlo,
[&](HloInstruction* operand, ShapeIndex index, int64 dynamic_dimension,
int64 operand_index, HloInstruction* dynamic_size) {
HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
if (sort->values_count() == 0) {
parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
} else {
parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
dynamic_size, constraint);
dynamic_size);
}
return Status::OK();
@ -293,8 +280,7 @@ Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
if (operand_index != 0) {
return Unimplemented(
"Dynamic dimension on padding value is not supported");
@ -311,8 +297,7 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
dynamic_size_adjusted->shape(), HloOpcode::kAdd,
dynamic_size_adjusted, adjustment));
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted,
constraint);
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
return Status::OK();
} else {
return Unimplemented(
@ -327,8 +312,7 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
HloInstruction* reduce = hlo;
int64 operand_count = reduce->operand_count();
bool is_variadic_reduce = operand_count > 2;
@ -354,13 +338,12 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
// reduce has a dynamic dimension, we set all outputs to use the
// same dynamic size in corresponding dimensions.
for (int64 i = 0; i < operand_count / 2; ++i) {
parent_->SetDynamicSize(reduce, {i},
dimensions_not_reduced_count,
dynamic_size, constraint);
parent_->SetDynamicSize(
reduce, {i}, dimensions_not_reduced_count, dynamic_size);
}
} else {
parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
dynamic_size, constraint);
dynamic_size);
}
return Status::OK();
@ -378,7 +361,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index,
int64 operand_dimension, int64 operand_index,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
HloInstruction* dynamic_size) {
// There are three types of dimensions in a dot:
// A. batch dims
// B. contracting dims
@ -451,8 +434,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
// work item to trace that dimension.
auto iter = result_dim_mapping.find(operand_dimension);
if (iter != result_dim_mapping.end()) {
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size,
constraint);
parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
}
return Status::OK();
@ -463,8 +445,7 @@ Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo,
[&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) -> Status {
int64 operand_index, HloInstruction* dynamic_size) -> Status {
int64 permuted_dim = -1;
for (int64 i = 0; i < hlo->dimensions().size(); ++i) {
if (hlo->dimensions()[i] == dimension) {
@ -472,8 +453,7 @@ Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
permuted_dim = i;
}
}
parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
return Status::OK();
});
}
@ -482,8 +462,7 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
HloInstruction* conv = hlo;
const ConvolutionDimensionNumbers& dimension_numbers =
conv->convolution_dimension_numbers();
@ -492,7 +471,7 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution(
if (dimension == dimension_numbers.input_batch_dimension()) {
parent_->SetDynamicSize(conv, {},
dimension_numbers.output_batch_dimension(),
dynamic_size, constraint);
dynamic_size);
return Status::OK();
}
@ -542,20 +521,18 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate(
dim_size_total, dynamic_dim));
}
parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
dim_size_total, DimensionConstraint(1, 1));
dim_size_total);
}
// Simply pass through non-concat dynamic dimensions.
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
int64 concatenate_dimension = hlo->concatenate_dimension();
if (concatenate_dimension == dimension) {
return Status::OK();
}
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
return Status::OK();
});
}
@ -596,18 +573,15 @@ Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
if (!dimension_is_static) {
// Propagate dynamic dimension indicated by this set dimension size
// instruction.
parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1),
DimensionConstraint(1, 1));
parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
}
// Also Propagate dynamic dimension already set by operands.
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
if (dimension != hlo->dimension()) {
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
}
return Status::OK();
}));
@ -619,10 +593,8 @@ Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size,
constraint);
int64 operand_index, HloInstruction* dynamic_size) {
parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
return Status::OK();
});
}
@ -654,8 +626,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
hlo,
[&](HloInstruction* operand, ShapeIndex index,
int64 input_dynamic_dimension, int64 operand_index,
HloInstruction* operand_dynamic_size,
DimensionConstraint constraint) -> Status {
HloInstruction* operand_dynamic_size) -> Status {
HloInstruction* reshape = hlo;
if (reshape->shape().rank() == 0) {
VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
@ -751,9 +722,6 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
if (output_dynamic_dimension == -1 &&
output_dim_end - output_dim_start > 1) {
// TODO(yunxing): We now have a better way to decide output dimension
// in the bridge. No need for this constraint propagation logic.
//
// One input dimension is splitted into multiple output dimensions.
// Output dimension is decomposed from input most major dimension.
// In this case, we don't know which one is dynamic, e.g., when we
@ -770,61 +738,17 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
// We use the following logics to disambiguate:
// 1. If the user sets "inferred_dimension", then use that as
// dynamic dimension.
// 2. If the one dimension in the reshape is dynamic, use that as
// dynamic dimension.
// E.g.:
// [<=4]
// |
// reshape
// |
// [1, <=2, 2]
// We use second dim as dynamic dimension.
//
// 2. Use the "multiple_of" constraint, e.g, :
// [<=2, 4]
// | Reshape
// [<=8]
// | Reshape
// [2, 4] // Which is dynamic?
//
// If the dynamic value has to be multiple of 4 (constraint
// created by the first reshape), then 2 must be the dynamic
// dimension.
//
// But this logic doesn't help with the case where two
// dimensions are the same:
//
// [<=3, 3]
// | Reshape
// [<=9]
// | Reshape
// [3, 3] // Which is dynamic?
//
// Both dynamic dimension can be multiple of 3.
//
// We then need the next constraint to disambiguate this case:
//
// 3. Use the "stride" constraint (also see the comment at the
// definition):
//
// [<=3, 3]
// | Reshape
// [<=9] // constraint.stride = 1
// | Reshape
// [3, 3]
// ^ ^
// | |
// stride= 1 3
//
// Each dimension will have different strides, only one will
// satisfy the stride constraint.
//
// Note that the stride constrint itself is not enough:
//
//
// [<=128]
// | Reshape
// [1, 128]
// ^ ^
// | |
// stride= 1 1
//
// In this case, both dimensions have the same stride, which is
// ambiguous. That's why we need the "multiple_of" constraint
// as used above.
//
// 4. If all logics above cannot disambiguate, e.g.,:
// 3. If all logics above cannot disambiguate, e.g.,:
//
// [<=1]
// |
@ -833,68 +757,15 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
// [1, 1, 1]
//
// We bail out and return an error.
// TODO(yunxing): Further simplify this, remove 1. and fully rely
// on 2.
output_dynamic_dimension = reshape->inferred_dimension();
if (output_dynamic_dimension == -1) {
// The user of XLA didn't specify a dynamic dimension, try infer
// it from the current constraint.
//
// Find all output dimensions that are decomposed from the first
// dimension. Among those dimensions, find all dimensions that
// satisfy the constraint of the dynamic dimension. In the
// previous example, if `a` is 9 and constraint is a multiple of
// `3', then in the output shape both a/c and c can be dynamic.
int64 current_product = 1;
int64 dimension_iter = output_dim_start;
// compatible_dimensions are dimensions that satisfies
// "multiple_of" constraints.
std::vector<int64> compatible_dimensions;
while (current_product <
operand->shape().dimensions(input_dynamic_dimension)) {
current_product *= reshape->shape().dimensions(dimension_iter);
if (operand->shape().dimensions(input_dynamic_dimension) /
reshape->shape().dimensions(dimension_iter) ==
constraint.multiple_of) {
compatible_dimensions.push_back(dimension_iter);
// Try find dynamic dimension from the result shape.
for (int64 i = 0; i < reshape->shape().rank(); ++i) {
if (reshape->shape().is_dynamic_dimension(i)) {
output_dynamic_dimension = i;
}
dimension_iter++;
}
CHECK_EQ(current_product,
operand->shape().dimensions(input_dynamic_dimension))
<< "Not a valid reshape: " << hlo->ToString();
// If there is only one compatible dimension, it must be the
// dynamic one in the output.
if (compatible_dimensions.size() == 1) {
output_dynamic_dimension = compatible_dimensions[0];
}
// When there are multiple compatible dimensions, e.g:
// [<=9]
// | Reshape
// [3, 3]
// Use stride constraint to figure out which one is the true
// dynamic one.
//
// [<=9]
// | Reshape
// [3, 3]
// ^ ^
// | |
// stride= 1 3
//
std::vector<int64> compatible_dimensions_with_stride;
absl::c_copy_if(
compatible_dimensions,
std::back_inserter(compatible_dimensions_with_stride),
[&](int64 dimension) {
int64 stride_total = 1;
for (int64 i = 0; i < dimension + 1; ++i) {
stride_total *= reshape->shape().dimensions(dimension);
}
return stride_total == constraint.stride;
});
if (compatible_dimensions_with_stride.size() == 1) {
output_dynamic_dimension = compatible_dimensions_with_stride[0];
}
}
@ -914,9 +785,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
return InvalidArgument(
"Reshape's input dynamic dimension is decomposed into "
"multiple output dynamic dimensions, but the constraint is "
"ambiguous and XLA can't infer the output dimension %s. "
"Constraint: multiple_of: %lld, stride: %lld",
hlo->ToString(), constraint.multiple_of, constraint.stride);
"ambiguous and XLA can't infer the output dimension %s. ",
hlo->ToString());
}
}
@ -931,7 +801,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
if (input_dim_size == output_dim_size) {
// Simply forward dynamic dimension.
parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
operand_dynamic_size, constraint);
operand_dynamic_size);
}
if (input_dim_size > output_dim_size) {
@ -946,9 +816,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
operand_dynamic_size->shape(), HloOpcode::kDivide,
operand_dynamic_size, divisor_hlo));
parent_->SetDynamicSize(
reshape, {}, output_dynamic_dimension, new_dynamic_size,
DimensionConstraint(1, constraint.multiple_of / divisor));
parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
new_dynamic_size);
}
if (input_dim_size < output_dim_size) {
@ -985,12 +854,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
output_dynamic_size->shape(), HloOpcode::kMultiply,
new_dynamic_size, operand_dynamic_size));
int64 new_multiple_of_constraint =
constraint.multiple_of * output_dim_size /
operand->shape().dimensions(input_dynamic_dimension);
parent_->SetDynamicSize(
reshape, {}, output_dynamic_dimension, new_dynamic_size,
DimensionConstraint(1, new_multiple_of_constraint));
parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
new_dynamic_size);
}
return Status::OK();
@ -1001,8 +866,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
HloInstruction* reduce_window = hlo;
const WindowDimension& window_dimension =
reduce_window->window().dimensions(dimension);
@ -1013,8 +877,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
reduce_window->ToString());
}
parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size);
return Status::OK();
});
@ -1024,8 +887,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
HloInstruction* select_and_scatter = hlo;
const WindowDimension& window_dimension =
select_and_scatter->window().dimensions(dimension);
@ -1036,8 +898,8 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
select_and_scatter->ToString());
}
parent_->SetDynamicSize(select_and_scatter, {}, dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(select_and_scatter, {}, dimension,
dynamic_size);
return Status::OK();
});
@ -1046,8 +908,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension,
int64 /*operand_index*/, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 /*operand_index*/, HloInstruction* dynamic_size) {
if (hlo->slice_starts(dimension) != 0 ||
hlo->slice_strides(dimension) != 1 ||
hlo->slice_limits(dimension) !=
@ -1056,7 +917,7 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
return Status::OK();
}
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
return Status::OK();
});
@ -1066,8 +927,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64 dimension,
int64 /*operand_index*/, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 /*operand_index*/, HloInstruction* dynamic_size) {
if (hlo->shape().dimensions(dimension) !=
hlo->operand(0)->shape().dimensions(dimension)) {
// Slicing a single element out kills the dynamic dimension.
@ -1080,7 +940,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
hlo->ToString());
}
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
return Status::OK();
});
@ -1089,9 +949,9 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/,
int64 dimension, int64 /*operand_index*/,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
hlo,
[&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
int64 /*operand_index*/, HloInstruction* dynamic_size) {
if (hlo->shape().dimensions(dimension) !=
hlo->operand(0)->shape().dimensions(dimension)) {
return Unimplemented(
@ -1100,7 +960,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
hlo->ToString());
}
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
return Status::OK();
});
@ -1108,16 +968,16 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/,
int64 dimension, int64 /*operand_index*/,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
hlo,
[&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
int64 /*operand_index*/, HloInstruction* dynamic_size) {
if (absl::c_linear_search(hlo->dimensions(), dimension)) {
return Unimplemented(
"Dynamic dimension propagation on reversed dimension is not "
"supported %s",
hlo->ToString());
}
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint);
parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
return Status::OK();
});
@ -1127,7 +987,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
int64 input_dynamic_dimension, int64 operand_index,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
HloInstruction* dynamic_size) {
const GatherDimensionNumbers& gather_dims =
hlo->gather_dimension_numbers();
if (operand_index != 1) {
@ -1147,8 +1007,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
output_dimension--;
}
}
parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
return Status::OK();
}
return Unimplemented(
@ -1171,8 +1030,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
indices_dim++;
}
if (indices_dim++ == input_dynamic_dimension) {
parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
return Status::OK();
}
}
@ -1220,8 +1078,7 @@ Status DynamicDimensionInferenceVisitor::HandleConditional(
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
hlo, operand_index,
[&](HloInstruction*, ShapeIndex, int64, int64,
HloInstruction* dynamic_size,
DimensionConstraint constraint) -> Status {
HloInstruction* dynamic_size) -> Status {
TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
<< "Only tuple typed inputs can have dynamic dimension. Please "
"file a bug against XLA team.";
@ -1263,8 +1120,7 @@ Status DynamicDimensionInferenceVisitor::HandleConditional(
TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
hlo, operand_index,
[&](HloInstruction*, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* dynamic_size) {
DynamicParameterBinding::DynamicParameter dynamic_parameter{
0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
DynamicParameterBinding::DynamicDimension dynamic_dimension{
@ -1284,8 +1140,8 @@ Status DynamicDimensionInferenceVisitor::HandleConditional(
// that into the root instruction as additional tuple elements.
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
new_computation->root_instruction(),
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
DimensionConstraint) -> Status {
[&](ShapeIndex index, int64 dim,
HloInstruction* dynamic_size) -> Status {
TF_RET_CHECK(hlo->shape().IsTuple())
<< "Only tuple typed conditionals can have dynamic dimension. "
"Please file a bug against XLA team.";
@ -1347,11 +1203,9 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
return ForEachOperandDynamicDimension(
hlo,
[&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
int64 operand_index, HloInstruction* operand_dynamic_size,
DimensionConstraint constraint) {
int64 operand_index, HloInstruction* operand_dynamic_size) {
if (operand_index == 0) {
parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size,
constraint);
parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
return Status::OK();
}
@ -1385,7 +1239,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
int64 operand_count = original_tuple_count;
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64,
HloInstruction* dynamic_size, DimensionConstraint constraint) {
HloInstruction* dynamic_size) {
operands_to_add.push_back(dynamic_size);
dynamic_output_mapping.mutable_element(index)->emplace(dim,
operand_count++);
@ -1413,8 +1267,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
hlo,
[&](HloInstruction*, ShapeIndex index, int64 dimension,
int64 operand_index, HloInstruction* dynamic_size,
DimensionConstraint constraint) -> Status {
int64 operand_index, HloInstruction* dynamic_size) -> Status {
TF_RET_CHECK(!operands_to_add.empty());
const int64 output_dynamic_size_index =
dynamic_output_mapping.element(index).at(dimension);
@ -1431,7 +1284,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
ShapeUtil::MakeScalarShape(S32), hlo,
output_dynamic_size_index));
parent_->SetDynamicSize(result.replacement_instr, index, dimension,
output_dynamic_size, constraint);
output_dynamic_size);
return Status::OK();
}));
// Set the replacement instruction as visited to avoid visiting it again.
@ -1465,8 +1318,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
// Add dynamic dimension size as new parameters.
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
hlo->while_body()->root_instruction(),
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size,
DimensionConstraint) -> Status {
[&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size) -> Status {
const int64 output_index =
dynamic_output_mapping.element(index).at(dim);
new_root_operands[output_index] = dynamic_size;
@ -1503,8 +1355,7 @@ Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
parent_->SetDynamicSize(target_parameter,
dynamic_dimension.parameter_index,
dynamic_dimension.dimension, dynamic_size,
DimensionConstraint(1, 1));
dynamic_dimension.dimension, dynamic_size);
return Status::OK();
});
}
@ -1517,10 +1368,8 @@ Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
HloInstruction* dynamic_size = parent_->GetDynamicSize(
dynamic_dimension.inst, dynamic_dimension.index,
dynamic_dimension.dim);
CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0);
TF_RETURN_IF_ERROR(fn(dynamic_dimension.index, dynamic_dimension.dim,
dynamic_size,
parent_->constraint_mapping_[dynamic_dimension]));
TF_RETURN_IF_ERROR(
fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
}
}
return Status::OK();
@ -1536,10 +1385,9 @@ Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
HloInstruction* dynamic_size = parent_->GetDynamicSize(
dynamic_dimension.inst, dynamic_dimension.index,
dynamic_dimension.dim);
CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0);
TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
dynamic_dimension.dim, operand_index, dynamic_size,
parent_->constraint_mapping_[dynamic_dimension]));
dynamic_dimension.dim, operand_index,
dynamic_size));
}
}
return Status::OK();
@ -1555,6 +1403,24 @@ Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
return Status::OK();
}
void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
const ShapeIndex& index,
int64 dim,
HloInstruction* size) {
VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
<< index.ToString() << "@" << dim << " to " << size->ToShortString();
Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
CHECK(dim < subshape.rank() && dim >= 0)
<< "Asked to set invalid dynamic dimension. Shape: "
<< subshape.ToString() << ", Dimension: " << dim;
DynamicDimension dynamic_dimension{inst, index, dim};
// Updating a dynamic dimension twice overwrites the previous one.
dynamic_mapping_[dynamic_dimension] = size;
auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
iter.first->second.emplace(dynamic_dimension);
}
void DynamicDimensionInference::CopyMapping(HloInstruction* from,
HloInstruction* to) {
auto iter = per_hlo_dynamic_dimensions_.find(from);
@ -1564,7 +1430,7 @@ void DynamicDimensionInference::CopyMapping(HloInstruction* from,
GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
dynamic_dimension.dim);
SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
dynamic_size, constraint_mapping_[dynamic_dimension]);
dynamic_size);
}
}
}
@ -1624,8 +1490,6 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
auto iter = dynamic_mapping_.find(dynamic_dimension);
if (iter != dynamic_mapping_.end()) {
dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
constraint_mapping_.insert(
{dynamic_dimension_new, constraint_mapping_[dynamic_dimension]});
auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
iter.first->second.emplace(dynamic_dimension_new);
}

View File

@ -55,8 +55,7 @@ class DynamicDimensionInference {
// go into tuples.
bool HasDynamicDimension(HloInstruction* inst) const;
// Forward dynamic dimension size at `dim` and its constraint from `inst` to
// `new_inst`.
// Forward dynamic dimension size at `dim` from `inst` to `new_inst`.
Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst,
const ShapeIndex& index);
@ -64,9 +63,7 @@ class DynamicDimensionInference {
// `inst` at `index` has a dynamic size, and its runtime size is represented
// by a scalar instruction `size`.
void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
HloInstruction* size) {
SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1));
}
HloInstruction* size);
// For all tensors whose dynamic dimension is `replace`, replace them with
// `with`.
@ -106,116 +103,6 @@ class DynamicDimensionInference {
}
};
// DimensionConstraint is attached to each dynamic dimension and describe the
// constraint of each dimension. This is used to disambiguate the index of
// dynamic dimension for reshapes that "splits" a dimension into two.
//
// As an example, consider the following reshapes:
// [<=3, 3] <- Assume first dimension is dynamic.
// |
// Reshape.1
// |
// [<=9] <- Dimension 9 is dynamic
// |
// Reshape.2
// |
// [3, 3] <- Ambiguous dimension after splitting 9 into [3, 3]
//
// There is no way to know which dimension is dynamic by looking at the second
// reshape locally.
//
// However, if we look at the dynamic dimension 9, since it comes from
// collapsing a major dynamic dimension of 3 (the dynamic size can be 0, 1, 2,
// 3, denoted as i in the diagram below) and a minor static dimension of 3, we
// know it has certain constraints that the reshape can only be one of the 4
// forms:
//
// o: Padded Data
// x: Effective Data
//
// [<=3, 3] to [9]
//
// +---+ +---+ +---+ +---+
// |ooo| |ooo| |ooo| |xxx|
// |ooo| |ooo| |xxx| |xxx|
// |ooo| |xxx| |xxx| |xxx|
// +---+ +---+ +---+ +---+
//
// Reshape Reshape Reshape Reshape
//
// +-----------+ +-----------+ +-----------+ +-----------+
// |ooo|ooo|ooo| or |xxx|ooo|ooo| or |xxx|xxx|ooo| or |xxx|xxx|xxx| stride=1
// +-----------+ +-----------+ +-----------+ +-----------+
// i = 0 i = 1 i = 2 i = 3
//
// On the other hand, if the minor dimension 3 is dynamic and major dimension
// is static, we will have the following form:
//
// [3, <=3] to [9]
//
// +---+ +---+ +---+ +---+
// |ooo| |xoo| |xxo| |xxx|
// |ooo| |xoo| |xxo| |xxx|
// |ooo| |xoo| |xxo| |xxx|
// +---+ +---+ +---+ +---+
//
// Reshape Reshape Reshape Reshape
//
// +-----------+ +-----------+ +-----------+ +-----------+
// |ooo|ooo|ooo| or |xoo|xoo|xoo| or |xxo|xxo|xxo| or |xxo|xxo|xxo| stride=3
// +-----------+ +-----------+ +-----------+ +-----------+
// i = 0 i = 1 i = 2 i = 3
//
// By encoding constraint as a stride of elements we can recover this
// information later when we reshape from [9] to [3, 3]. We know which form
// ([3, i] or [i,3]) we should reshape the [9] into.
//
//
struct DimensionConstraint {
explicit DimensionConstraint(int64 s, int64 m)
: stride(s), multiple_of(m) {}
DimensionConstraint() : stride(1), multiple_of(1) {}
// Stride represents the distance of a newly placed element and the previous
// placed element on this dynamic dimension.
int64 stride;
// multiple_of represents the constraints that
//
// `dynamic_size` % `multiple_of` == 0
int64 multiple_of;
};
using ConstraintMapping =
absl::flat_hash_map<DynamicDimension, DimensionConstraint>;
ConstraintMapping constraint_mapping_;
// Update the dynamic mapping so that we know dimension `dim` of instruction
// `inst` at `index` has a dynamic size, and its runtime size is represented
// by a scalar instruction `size`.
void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim,
HloInstruction* size, DimensionConstraint constraint) {
VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
<< index.ToString() << "@" << dim << " to " << size->ToShortString()
<< " constraint: " << constraint.multiple_of;
Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
CHECK(!subshape.IsTuple())
<< "Can't set a tuple shape to dynamic dimension";
CHECK(dim < subshape.rank() && dim >= 0)
<< "Asked to set invalid dynamic dimension. Shape: "
<< subshape.ToString() << ", Dimension: " << dim;
DynamicDimension dynamic_dimension{inst, index, dim};
// Updating a dynamic dimension twice overwrites the previous one.
dynamic_mapping_[dynamic_dimension] = size;
if (constraint_mapping_.count(dynamic_dimension) != 0) {
CHECK_EQ(constraint_mapping_[dynamic_dimension].stride,
constraint.stride);
}
constraint_mapping_[dynamic_dimension] = constraint;
auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
iter.first->second.emplace(dynamic_dimension);
}
// Copies the internal mapping from instruction `from` to instruction `to`.
// This is useful when an instruction is replaced by the other during the
// inferencing process.

View File

@ -27,6 +27,7 @@ load(
"if_cuda_is_configured",
)
load("//tensorflow:tensorflow.bzl", "if_nccl")
load("//third_party/mlir:tblgen.bzl", "gentbl")
package(
default_visibility = [":friends"],
@ -686,7 +687,7 @@ cc_library(
":gpu_autotuning_proto_cc",
":gpu_conv_runner",
":gpu_executable",
":hlo_algorithm_blacklist",
":hlo_algorithm_denylist",
":ir_emission_utils",
":stream_executor_util",
"@com_google_absl//absl/algorithm:container",
@ -1660,9 +1661,9 @@ tf_proto_library_cc(
)
cc_library(
name = "hlo_algorithm_blacklist",
srcs = ["hlo_algorithm_blacklist.cc"],
hdrs = ["hlo_algorithm_blacklist.h"],
name = "hlo_algorithm_denylist",
srcs = ["hlo_algorithm_denylist.cc"],
hdrs = ["hlo_algorithm_denylist.h"],
deps = [
":gpu_autotuning_proto_cc",
"//tensorflow/compiler/xla:debug_options_flags",
@ -1673,12 +1674,12 @@ cc_library(
)
tf_cc_test(
name = "hlo_algorithm_blacklist_test",
srcs = ["hlo_algorithm_blacklist_test.cc"],
name = "hlo_algorithm_denylist_test",
srcs = ["hlo_algorithm_denylist_test.cc"],
data = ["data/hlo_algorithm_denylist.pbtxt"],
tags = ["no_pip"],
deps = [
":hlo_algorithm_blacklist",
":hlo_algorithm_denylist",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@ -1875,3 +1876,49 @@ cc_library(
"@com_google_absl//absl/types:span",
],
)
gentbl(
name = "xla_thunks_ops_inc_gen",
tbl_outs = [
("-gen-op-decls", "ir/xla_thunks_ops.h.inc"),
("-gen-op-defs", "ir/xla_thunks_ops.cc.inc"),
("-gen-struct-attr-decls", "ir/xla_thunks_structs.h.inc"),
("-gen-struct-attr-defs", "ir/xla_thunks_structs.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/xla_thunks_ops.td",
td_srcs = [
"@llvm-project//mlir:LLVMOpsTdFiles",
],
)
cc_library(
name = "xla_thunks_ops",
srcs = [
"ir/xla_thunks_ops.cc",
"ir/xla_thunks_ops.cc.inc",
"ir/xla_thunks_ops.h.inc",
],
hdrs = [
"ir/xla_thunks_ops.h",
],
deps = [
":xla_thunks_ops_inc_gen",
"//tensorflow/compiler/mlir/hlo",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
],
)
# Library with XLA thunks dialect static initialization.
cc_library(
name = "xla_thunks_dialect_registration",
srcs = [
"ir/dialect_registration.cc",
],
deps = [
":xla_thunks_ops",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
#include <string>

Some files were not shown because too many files have changed in this diff Show More