Resolved merge conflicts
This commit is contained in:
commit
f5762da2e7
2
.bazelrc
2
.bazelrc
@ -323,6 +323,8 @@ build:windows --copt=/experimental:preprocessor
|
|||||||
build:windows --host_copt=/experimental:preprocessor
|
build:windows --host_copt=/experimental:preprocessor
|
||||||
|
|
||||||
# Misc build options we need for windows.
|
# Misc build options we need for windows.
|
||||||
|
build:windows --linkopt=/DEBUG
|
||||||
|
build:windows --host_linkopt=/DEBUG
|
||||||
build:windows --linkopt=/OPT:REF
|
build:windows --linkopt=/OPT:REF
|
||||||
build:windows --host_linkopt=/OPT:REF
|
build:windows --host_linkopt=/OPT:REF
|
||||||
build:windows --linkopt=/OPT:ICF
|
build:windows --linkopt=/OPT:ICF
|
||||||
|
6
.github/bot_config.yml
vendored
6
.github/bot_config.yml
vendored
@ -12,12 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
#
|
|
||||||
# THIS IS A GENERATED DOCKERFILE.
|
|
||||||
#
|
|
||||||
# This file was assembled from multiple pieces, whose use is documented
|
|
||||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
|
||||||
# for more information.
|
|
||||||
|
|
||||||
# A list of assignees
|
# A list of assignees
|
||||||
assignees:
|
assignees:
|
||||||
|
28
.github/workflows/update-nightly.yml
vendored
Normal file
28
.github/workflows/update-nightly.yml
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch: # Allow manual triggers
|
||||||
|
schedule:
|
||||||
|
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
|
||||||
|
name: Set nightly branch to master HEAD
|
||||||
|
jobs:
|
||||||
|
master-to-nightly:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: zofrex/mirror-branch@v1
|
||||||
|
name: Set nightly branch to master HEAD
|
||||||
|
with:
|
||||||
|
target-branch: 'nightly'
|
22
RELEASE.md
22
RELEASE.md
@ -1,4 +1,4 @@
|
|||||||
# Release 2.4.0
|
h# Release 2.4.0
|
||||||
|
|
||||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||||
|
|
||||||
@ -209,8 +209,13 @@
|
|||||||
* Improvements to Keras preprocessing layers:
|
* Improvements to Keras preprocessing layers:
|
||||||
* TextVectorization can now accept a vocabulary list or file as an
|
* TextVectorization can now accept a vocabulary list or file as an
|
||||||
init arg.
|
init arg.
|
||||||
|
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||||
|
accepts a `return_attention_scores` argument. When set to
|
||||||
|
True, the layer returns the attention scores as an additional output
|
||||||
|
argument.
|
||||||
|
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||||
|
with the same implementation as their `tf.losses` equivalent.
|
||||||
* `tf.function` / AutoGraph:
|
* `tf.function` / AutoGraph:
|
||||||
|
|
||||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||||
True, the function may use type annotations to optimize the tracing
|
True, the function may use type annotations to optimize the tracing
|
||||||
performance.
|
performance.
|
||||||
@ -296,16 +301,21 @@
|
|||||||
|
|
||||||
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
||||||
|
|
||||||
* `TensorRT`
|
* `tf.print`:
|
||||||
* Add parameter allow_mixed_precision_on_unconverted_ops to
|
|
||||||
TrtConversionParams.
|
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
|
||||||
|
didn't have the keys sorted, the keys and values were not being printed
|
||||||
|
in accordance with their correct mapping.
|
||||||
|
|
||||||
* Other:
|
* Other:
|
||||||
|
|
||||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||||
and "denylist" where possible. Please see
|
and "denylist" where possible. Please see
|
||||||
https://developers.google.com/style/word-list#blacklist for more
|
https://developers.google.com/style/word-list#blacklist for more
|
||||||
context. <ADD RELEASE NOTES HERE>
|
context.
|
||||||
|
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
|
||||||
|
rollout the new MLIR TPU bridge.
|
||||||
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
## Thanks to our Contributors
|
## Thanks to our Contributors
|
||||||
|
|
||||||
|
@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
|
|||||||
TestDistributedFunctionCancellation(false);
|
TestDistributedFunctionCancellation(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
// TODO(b/170399182): Update test once an alternative to using the function
|
||||||
|
// optimization hook is in place.
|
||||||
|
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
|
||||||
TestDistributedFunctionCancellation(true);
|
TestDistributedFunctionCancellation(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
|||||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||||
|
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes
|
||||||
|
// y = sqrt(inputs[0])
|
||||||
|
// return grad(y, {inputs[0]})
|
||||||
|
Status SqrtGradModel(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs,
|
||||||
|
const GradientRegistry& registry) {
|
||||||
|
TapeVSpace vspace(ctx);
|
||||||
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
|
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||||
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||||
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
|
std::vector<AbstractTensorHandle*> out_grads;
|
||||||
|
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||||
|
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||||
|
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||||
|
/*output_gradients=*/{}, &out_grads,
|
||||||
|
/*build_default_zeros_grads=*/false));
|
||||||
|
for (auto sqrt_output : sqrt_outputs) {
|
||||||
|
sqrt_output->Unref();
|
||||||
|
}
|
||||||
|
outputs[0] = out_grads[0];
|
||||||
|
delete tape;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Computes
|
// Computes
|
||||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||||
// return grad(y, {inputs[0], inputs[1]})
|
// return grad(y, {inputs[0], inputs[1]})
|
||||||
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
|||||||
result_tensor = nullptr;
|
result_tensor = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(CppGradients, TestSqrtGrad) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
AbstractContextPtr ctx;
|
||||||
|
{
|
||||||
|
AbstractContext* ctx_raw = nullptr;
|
||||||
|
Status s =
|
||||||
|
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ctx.reset(ctx_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
GradientRegistry registry;
|
||||||
|
Status s = RegisterGradients(®istry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
// Pseudo-code:
|
||||||
|
//
|
||||||
|
// tape.watch(x)
|
||||||
|
// y = sqrt(x)
|
||||||
|
// outputs = tape.gradient(y, x)
|
||||||
|
std::vector<AbstractTensorHandle*> outputs(1);
|
||||||
|
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||||
|
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
|
||||||
|
TF_Tensor* result_tensor;
|
||||||
|
s = getValue(outputs[0], &result_tensor);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||||
|
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||||
|
outputs[0]->Unref();
|
||||||
|
TF_DeleteTensor(result_tensor);
|
||||||
|
result_tensor = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||||
// Pseudo-code:
|
// Pseudo-code:
|
||||||
//
|
//
|
||||||
|
@ -29,6 +29,7 @@ cc_library(
|
|||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/c:env",
|
"//tensorflow/c:env",
|
||||||
|
"//tensorflow/c:logging",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
"//third_party/hadoop:hdfs",
|
"//third_party/hadoop:hdfs",
|
||||||
|
@ -22,9 +22,9 @@ limitations under the License.
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/synchronization/mutex.h"
|
|
||||||
#include "tensorflow/c/env.h"
|
#include "tensorflow/c/env.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/logging.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
// Implementation of a filesystem for HADOOP environments.
|
// Implementation of a filesystem for HADOOP environments.
|
||||||
@ -148,15 +148,20 @@ class LibHDFS {
|
|||||||
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
||||||
if (hdfs_home != nullptr) {
|
if (hdfs_home != nullptr) {
|
||||||
auto JoinPath = [](std::string home, std::string lib) {
|
auto JoinPath = [](std::string home, std::string lib) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
if (home.back() != '\\') home.push_back('\\');
|
||||||
|
return home + "lib\\native\\" + lib;
|
||||||
|
#else
|
||||||
if (home.back() != '/') home.push_back('/');
|
if (home.back() != '/') home.push_back('/');
|
||||||
return home + "lib/native/" + lib;
|
return home + "lib/native/" + lib;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
||||||
TryLoadAndBind(path.c_str(), &handle_, status);
|
TryLoadAndBind(path.c_str(), &handle_, status);
|
||||||
if (TF_GetCode(status) == TF_OK) {
|
if (TF_GetCode(status) == TF_OK) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
|
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,6 +212,7 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
|||||||
builder, namenode.empty() ? "default" : namenode.c_str());
|
builder, namenode.empty() ? "default" : namenode.c_str());
|
||||||
cacheKey += namenode;
|
cacheKey += namenode;
|
||||||
}
|
}
|
||||||
|
absl::MutexLock l(&hadoop_file->connection_cache_lock);
|
||||||
if (hadoop_file->connection_cache.find(cacheKey) ==
|
if (hadoop_file->connection_cache.find(cacheKey) ==
|
||||||
hadoop_file->connection_cache.end()) {
|
hadoop_file->connection_cache.end()) {
|
||||||
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
||||||
@ -418,17 +424,20 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
|
|||||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_read_only_memory_region {
|
namespace tf_read_only_memory_region {
|
||||||
|
// Hadoop doesn't support Readonly Memory Region
|
||||||
// TODO(vnvo2409): Implement later
|
|
||||||
|
|
||||||
} // namespace tf_read_only_memory_region
|
} // namespace tf_read_only_memory_region
|
||||||
|
|
||||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_hadoop_filesystem {
|
namespace tf_hadoop_filesystem {
|
||||||
|
|
||||||
|
HadoopFile::HadoopFile(TF_Status* status)
|
||||||
|
: libhdfs(new LibHDFS(status)),
|
||||||
|
connection_cache_lock(),
|
||||||
|
connection_cache() {}
|
||||||
|
|
||||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||||
filesystem->plugin_filesystem = new HadoopFile({new LibHDFS(status), {}});
|
filesystem->plugin_filesystem = new HadoopFile(status);
|
||||||
if (TF_GetCode(status) != TF_OK) return;
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
}
|
}
|
||||||
@ -699,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
|||||||
return num_entries;
|
return num_entries;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(vnvo2409): Implement later
|
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||||
|
return strdup(uri);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tf_hadoop_filesystem
|
} // namespace tf_hadoop_filesystem
|
||||||
|
|
||||||
@ -707,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
|||||||
const char* uri) {
|
const char* uri) {
|
||||||
TF_SetFilesystemVersionMetadata(ops);
|
TF_SetFilesystemVersionMetadata(ops);
|
||||||
ops->scheme = strdup(uri);
|
ops->scheme = strdup(uri);
|
||||||
|
|
||||||
|
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||||
|
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||||
|
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||||
|
|
||||||
|
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||||
|
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||||
|
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||||
|
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||||
|
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||||
|
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||||
|
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||||
|
|
||||||
|
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||||
|
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||||
|
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
|
||||||
|
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
|
||||||
|
ops->filesystem_ops->new_random_access_file =
|
||||||
|
tf_hadoop_filesystem::NewRandomAccessFile;
|
||||||
|
ops->filesystem_ops->new_writable_file =
|
||||||
|
tf_hadoop_filesystem::NewWritableFile;
|
||||||
|
ops->filesystem_ops->new_appendable_file =
|
||||||
|
tf_hadoop_filesystem::NewAppendableFile;
|
||||||
|
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||||
|
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||||
|
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
|
||||||
|
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
|
||||||
|
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
|
||||||
|
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
|
||||||
|
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
|
||||||
|
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
|
||||||
|
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
|
||||||
|
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
|
||||||
|
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "third_party/hadoop/hdfs.h"
|
#include "third_party/hadoop/hdfs.h"
|
||||||
@ -47,7 +48,10 @@ void Close(const TF_WritableFile* file, TF_Status* status);
|
|||||||
namespace tf_hadoop_filesystem {
|
namespace tf_hadoop_filesystem {
|
||||||
typedef struct HadoopFile {
|
typedef struct HadoopFile {
|
||||||
LibHDFS* libhdfs;
|
LibHDFS* libhdfs;
|
||||||
std::map<std::string, hdfsFS> connection_cache;
|
absl::Mutex connection_cache_lock;
|
||||||
|
std::map<std::string, hdfsFS> connection_cache
|
||||||
|
ABSL_GUARDED_BY(connection_cache_lock);
|
||||||
|
HadoopFile(TF_Status* status);
|
||||||
} HadoopFile;
|
} HadoopFile;
|
||||||
|
|
||||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||||
|
@ -24,6 +24,7 @@ using std::vector;
|
|||||||
using tensorflow::ops::Conj;
|
using tensorflow::ops::Conj;
|
||||||
using tensorflow::ops::MatMul;
|
using tensorflow::ops::MatMul;
|
||||||
using tensorflow::ops::Mul;
|
using tensorflow::ops::Mul;
|
||||||
|
using tensorflow::ops::SqrtGrad;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
|
|||||||
AbstractTensorHandlePtr exp_;
|
AbstractTensorHandlePtr exp_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SqrtGradientFunction : public GradientFunction {
|
||||||
|
public:
|
||||||
|
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||||
|
sqrt->Ref();
|
||||||
|
}
|
||||||
|
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||||
|
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||||
|
std::string name = "Sqrt_Grad";
|
||||||
|
grad_outputs->resize(1);
|
||||||
|
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||||
|
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
~SqrtGradientFunction() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractTensorHandlePtr sqrt_;
|
||||||
|
};
|
||||||
|
|
||||||
class MatMulGradientFunction : public GradientFunction {
|
class MatMulGradientFunction : public GradientFunction {
|
||||||
public:
|
public:
|
||||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||||
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
|||||||
return new BackwardFunction(gradient_function, default_gradients);
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||||
|
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||||
|
// For ops with a single output, the gradient function is not called if there
|
||||||
|
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||||
|
// grads in this case.
|
||||||
|
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||||
|
return new BackwardFunction(gradient_function, default_gradients);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
|
|
||||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||||
|
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||||
|
|
||||||
} // namespace gradients
|
} // namespace gradients
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||||
|
@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
|||||||
return exp_op->Execute(outputs, &num_retvals);
|
return exp_op->Execute(outputs, &num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||||
|
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||||
|
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||||
|
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||||
|
|
||||||
|
int num_retvals = 1;
|
||||||
|
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
|
|||||||
|
|
||||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status Sqrt(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
|
Status SqrtGrad(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||||
|
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -166,6 +166,8 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/common_runtime/eager:context",
|
"//tensorflow/core/common_runtime/eager:context",
|
||||||
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -20,8 +20,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
|||||||
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
Status Variable::CreateUninitialized(
|
||||||
DataType dtype, TensorShape shape,
|
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||||
absl::optional<std::string> name,
|
absl::optional<std::string> name, const char* raw_device_name,
|
||||||
const char* raw_device_name,
|
const std::vector<std::string>& component_devices,
|
||||||
std::unique_ptr<Variable>* output) {
|
std::unique_ptr<Variable>* output) {
|
||||||
ImmediateTensorHandlePtr handle;
|
ImmediateTensorHandlePtr handle;
|
||||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
|
||||||
ctx, dtype, shape, raw_device_name, &handle));
|
|
||||||
|
|
||||||
|
if (component_devices.empty()) {
|
||||||
|
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||||
|
ctx, dtype, shape, raw_device_name, &handle));
|
||||||
|
output->reset(
|
||||||
|
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||||
|
return Status();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tensorflow::isa<EagerContext>(ctx)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Can only load distributed variables with EagerContext.");
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
|
||||||
|
|
||||||
|
std::vector<TensorHandle*> handles;
|
||||||
|
for (const auto& device : component_devices) {
|
||||||
|
ImmediateTensorHandlePtr handlePtr;
|
||||||
|
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||||
|
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
|
||||||
|
&handlePtr));
|
||||||
|
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
|
||||||
|
return errors::Internal("Returned replica handle has unsupported type.");
|
||||||
|
}
|
||||||
|
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
|
||||||
|
}
|
||||||
|
TensorHandle* packed_handle;
|
||||||
|
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||||
|
std::move(handles), eager_ctx, &packed_handle));
|
||||||
|
// The call to `CreatePackedHandle` incremented the handles' reference count,
|
||||||
|
// which we must now decrement to make the packed handle the owner of those
|
||||||
|
// handles. We can't loop through the `handles` vector because it was
|
||||||
|
// `std::move`d in the call above.
|
||||||
|
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
|
||||||
|
TensorHandle* component;
|
||||||
|
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
|
||||||
|
component->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
handle.reset(packed_handle);
|
||||||
output->reset(
|
output->reset(
|
||||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||||
return Status();
|
return Status();
|
||||||
|
@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible {
|
|||||||
public:
|
public:
|
||||||
// Creates an uninitialized resource variable. Note that a caller must
|
// Creates an uninitialized resource variable. Note that a caller must
|
||||||
// call "assign" to associate a value with the variable.
|
// call "assign" to associate a value with the variable.
|
||||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
static Status CreateUninitialized(
|
||||||
DataType dtype, TensorShape shape,
|
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||||
absl::optional<std::string> name,
|
absl::optional<std::string> name, const char* raw_device_name,
|
||||||
const char* raw_device_name,
|
const std::vector<std::string>& component_devices,
|
||||||
std::unique_ptr<Variable>* output);
|
std::unique_ptr<Variable>* output);
|
||||||
|
|
||||||
// The dtype of the underlying variable.
|
// The dtype of the underlying variable.
|
||||||
DataType dtype();
|
DataType dtype();
|
||||||
|
@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
|||||||
const std::string& name = variable.name();
|
const std::string& name = variable.name();
|
||||||
tensorflow::TensorShape shape(variable.shape());
|
tensorflow::TensorShape shape(variable.shape());
|
||||||
tensorflow::DataType dtype = variable.dtype();
|
tensorflow::DataType dtype = variable.dtype();
|
||||||
|
std::vector<std::string> component_devices;
|
||||||
|
|
||||||
|
for (const auto& component :
|
||||||
|
variable.experimental_distributed_variable_components()) {
|
||||||
|
component_devices.push_back(component.device());
|
||||||
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
||||||
ctx, dtype, shape, name,
|
ctx, dtype, shape, name,
|
||||||
variable.device().empty() ? nullptr : variable.device().c_str(), output));
|
variable.device().empty() ? nullptr : variable.device().c_str(),
|
||||||
|
component_devices, output));
|
||||||
return Status();
|
return Status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
|||||||
Status status;
|
Status status;
|
||||||
std::unique_ptr<Variable> var;
|
std::unique_ptr<Variable> var;
|
||||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||||
absl::nullopt, nullptr, &var));
|
absl::nullopt, nullptr, {}, &var));
|
||||||
|
|
||||||
// Create a TensorHandle
|
// Create a TensorHandle
|
||||||
ImmediateTensorHandlePtr expected_handle =
|
ImmediateTensorHandlePtr expected_handle =
|
||||||
|
@ -127,7 +127,7 @@ def tf_library(
|
|||||||
"$(location " + tfcompile_tool + ")" +
|
"$(location " + tfcompile_tool + ")" +
|
||||||
" --config=$(location " + config + ")" +
|
" --config=$(location " + config + ")" +
|
||||||
" --dump_fetch_nodes > $@"),
|
" --dump_fetch_nodes > $@"),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
# Run tfcompile on the build host, rather than forge, since it's
|
# Run tfcompile on the build host, rather than forge, since it's
|
||||||
# typically way faster on the local machine.
|
# typically way faster on the local machine.
|
||||||
local = 1,
|
local = 1,
|
||||||
@ -242,7 +242,7 @@ def tf_library(
|
|||||||
" --out_function_object=$(@D)/" + function_object_file +
|
" --out_function_object=$(@D)/" + function_object_file +
|
||||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||||
),
|
),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
# Run tfcompile on the build host since it's typically faster on the
|
# Run tfcompile on the build host since it's typically faster on the
|
||||||
@ -281,7 +281,7 @@ def tf_library(
|
|||||||
" --out_session_module=$(@D)/" + session_module_pb +
|
" --out_session_module=$(@D)/" + session_module_pb +
|
||||||
" " + flags
|
" " + flags
|
||||||
),
|
),
|
||||||
tools = [tfcompile_tool],
|
exec_tools = [tfcompile_tool],
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
local = 1,
|
local = 1,
|
||||||
|
@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
|
||||||
|
const Node& node, absl::string_view attr_name,
|
||||||
|
absl::string_view call_name) {
|
||||||
|
std::vector<NameAttrList> attr_lists;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
|
||||||
|
|
||||||
|
std::vector<NodeDef> out;
|
||||||
|
for (int i = 0; i < attr_lists.size(); i++) {
|
||||||
|
out.emplace_back();
|
||||||
|
NodeDef& inserted = out.back();
|
||||||
|
inserted.set_name(absl::StrCat(call_name, "_", i));
|
||||||
|
inserted.set_op(attr_lists[i].name());
|
||||||
|
*inserted.mutable_attr() = attr_lists[i].attr();
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
// Utility which searches for values in a sorted list by scanning over it once.
|
// Utility which searches for values in a sorted list by scanning over it once.
|
||||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||||
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
|
|||||||
return is_compilable;
|
return is_compilable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool RecursiveCompilabilityChecker::IsCompilableCase(
|
||||||
|
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
|
||||||
|
std::vector<StackFrameView>* stack_trace,
|
||||||
|
NameAttrList* encapsulating_function,
|
||||||
|
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||||
|
const {
|
||||||
|
xla::StatusOr<std::vector<NodeDef>> calls =
|
||||||
|
MakeCallNodesFromAttribute(case_node, "branches", "branch");
|
||||||
|
if (!calls.ok()) {
|
||||||
|
VLOG(2) << "Rejecting node " << case_node.name() << ": "
|
||||||
|
<< "missing attribute 'branches'";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_compilable = true;
|
||||||
|
|
||||||
|
for (const NodeDef& call : *calls) {
|
||||||
|
is_compilable &=
|
||||||
|
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
|
||||||
|
uncompilable_nodes);
|
||||||
|
}
|
||||||
|
return is_compilable;
|
||||||
|
}
|
||||||
|
|
||||||
// Tests whether 'while_node' is a completely compilable loop.
|
// Tests whether 'while_node' is a completely compilable loop.
|
||||||
// Every operator in the condition and body functions must be compilable for a
|
// Every operator in the condition and body functions must be compilable for a
|
||||||
// while loop to be compilable.
|
// while loop to be compilable.
|
||||||
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
|
||||||
|
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
|
||||||
|
uncompilable_nodes)) {
|
||||||
|
LogNotCompilable(node, "unsupported case");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (!op_filter_.allow_stateful_rng_ops &&
|
if (!op_filter_.allow_stateful_rng_ops &&
|
||||||
IsStatefulRandomOp(node.type_string())) {
|
IsStatefulRandomOp(node.type_string())) {
|
||||||
absl::string_view uncompilable_reason = "stateful random op";
|
absl::string_view uncompilable_reason = "stateful random op";
|
||||||
|
@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
|
|||||||
// Whether ops known to have numerical accuracy issues should be considered
|
// Whether ops known to have numerical accuracy issues should be considered
|
||||||
// compilable..
|
// compilable..
|
||||||
bool allow_inaccurate_ops = false;
|
bool allow_inaccurate_ops = false;
|
||||||
|
|
||||||
|
// Require the function to be always compilable, regardless whether some
|
||||||
|
// control flow branches might be dead for a given input.
|
||||||
|
bool require_always_compilable = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
||||||
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
|
|||||||
NameAttrList* encapsulating_function,
|
NameAttrList* encapsulating_function,
|
||||||
UncompilableNodesMap* uncompilable_nodes) const;
|
UncompilableNodesMap* uncompilable_nodes) const;
|
||||||
|
|
||||||
|
// Tests whether 'case_node' is compilable. Every operator in all branches
|
||||||
|
// must be compilable.
|
||||||
|
bool IsCompilableCase(const Node& case_node,
|
||||||
|
FunctionLibraryRuntime* lib_runtime,
|
||||||
|
std::vector<StackFrameView>* stack_trace,
|
||||||
|
NameAttrList* encapsulating_function,
|
||||||
|
UncompilableNodesMap* uncompilable_nodes) const;
|
||||||
|
|
||||||
// Returns compilability of node def retrieved from `node`'s attribute with
|
// Returns compilability of node def retrieved from `node`'s attribute with
|
||||||
// name `attr_name`.
|
// name `attr_name`.
|
||||||
bool ExtractNodeDefAndCheckCompilability(
|
bool ExtractNodeDefAndCheckCompilability(
|
||||||
|
@ -34,7 +34,16 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
|
||||||
|
AttrValue attr;
|
||||||
|
for (const char* name : names) {
|
||||||
|
attr.mutable_list()->add_func()->set_name(name);
|
||||||
|
}
|
||||||
|
return attr;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr char kFunctionalIfNodeName[] = "If";
|
constexpr char kFunctionalIfNodeName[] = "If";
|
||||||
|
constexpr char kFunctionalCaseNodeName[] = "Case";
|
||||||
constexpr char kFunctionalWhileNodeName[] = "While";
|
constexpr char kFunctionalWhileNodeName[] = "While";
|
||||||
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
||||||
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
||||||
@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
|
|||||||
op_filter_.allow_inaccurate_ops = false;
|
op_filter_.allow_inaccurate_ops = false;
|
||||||
op_filter_.allow_slow_ops = false;
|
op_filter_.allow_slow_ops = false;
|
||||||
|
|
||||||
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
checker_ = CreateCompilabilityChecker();
|
||||||
device_type_);
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
|
||||||
|
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||||
|
device_type_);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
||||||
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
|||||||
"unsupported op"));
|
"unsupported op"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
|
||||||
|
FunctionDefLibrary flib;
|
||||||
|
*flib.add_function() = FunctionDefHelper::Define(
|
||||||
|
/*Function*/ kUncompilableFunctionName,
|
||||||
|
/*Inputs*/ {"n_a:float"},
|
||||||
|
/*Outputs*/ {"n_c_uncompilable:float"},
|
||||||
|
/*Attributes*/ {},
|
||||||
|
// Node info
|
||||||
|
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
|
||||||
|
*flib.add_function() = FunctionDefHelper::Define(
|
||||||
|
/*Function*/ kUncompilableFunctionTwoName,
|
||||||
|
/*Inputs*/ {"n_a:float"},
|
||||||
|
/*Outputs*/ {"n_d_uncompilable:float"},
|
||||||
|
/*Attribute*/ {},
|
||||||
|
// Node info
|
||||||
|
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
|
||||||
|
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
|
||||||
|
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
|
||||||
|
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputes(
|
||||||
|
{NodeBuilder::NodeOut(placeholder.node())});
|
||||||
|
Node* case_node;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
|
||||||
|
.Input(branch_index.node())
|
||||||
|
.Input(inputes)
|
||||||
|
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
|
||||||
|
kUncompilableFunctionTwoName}))
|
||||||
|
.Attr("Tout", {DT_INT32})
|
||||||
|
.Finalize(root.graph(), &case_node));
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||||
|
|
||||||
|
auto case_node_it = std::find_if(
|
||||||
|
graph->nodes().begin(), graph->nodes().end(),
|
||||||
|
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
|
||||||
|
EXPECT_NE(case_node_it, graph->nodes().end());
|
||||||
|
auto* flib_runtime = GetFunctionLibraryRuntime();
|
||||||
|
|
||||||
|
op_filter_.require_always_compilable = false;
|
||||||
|
checker_ = CreateCompilabilityChecker();
|
||||||
|
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||||
|
op_filter_.require_always_compilable = true;
|
||||||
|
checker_ = CreateCompilabilityChecker();
|
||||||
|
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
||||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||||
Scope root = Scope::NewRootScope().ExitOnError();
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RecursiveCompilabilityChecker{
|
RecursiveCompilabilityChecker::OperationFilter filter =
|
||||||
CreateOperationFilter(*registration),
|
CreateOperationFilter(*registration);
|
||||||
DeviceType{registration->compilation_device_name}}
|
filter.require_always_compilable = true;
|
||||||
.IsCompilableNode(*node, lib_runtime)) {
|
|
||||||
|
RecursiveCompilabilityChecker checker(
|
||||||
|
filter, DeviceType{registration->compilation_device_name});
|
||||||
|
|
||||||
|
if (!checker.IsCompilableNode(*node, lib_runtime)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
|
std::vector<std::string> control_rets;
|
||||||
|
if (result_dtypes.empty()) {
|
||||||
|
control_rets.push_back(node_def.name());
|
||||||
|
}
|
||||||
return CompileGraphToXlaHlo(
|
return CompileGraphToXlaHlo(
|
||||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
|
||||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||||
#endif
|
#endif
|
||||||
|
@ -29,7 +29,7 @@ LLVM_SRC=...
|
|||||||
|
|
||||||
# Create basic workspace file
|
# Create basic workspace file
|
||||||
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
|
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
|
||||||
# and over the bazel BUILD files.
|
# and copy over the bazel BUILD files.
|
||||||
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
|
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
|
||||||
cp third_party/mlir/BUILD $LLVM_SRC/mlir
|
cp third_party/mlir/BUILD $LLVM_SRC/mlir
|
||||||
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD
|
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD
|
||||||
|
@ -48,6 +48,7 @@ filegroup(
|
|||||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
||||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
||||||
|
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
|
||||||
"@llvm-project//mlir:OpBaseTdFiles",
|
"@llvm-project//mlir:OpBaseTdFiles",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
||||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||||
|
@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Atan operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Atan(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\atan(x) = \atan2(x, 1)
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Sinh operation";
|
let summary = "Sinh operation";
|
||||||
|
@ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
|||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
||||||
|
|
||||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
||||||
@ -910,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
|||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(hinsu): Make this struct dialect independent so that it can be shared
|
|
||||||
// between HLO and LHLO dialect.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
HLO_Tensor:$lhs,
|
(ins
|
||||||
HLO_Tensor:$rhs,
|
HLO_Tensor:$lhs,
|
||||||
// Default value: one for each of the spatial dimension.
|
HLO_Tensor:$rhs),
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
ConvolutionAttributes<HLO_Dialect>.attributes);
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
@ -1007,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common convolution attributes
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class ConvDimensionNumbersBase<Dialect dialect>
|
||||||
|
: StructAttr<"ConvDimensionNumbers", dialect, [
|
||||||
|
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||||
|
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||||
|
|
||||||
|
let description = "Structure of dimension information for conv op";
|
||||||
|
}
|
||||||
|
|
||||||
|
class ConvolutionAttributes<Dialect dialect> {
|
||||||
|
dag attributes = (ins
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||||
|
// Default value: zero for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$padding,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||||
|
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
|
||||||
|
I64Attr:$feature_group_count,
|
||||||
|
I64Attr:$batch_group_count,
|
||||||
|
HLO_PrecisionConfigAttr:$precision_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_ConvOp {
|
class BASE_HLO_ConvOp {
|
||||||
string summary = "Convolution operator";
|
string summary = "Convolution operator";
|
||||||
|
|
||||||
|
@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
|
|||||||
include "mlir/Interfaces/CopyOpInterface.td"
|
include "mlir/Interfaces/CopyOpInterface.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||||
|
|
||||||
def LHLO_Dialect : Dialect {
|
def LHLO_Dialect : Dialect {
|
||||||
let name = "lmhlo";
|
let name = "lmhlo";
|
||||||
let cppNamespace = "::mlir::lmhlo";
|
let cppNamespace = "::mlir::lmhlo";
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// LMHLO type definitions.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Any integer tensor types
|
|
||||||
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
|
||||||
|
|
||||||
// Any floating-point tensor types
|
|
||||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
|
||||||
|
|
||||||
// Any integer or floating-point tensor types
|
|
||||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
|
||||||
|
|
||||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
|
||||||
|
|
||||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
|
||||||
|
|
||||||
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LMHLO nullary op definitions.
|
// LMHLO nullary op definitions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -345,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
|||||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
"modifies the offset, sizes and strides of a statically shaped memref.
|
modifies the offset, sizes and strides of a statically shaped memref
|
||||||
}];
|
}];
|
||||||
let description = [{
|
let description = [{
|
||||||
Allows to modify the offset, sizes and strides of a statically shaped memref.
|
Casts the statically shaped memref operand to a memref with optionally
|
||||||
|
modified offsets, sizes and strides.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```mlir
|
```mlir
|
||||||
@ -592,40 +568,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
|
||||||
// shared between the HLO and LHLO dialects.
|
|
||||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
|
||||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
|
||||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
|
||||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
|
||||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
|
||||||
|
|
||||||
let description = "Structure of dimension information for conv op";
|
|
||||||
}
|
|
||||||
|
|
||||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||||
let arguments = (ins
|
let arguments = !con(
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
(ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
// Default value: one for each of the spatial dimension.
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
|
||||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
ConvolutionAttributes<LHLO_Dialect>.attributes);
|
||||||
// Default value: zero for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$padding,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
|
||||||
// Default value: one for each of the spatial dimension.
|
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
|
||||||
I64Attr:$feature_group_count,
|
|
||||||
I64Attr:$batch_group_count,
|
|
||||||
HLO_PrecisionConfigAttr:$precision_config
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef LHLO_OPS_BASE
|
||||||
|
#define LHLO_OPS_BASE
|
||||||
|
|
||||||
|
include "mlir/IR/OpBase.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LMHLO type definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Any integer tensor types
|
||||||
|
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||||
|
|
||||||
|
// Any floating-point tensor types
|
||||||
|
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||||
|
|
||||||
|
// Any integer or floating-point tensor types
|
||||||
|
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||||
|
|
||||||
|
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
||||||
|
|
||||||
|
#endif // LHLO_OPS_BASE
|
@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
|||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename PredicateType>
|
template <typename PredicateType>
|
||||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
|
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
|
||||||
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
|
let summary = "Legalize CHLO to HLO.";
|
||||||
let constructor = "createTestChloLegalizeToHloPass()";
|
let constructor = "createChloLegalizeToHloPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||||
|
@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
|||||||
/// Lowers from HLO dialect to Standard dialect.
|
/// Lowers from HLO dialect to Standard dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
|
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||||
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
|
||||||
|
|
||||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
/// buffers if necessary. If `results_escape_functions` is set to true,
|
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||||
/// allocated buffers for function results will be returned and escape the
|
/// allocated buffers for function results will be returned and escape the
|
||||||
@ -63,7 +66,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
|||||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
||||||
|
|
||||||
/// Lowers trigonometric operations from the standard dialect to approximations
|
/// Lowers trigonometric operations from the standard dialect to approximations
|
||||||
// that do not use intrinsics.
|
/// that do not use intrinsics.
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
createLegalizeTrigonometricToApproximationPass();
|
createLegalizeTrigonometricToApproximationPass();
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
|
|
||||||
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
||||||
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
||||||
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
||||||
|
@ -2001,6 +2001,23 @@ struct divide<APInt> {
|
|||||||
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct remainder : std::modulus<T> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APInt> {
|
||||||
|
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct remainder<APFloat> {
|
||||||
|
APFloat operator()(const APFloat& a, const APFloat& b) const {
|
||||||
|
APFloat result(a);
|
||||||
|
result.remainder(b);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct max {
|
struct max {
|
||||||
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
||||||
@ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
|
|||||||
BINARY_FOLDER(SubOp, std::minus);
|
BINARY_FOLDER(SubOp, std::minus);
|
||||||
BINARY_FOLDER(MulOp, std::multiplies);
|
BINARY_FOLDER(MulOp, std::multiplies);
|
||||||
BINARY_FOLDER(DivOp, divide);
|
BINARY_FOLDER(DivOp, divide);
|
||||||
|
BINARY_FOLDER(RemOp, remainder);
|
||||||
BINARY_FOLDER(MaxOp, max);
|
BINARY_FOLDER(MaxOp, max);
|
||||||
BINARY_FOLDER(MinOp, min);
|
BINARY_FOLDER(MinOp, min);
|
||||||
|
|
||||||
|
@ -27,8 +27,8 @@ namespace mhlo {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct TestChloLegalizeToHloPass
|
struct ChloLegalizeToHloPass
|
||||||
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
|
: public PassWrapper<ChloLegalizeToHloPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
@ -36,11 +36,12 @@ struct TestChloLegalizeToHloPass
|
|||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
ConversionTarget conversionTarget(getContext());
|
ConversionTarget conversionTarget(getContext());
|
||||||
OwningRewritePatternList conversionPatterns;
|
OwningRewritePatternList conversionPatterns;
|
||||||
|
|
||||||
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
||||||
|
|
||||||
// Consider the mhlo dialect legal for tests.
|
// Consider the mhlo dialect legal for tests.
|
||||||
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
||||||
// The conversion uses helpers from the Standard dialect.
|
|
||||||
|
// The conversion uses helpers from the standard dialect.
|
||||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||||
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
||||||
@ -56,8 +57,8 @@ struct TestChloLegalizeToHloPass
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() {
|
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() {
|
||||||
return std::make_unique<TestChloLegalizeToHloPass>();
|
return std::make_unique<ChloLegalizeToHloPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
|
@ -24,16 +24,17 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Expand acos to MHLO dialect as follows:
|
// Expand acos to MHLO dialect as follows:
|
||||||
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
|
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
|
||||||
// = pi if x == -1
|
// = pi if x == -1
|
||||||
def : Pat<(HLOClient_AcosOp $input),
|
def : Pat<(HLOClient_AcosOp $input),
|
||||||
(HLO_SelectOp
|
(HLO_SelectOp
|
||||||
(HLO_CompareOp $input,
|
(HLO_CompareOp
|
||||||
(HLO_ConstantLike<"0"> $input),
|
$input,
|
||||||
|
(HLO_ConstantLike<"-1"> $input),
|
||||||
HLO_COMPARISON_DIRECTION_NE
|
HLO_COMPARISON_DIRECTION_NE
|
||||||
),
|
),
|
||||||
(HLO_MulOp
|
(HLO_MulOp
|
||||||
(HLO_ConstantLike<"2.0f"> $input),
|
(HLO_ConstantLike<"2"> $input),
|
||||||
(HLO_Atan2Op
|
(HLO_Atan2Op
|
||||||
(HLO_SqrtOp
|
(HLO_SqrtOp
|
||||||
(HLO_SubOp
|
(HLO_SubOp
|
||||||
@ -47,7 +48,16 @@ def : Pat<(HLOClient_AcosOp $input),
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
(HLO_ConstantLike<"M_PI"> $input))>;
|
(HLO_ConstantLike<"M_PI"> $input)
|
||||||
|
)>;
|
||||||
|
|
||||||
|
// Express `atan` as
|
||||||
|
// atan(x) = atan2(x, 1)
|
||||||
|
def : Pat<(HLOClient_AtanOp $input),
|
||||||
|
(HLO_Atan2Op
|
||||||
|
$input,
|
||||||
|
(HLO_ConstantLike<"1"> $input)
|
||||||
|
)>;
|
||||||
|
|
||||||
// Express `sinh` as
|
// Express `sinh` as
|
||||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||||
@ -95,4 +105,3 @@ def : Pat<(HLOClient_TanOp $input),
|
|||||||
(HLO_SinOp $input),
|
(HLO_SinOp $input),
|
||||||
(HLO_CosOp $input)
|
(HLO_CosOp $input)
|
||||||
)>;
|
)>;
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
|||||||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||||
Value shape_operand,
|
Value shape_operand,
|
||||||
ConversionPatternRewriter* rewriter) {
|
ConversionPatternRewriter* rewriter) {
|
||||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!result_type) {
|
if (!result_type) {
|
||||||
result.getDefiningOp()->emitOpError()
|
result.getDefiningOp()->emitOpError()
|
||||||
<< "tensor to buffer conversion expects ranked results";
|
<< "tensor to buffer conversion expects ranked results";
|
||||||
@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
|||||||
auto memref_type =
|
auto memref_type =
|
||||||
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||||
|
|
||||||
Operation* op = result.getDefiningOp();
|
|
||||||
|
|
||||||
// Extract the required element out of the vector.
|
// Extract the required element out of the vector.
|
||||||
SmallVector<Value, 4> dynamic_operands;
|
SmallVector<Value, 4> dynamic_operands;
|
||||||
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
|
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
|
||||||
if (shape_element.value() != ShapedType::kDynamicSize) continue;
|
if (shape_element.value() != ShapedType::kDynamicSize) continue;
|
||||||
Value index = rewriter->create<ConstantOp>(
|
Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
|
||||||
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
|
Value alloc_operand =
|
||||||
shape_element.index()));
|
rewriter->create<ExtractElementOp>(loc, shape_operand, index);
|
||||||
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
|
|
||||||
ValueRange{index});
|
|
||||||
if (!alloc_operand.getType().isIndex()) {
|
if (!alloc_operand.getType().isIndex()) {
|
||||||
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
|
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
|
||||||
rewriter->getIndexType());
|
rewriter->getIndexType());
|
||||||
@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
|||||||
dynamic_operands.push_back(alloc_operand);
|
dynamic_operands.push_back(alloc_operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert in front of op to ensure sizes are available.
|
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||||
OpBuilder allocBuilder(op);
|
|
||||||
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
|
|
||||||
return alloc;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value InsertAlloc(Location loc, OpResult result,
|
Value InsertAlloc(Location loc, OpResult result,
|
||||||
ConversionPatternRewriter* rewriter) {
|
ConversionPatternRewriter* rewriter) {
|
||||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!result_type || !result_type.hasStaticShape()) {
|
if (!result_type || !result_type.hasStaticShape()) {
|
||||||
result.getDefiningOp()->emitOpError()
|
result.getDefiningOp()->emitOpError()
|
||||||
<< "tensor to buffer conversion expects statically shaped results";
|
<< "tensor to buffer conversion expects statically shaped results";
|
||||||
@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
|||||||
buffer_args.push_back(
|
buffer_args.push_back(
|
||||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||||
} else {
|
} else {
|
||||||
SmallVector<Value, 1> results_shape;
|
|
||||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
if (!shape_type_op) return failure();
|
if (!shape_type_op) return failure();
|
||||||
if (failed(
|
|
||||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
SmallVector<Value, 1> results_shape;
|
||||||
return failure();
|
auto status =
|
||||||
|
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||||
|
if (failed(status)) return failure();
|
||||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||||
buffer_args, op->getAttrs());
|
buffer_args, op->getAttrs());
|
||||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
rewriter.replaceOp(
|
||||||
|
op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -32,8 +32,6 @@ limitations under the License.
|
|||||||
#include "mlir/Pass/PassRegistry.h"
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
using mlir::PassRegistration;
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
|
||||||
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
||||||
@ -932,6 +933,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||||
|
@ -235,6 +235,23 @@ class ApproximateAtan2Lowering
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ApproximateAtanLowering
|
||||||
|
: public ApproximateOnExtendedF32Lowering<AtanOp> {
|
||||||
|
public:
|
||||||
|
explicit ApproximateAtanLowering(MLIRContext *ctx)
|
||||||
|
: ApproximateOnExtendedF32Lowering<AtanOp>(ctx) {}
|
||||||
|
|
||||||
|
// Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation
|
||||||
|
// for the argument range [-1, 1].
|
||||||
|
Value emitApproximation(ValueRange args, Location loc,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value x = args.front();
|
||||||
|
assert(x.getType().isF32());
|
||||||
|
Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1));
|
||||||
|
return rewriter.create<Atan2Op>(loc, x, one);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct LegalizeTrigonometricToApproximationPass
|
struct LegalizeTrigonometricToApproximationPass
|
||||||
: public PassWrapper<LegalizeTrigonometricToApproximationPass,
|
: public PassWrapper<LegalizeTrigonometricToApproximationPass,
|
||||||
FunctionPass> {
|
FunctionPass> {
|
||||||
@ -257,6 +274,7 @@ void PopulateTrigonometricToApproximationPatterns(
|
|||||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
|
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
|
ApproximateAtanLowering,
|
||||||
ApproximateAtan2Lowering,
|
ApproximateAtan2Lowering,
|
||||||
ApproximateTanhLowering>(context);
|
ApproximateTanhLowering>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
@ -37,7 +37,6 @@ limitations under the License.
|
|||||||
|
|
||||||
using mlir::FunctionPass;
|
using mlir::FunctionPass;
|
||||||
using mlir::OwningRewritePatternList;
|
using mlir::OwningRewritePatternList;
|
||||||
using mlir::PassRegistration;
|
|
||||||
using mlir::PassWrapper;
|
using mlir::PassWrapper;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -38,7 +38,6 @@ using mlir::LogicalResult;
|
|||||||
using mlir::MLIRContext;
|
using mlir::MLIRContext;
|
||||||
using mlir::OpRewritePattern;
|
using mlir::OpRewritePattern;
|
||||||
using mlir::OwningRewritePatternList;
|
using mlir::OwningRewritePatternList;
|
||||||
using mlir::PassRegistration;
|
|
||||||
using mlir::PassWrapper;
|
using mlir::PassWrapper;
|
||||||
using mlir::PatternRewriter;
|
using mlir::PatternRewriter;
|
||||||
using mlir::RankedTensorType;
|
using mlir::RankedTensorType;
|
||||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
using mlir::FunctionPass;
|
using mlir::FunctionPass;
|
||||||
using mlir::PassRegistration;
|
|
||||||
using mlir::PassWrapper;
|
using mlir::PassWrapper;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -48,7 +48,7 @@ namespace {
|
|||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
fn(TanOp) sep fn(AcosOp) sep fn(SinhOp)
|
fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> {
|
|||||||
return %2 : tensor<4xf64>
|
return %2 : tensor<4xf64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: remainder_fold_int
|
||||||
|
func @remainder_fold_int() -> tensor<4xi32> {
|
||||||
|
%0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32>
|
||||||
|
%1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[2, 1, 0, 1]>
|
||||||
|
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
|
||||||
|
return %2 : tensor<4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: remainder_fold_float
|
||||||
|
func @remainder_fold_float() -> tensor<4xf32> {
|
||||||
|
%0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32>
|
||||||
|
%1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32>
|
||||||
|
// CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]>
|
||||||
|
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
|
||||||
|
return %2 : tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: max_scalar_fold
|
// CHECK-LABEL: max_scalar_fold
|
||||||
func @max_scalar_fold() -> tensor<4xi64> {
|
func @max_scalar_fold() -> tensor<4xi64> {
|
||||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||||
|
|
||||||
// Check the non-broadcast case for each registered op, then just check a
|
// Check the non-broadcast case for each registered op, then just check a
|
||||||
// representative op for detailed broadcast semantics.
|
// representative op for detailed broadcast semantics.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
// Lower statically shaped `constant_like` to constant.
|
// Lower statically shaped `constant_like` to constant.
|
||||||
// CHECK-LABEL: @constant_like_static_shape
|
// CHECK-LABEL: @constant_like_static_shape
|
||||||
|
@ -261,3 +261,120 @@ func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 {
|
|||||||
%res = atan2 %arg0, %arg1 : f16
|
%res = atan2 %arg0, %arg1 : f16
|
||||||
return %res : f16
|
return %res : f16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @atan_f64
|
||||||
|
func @atan_f64(%arg : f64) -> f64 {
|
||||||
|
// CHECK: atan
|
||||||
|
%res = atan %arg : f64
|
||||||
|
return %res : f64
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @atan_f32
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32
|
||||||
|
func @atan_f32(%arg : f32) -> f32 {
|
||||||
|
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
|
||||||
|
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
|
||||||
|
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
|
||||||
|
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
|
||||||
|
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
|
||||||
|
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
|
||||||
|
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
|
||||||
|
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
|
||||||
|
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
|
||||||
|
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
|
||||||
|
// CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32
|
||||||
|
// CHECK: %[[VAL_1:.*]] = absf %arg0 : f32
|
||||||
|
// CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32
|
||||||
|
// CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32
|
||||||
|
// CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32
|
||||||
|
// CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32
|
||||||
|
// CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32
|
||||||
|
// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32
|
||||||
|
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32
|
||||||
|
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32
|
||||||
|
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32
|
||||||
|
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32
|
||||||
|
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32
|
||||||
|
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32
|
||||||
|
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
|
||||||
|
// CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32
|
||||||
|
// CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32
|
||||||
|
// CHECK: %[[VAL_26:.*]] = cmpf "oeq", %arg0, %[[CST_9]] : f32
|
||||||
|
// CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32
|
||||||
|
// CHECK: %[[VAL_28:.*]] = cmpf "uno", %arg0, %[[CST]] : f32
|
||||||
|
// CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32
|
||||||
|
// CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32
|
||||||
|
// CHECK: return %[[VAL_30]] : f32
|
||||||
|
%res = atan %arg : f32
|
||||||
|
return %res : f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @atan_f16
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16
|
||||||
|
func @atan_f16(%arg : f16) -> f16 {
|
||||||
|
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
|
||||||
|
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
|
||||||
|
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
|
||||||
|
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
|
||||||
|
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
|
||||||
|
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
|
||||||
|
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
|
||||||
|
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
|
||||||
|
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
|
||||||
|
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
|
||||||
|
// CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32
|
||||||
|
// CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32
|
||||||
|
// CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32
|
||||||
|
// CHECK: %[[VAL_3:.*]] = cmpf "ole", %[[VAL_1]], %[[VAL_2]] : f32
|
||||||
|
// CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32
|
||||||
|
// CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32
|
||||||
|
// CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32
|
||||||
|
// CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32
|
||||||
|
// CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32
|
||||||
|
// CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32
|
||||||
|
// CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32
|
||||||
|
// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32
|
||||||
|
// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32
|
||||||
|
// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32
|
||||||
|
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32
|
||||||
|
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32
|
||||||
|
// CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32
|
||||||
|
// CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32
|
||||||
|
// CHECK: %[[VAL_27:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_9]] : f32
|
||||||
|
// CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32
|
||||||
|
// CHECK: %[[VAL_29:.*]] = cmpf "uno", %[[VAL_0]], %[[CST]] : f32
|
||||||
|
// CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32
|
||||||
|
// CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32
|
||||||
|
// CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16
|
||||||
|
// CHECK: return %[[VAL_32]] : f16
|
||||||
|
%res = atan %arg : f16
|
||||||
|
return %res : f16
|
||||||
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
|
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @add
|
// CHECK-LABEL: @add
|
||||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||||
|
@ -477,7 +477,6 @@ cc_library(
|
|||||||
"transforms/default_quant_params.cc",
|
"transforms/default_quant_params.cc",
|
||||||
"transforms/generated_post_quantize.inc",
|
"transforms/generated_post_quantize.inc",
|
||||||
"transforms/generated_quantize.inc",
|
"transforms/generated_quantize.inc",
|
||||||
"transforms/load_quantization_recipe.cc",
|
|
||||||
"transforms/post_quantize.cc",
|
"transforms/post_quantize.cc",
|
||||||
"transforms/prepare_quantize.cc",
|
"transforms/prepare_quantize.cc",
|
||||||
"transforms/quantize.cc",
|
"transforms/quantize.cc",
|
||||||
@ -670,6 +669,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib",
|
"//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib",
|
||||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/schema:schema_utils",
|
||||||
"//tensorflow/lite/tools/versioning",
|
"//tensorflow/lite/tools/versioning",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
@ -706,6 +706,7 @@ cc_library(
|
|||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core/platform:status",
|
||||||
"//tensorflow/lite:framework",
|
"//tensorflow/lite:framework",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
"//tensorflow/lite/schema:schema_utils",
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
@ -75,6 +75,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_utils.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||||
#include "tensorflow/lite/tools/versioning/runtime_version.h"
|
#include "tensorflow/lite/tools/versioning/runtime_version.h"
|
||||||
|
@ -75,6 +75,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_utils.h"
|
||||||
|
|
||||||
using llvm::ArrayRef;
|
using llvm::ArrayRef;
|
||||||
using mlir::Builder;
|
using mlir::Builder;
|
||||||
@ -271,18 +272,18 @@ StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
|
|||||||
return std::string("tfl.basic_lstm");
|
return std::string("tfl.basic_lstm");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
auto builtin_code = tflite::GetBuiltinCode(&op_code);
|
||||||
|
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||||
return std::string("tfl.custom");
|
return std::string("tfl.custom");
|
||||||
}
|
}
|
||||||
if (op_code.builtin_code == tflite::BuiltinOperator_IF) {
|
if (builtin_code == tflite::BuiltinOperator_IF) {
|
||||||
return std::string("tf.If");
|
return std::string("tf.If");
|
||||||
}
|
}
|
||||||
if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) {
|
if (builtin_code == tflite::BuiltinOperator_WHILE) {
|
||||||
return std::string("tf.While");
|
return std::string("tf.While");
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::StringRef op_name(
|
llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
|
||||||
tflite::EnumNameBuiltinOperator(op_code.builtin_code));
|
|
||||||
return llvm::Twine("tfl.", op_name.lower()).str();
|
return llvm::Twine("tfl.", op_name.lower()).str();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,7 +638,8 @@ StatusOr<Operation*> ConvertOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||||
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
auto builtin_code = tflite::GetBuiltinCode(&op_code);
|
||||||
|
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||||
auto status = mlir::CustomOptionsToAttributes(
|
auto status = mlir::CustomOptionsToAttributes(
|
||||||
op_code.custom_code, op.custom_options, builder, loc, &attrs);
|
op_code.custom_code, op.custom_options, builder, loc, &attrs);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
|
@ -459,11 +459,13 @@ node {
|
|||||||
# CHECK-LABEL: {
|
# CHECK-LABEL: {
|
||||||
# CHECK: version: 3,
|
# CHECK: version: 3,
|
||||||
# CHECK: operator_codes: [ {
|
# CHECK: operator_codes: [ {
|
||||||
# CHECK: builtin_code: CONV_2D,
|
# CHECK: deprecated_builtin_code: 3,
|
||||||
# CHECK: version: 3
|
# CHECK: version: 3,
|
||||||
|
# CHECK: builtin_code: CONV_2D
|
||||||
# CHECK: }, {
|
# CHECK: }, {
|
||||||
# CHECK: builtin_code: RESHAPE,
|
# CHECK: deprecated_builtin_code: 22,
|
||||||
# CHECK: version: 1
|
# CHECK: version: 1
|
||||||
|
# CHECK: builtin_code: RESHAPE
|
||||||
# CHECK: } ],
|
# CHECK: } ],
|
||||||
# CHECK: subgraphs: [ {
|
# CHECK: subgraphs: [ {
|
||||||
# CHECK: tensors: [ {
|
# CHECK: tensors: [ {
|
||||||
|
@ -4,8 +4,9 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: LSTM,
|
// CHECK-NEXT: deprecated_builtin_code: 16,
|
||||||
// CHECK-NEXT: version: 2
|
// CHECK-NEXT: version: 2
|
||||||
|
// CHECK-NEXT: builtin_code: LSTM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -6,14 +6,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: MUL,
|
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1
|
||||||
|
// CHECK-NEXT: builtin_code: MUL
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||||
// CHECK-NEXT: custom_code: "MyCustomOp"
|
// CHECK-NEXT: custom_code: "MyCustomOp",
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: EXP,
|
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: EXP
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: DEQUANTIZE,
|
// CHECK-NEXT: deprecated_builtin_code: 6,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1
|
||||||
|
// CHECK-NEXT: builtin_code: DEQUANTIZE
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D,
|
// CHECK-NEXT: deprecated_builtin_code: 4,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1
|
||||||
|
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: DEQUANTIZE,
|
// CHECK-NEXT: deprecated_builtin_code: 6,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: DEQUANTIZE
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D,
|
// CHECK-NEXT: deprecated_builtin_code: 4,
|
||||||
// CHECK-NEXT: version: 2
|
// CHECK-NEXT: version: 2,
|
||||||
|
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,11 +5,13 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: MUL,
|
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: MUL
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: EXP,
|
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: EXP
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -6,8 +6,9 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: FAKE_QUANT,
|
// CHECK-NEXT: deprecated_builtin_code: 80,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: FAKE_QUANT
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -4,8 +4,9 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||||
// CHECK-NEXT: custom_code: "FlexAddV2"
|
// CHECK-NEXT: custom_code: "FlexAddV2"
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
// CHECK-NEXT: custom_code: "FlexAdd",
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
// CHECK-NEXT: custom_code: "FlexAdd",
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,14 +5,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
|
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||||
|
// CHECK-NEXT: version: 1,
|
||||||
// CHECK-NEXT: builtin_code: MUL
|
// CHECK-NEXT: builtin_code: MUL
|
||||||
// CHECK-NEXT: version: 1
|
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||||
// CHECK-NEXT: custom_code: "FlexDiv"
|
// CHECK-NEXT: custom_code: "FlexDiv",
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
|
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||||
|
// CHECK-NEXT: version: 1,
|
||||||
// CHECK-NEXT: builtin_code: EXP
|
// CHECK-NEXT: builtin_code: EXP
|
||||||
// CHECK-NEXT: version: 1
|
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: FULLY_CONNECTED,
|
// CHECK-NEXT: deprecated_builtin_code: 9,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: FULLY_CONNECTED,
|
// CHECK-NEXT: deprecated_builtin_code: 9,
|
||||||
// CHECK-NEXT: version: 2
|
// CHECK-NEXT: version: 2,
|
||||||
|
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK: version: 3,
|
// CHECK: version: 3,
|
||||||
// CHECK: operator_codes: [ {
|
// CHECK: operator_codes: [ {
|
||||||
// CHECK: builtin_code: CUSTOM,
|
// CHECK: deprecated_builtin_code: 32,
|
||||||
// CHECK: custom_code: "HashTableV2"
|
// CHECK: custom_code: "HashTableV2",
|
||||||
|
// CHECK: builtin_code: CUSTOM
|
||||||
// CHECK: } ],
|
// CHECK: } ],
|
||||||
// CHECK: subgraphs: [ {
|
// CHECK: subgraphs: [ {
|
||||||
// CHECK: tensors: [ {
|
// CHECK: tensors: [ {
|
||||||
|
@ -4,16 +4,19 @@
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: LESS,
|
// CHECK-NEXT: deprecated_builtin_code: 58,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: LESS
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: IF,
|
// CHECK-NEXT: deprecated_builtin_code: 118,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: IF
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: MUL,
|
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: MUL
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,11 +5,13 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: LOGICAL_OR,
|
// CHECK-NEXT: deprecated_builtin_code: 84,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: LOGICAL_OR
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: LOGICAL_AND,
|
// CHECK-NEXT: deprecated_builtin_code: 86,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: LOGICAL_AND
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -4,8 +4,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: LSTM,
|
// CHECK-NEXT: deprecated_builtin_code: 16,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: LSTM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -7,8 +7,9 @@ func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: LSTM,
|
// CHECK-NEXT: deprecated_builtin_code: 16,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: LSTM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,20 +5,25 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE,
|
// CHECK-NEXT: deprecated_builtin_code: 99,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: MUL,
|
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: MUL
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: DIV,
|
// CHECK-NEXT: deprecated_builtin_code: 42,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: DIV
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: EXP,
|
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: EXP
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: NEG,
|
// CHECK-NEXT: deprecated_builtin_code: 59,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: NEG
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform<i8:f32, 0.1>>) -> tensor<3x!quant.uniform<i8:
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: MUL,
|
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||||
// CHECK-NEXT: version: 2
|
// CHECK-NEXT: version: 2,
|
||||||
|
// CHECK-NEXT: builtin_code: MUL
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform<i8:f32, 1.0>>) -> tensor<3x!quant.uniform<i8:
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: MUL,
|
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||||
// CHECK-NEXT: version: 3
|
// CHECK-NEXT: version: 3,
|
||||||
|
// CHECK-NEXT: builtin_code: MUL
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D,
|
// CHECK-NEXT: deprecated_builtin_code: 1,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||||
// CHECK-NEXT: custom_code: "NumericVerify"
|
// CHECK-NEXT: custom_code: "NumericVerify",
|
||||||
|
// CHECK-NEXT: builtin_code: CUSTOM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -4,20 +4,25 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: QUANTIZE,
|
// CHECK-NEXT: deprecated_builtin_code: 114,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: QUANTIZE
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: CONV_2D,
|
// CHECK-NEXT: deprecated_builtin_code: 3,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: CONV_2D
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: RESHAPE,
|
// CHECK-NEXT: deprecated_builtin_code: 22,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: RESHAPE
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: SOFTMAX,
|
// CHECK-NEXT: deprecated_builtin_code: 25,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: SOFTMAX
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: DEQUANTIZE,
|
// CHECK-NEXT: deprecated_builtin_code: 6,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: DEQUANTIZE
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -5,8 +5,9 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> {
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: RESHAPE,
|
// CHECK-NEXT: deprecated_builtin_code: 22,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: RESHAPE
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -7,8 +7,9 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: SUB,
|
// CHECK-NEXT: deprecated_builtin_code: 41,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: SUB
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
|
@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: SVDF,
|
// CHECK-NEXT: deprecated_builtin_code: 27,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: SVDF
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) ->
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: SVDF,
|
// CHECK-NEXT: deprecated_builtin_code: 27,
|
||||||
// CHECK-NEXT: version: 2
|
// CHECK-NEXT: version: 2,
|
||||||
|
// CHECK-NEXT: builtin_code: SVDF
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -3,14 +3,17 @@
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: WHILE,
|
// CHECK-NEXT: deprecated_builtin_code: 119,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: WHILE
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: GREATER,
|
// CHECK-NEXT: deprecated_builtin_code: 61,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: GREATER
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: SUB,
|
// CHECK-NEXT: deprecated_builtin_code: 41,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: SUB
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
|
@ -4,8 +4,9 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: TRANSPOSE_CONV,
|
// CHECK-NEXT: deprecated_builtin_code: 67,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: TRANSPOSE_CONV
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK: version: 3,
|
// CHECK: version: 3,
|
||||||
// CHECK: operator_codes: [ {
|
// CHECK: operator_codes: [ {
|
||||||
// CHECK: builtin_code: CUSTOM,
|
// CHECK: deprecated_builtin_code: 32,
|
||||||
// CHECK: custom_code: "SomeOperation"
|
// CHECK: custom_code: "SomeOperation",
|
||||||
|
// CHECK: builtin_code: CUSTOM
|
||||||
// CHECK: } ],
|
// CHECK: } ],
|
||||||
// CHECK: subgraphs: [ {
|
// CHECK: subgraphs: [ {
|
||||||
// CHECK: tensors: [ {
|
// CHECK: tensors: [ {
|
||||||
|
@ -4,8 +4,9 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM,
|
// CHECK-NEXT: deprecated_builtin_code: 44,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN,
|
// CHECK-NEXT: deprecated_builtin_code: 35,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
// CHECK-NEXT: subgraphs: [ {
|
// CHECK-NEXT: subgraphs: [ {
|
||||||
// CHECK-NEXT: tensors: [ {
|
// CHECK-NEXT: tensors: [ {
|
||||||
|
@ -3,14 +3,17 @@
|
|||||||
// CHECK: {
|
// CHECK: {
|
||||||
// CHECK-NEXT: version: 3,
|
// CHECK-NEXT: version: 3,
|
||||||
// CHECK-NEXT: operator_codes: [ {
|
// CHECK-NEXT: operator_codes: [ {
|
||||||
// CHECK-NEXT: builtin_code: WHILE,
|
// CHECK-NEXT: deprecated_builtin_code: 119,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: WHILE
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: GREATER,
|
// CHECK-NEXT: deprecated_builtin_code: 61,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: GREATER
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: builtin_code: SUB,
|
// CHECK-NEXT: deprecated_builtin_code: 41,
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1,
|
||||||
|
// CHECK-NEXT: builtin_code: SUB
|
||||||
// CHECK-NEXT: }, {
|
// CHECK-NEXT: }, {
|
||||||
// CHECK-NEXT: version: 1
|
// CHECK-NEXT: version: 1
|
||||||
// CHECK-NEXT: } ],
|
// CHECK-NEXT: } ],
|
||||||
|
@ -272,6 +272,22 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
|||||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fuseBroadcastMulIntoFullyConnected
|
||||||
|
func @fuseBroadcastMulIntoFullyConnected(%arg0: tensor<1x10368xbf16>) -> tensor<32x1x256xbf16> {
|
||||||
|
%cst_0 = constant dense<2.0> : tensor<256x10368xbf16>
|
||||||
|
%cst_1 = constant unit
|
||||||
|
%cst_2 = constant dense<3.0> : tensor<32x1x256xbf16>
|
||||||
|
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {
|
||||||
|
fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"
|
||||||
|
} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16>
|
||||||
|
%1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16>
|
||||||
|
return %1 : tensor<32x1x256xbf16>
|
||||||
|
|
||||||
|
// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16>
|
||||||
|
// CHECK: %[[V1:.*]] = "tfl.mul"(%[[V0]], {{.*}}) {{{.*}}} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16>
|
||||||
|
// CHECK: return %[[V1]] : tensor<32x1x256xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs
|
// CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs
|
||||||
func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||||
|
@ -139,6 +139,9 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
|||||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||||
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
|
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
|
||||||
mlir::PassManager* pass_manager) {
|
mlir::PassManager* pass_manager) {
|
||||||
|
// Explicitly disable dumping Op details on failures.
|
||||||
|
module.getContext()->printOpOnDiagnostic(false);
|
||||||
|
|
||||||
// Register a warning handler only log to std out.
|
// Register a warning handler only log to std out.
|
||||||
mlir::ScopedDiagnosticHandler s(
|
mlir::ScopedDiagnosticHandler s(
|
||||||
module.getContext(), [](mlir::Diagnostic& diag) {
|
module.getContext(), [](mlir::Diagnostic& diag) {
|
||||||
|
@ -416,6 +416,10 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
|
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
// If we are broadcasting on the lhs then don't fold the multiply as it
|
||||||
|
// would increase the amount of compute done by the fully connected op.
|
||||||
|
if (mul_op.lhs().getType() != mul_op.getType()) return failure();
|
||||||
|
|
||||||
// Mul.
|
// Mul.
|
||||||
DenseElementsAttr cst;
|
DenseElementsAttr cst;
|
||||||
Value constant_val = mul_op.rhs();
|
Value constant_val = mul_op.rhs();
|
||||||
|
@ -74,8 +74,8 @@ tool_names = [
|
|||||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
||||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
|
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
|
||||||
'tfjs-opt'
|
'xla-thunks-opt', 'tfjs-opt'
|
||||||
]
|
]
|
||||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||||
|
@ -1394,6 +1394,7 @@ cc_library(
|
|||||||
":decode_constant_pass",
|
":decode_constant_pass",
|
||||||
":eval_util",
|
":eval_util",
|
||||||
":tensorflow",
|
":tensorflow",
|
||||||
|
":tensorflow_traits",
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
@ -1961,6 +1962,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":convert_tensor",
|
":convert_tensor",
|
||||||
":convert_type",
|
":convert_type",
|
||||||
|
":export_tf_dialect_op",
|
||||||
":export_utils",
|
":export_utils",
|
||||||
":tensorflow",
|
":tensorflow",
|
||||||
":tensorflow_attributes",
|
":tensorflow_attributes",
|
||||||
|
@ -297,6 +297,33 @@ Equivalent to np.angle.
|
|||||||
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_AnonymousIteratorOp : TF_Op<"AnonymousIterator", []> {
|
||||||
|
let summary = "A container for an iterator resource.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_AnonymousIteratorV2Op : TF_Op<"AnonymousIteratorV2", []> {
|
||||||
|
let summary = "A container for an iterator resource.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle,
|
||||||
|
TF_VariantTensor:$deleter
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> {
|
def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> {
|
||||||
let summary = "";
|
let summary = "";
|
||||||
|
|
||||||
@ -308,6 +335,21 @@ def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_AnonymousMultiDeviceIteratorOp : TF_Op<"AnonymousMultiDeviceIterator", []> {
|
||||||
|
let summary = "A container for a multi device iterator resource.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Confined<StrArrayAttr, [ArrayMinCount<1>]>:$devices,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle,
|
||||||
|
TF_VariantTensor:$deleter
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_AnonymousRandomSeedGeneratorOp : TF_Op<"AnonymousRandomSeedGenerator", []> {
|
def TF_AnonymousRandomSeedGeneratorOp : TF_Op<"AnonymousRandomSeedGenerator", []> {
|
||||||
let summary = "";
|
let summary = "";
|
||||||
|
|
||||||
@ -2485,6 +2527,17 @@ is the same, though it is cleaner to use `tf.io.decode_image`.
|
|||||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_DeleteIteratorOp : TF_Op<"DeleteIterator", []> {
|
||||||
|
let summary = "A container for an iterator resource.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorFree]>:$handle,
|
||||||
|
TF_VariantTensor:$deleter
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> {
|
def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> {
|
||||||
let summary = "";
|
let summary = "";
|
||||||
|
|
||||||
@ -2496,6 +2549,20 @@ def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> {
|
|||||||
let results = (outs);
|
let results = (outs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_DeleteMultiDeviceIteratorOp : TF_Op<"DeleteMultiDeviceIterator", []> {
|
||||||
|
let summary = "A container for an iterator resource.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorFree]>:$multi_device_iterator,
|
||||||
|
Arg<Variadic<TF_ResourceTensor>, "", [TF_DatasetIteratorRead]>:$iterators,
|
||||||
|
TF_VariantTensor:$deleter
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_DeleteRandomSeedGeneratorOp : TF_Op<"DeleteRandomSeedGenerator", []> {
|
def TF_DeleteRandomSeedGeneratorOp : TF_Op<"DeleteRandomSeedGenerator", []> {
|
||||||
let summary = "";
|
let summary = "";
|
||||||
|
|
||||||
@ -2719,6 +2786,19 @@ Computes the gradients of depthwise convolution with respect to the input.
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_DeserializeIteratorOp : TF_Op<"DeserializeIterator", []> {
|
||||||
|
let summary = [{
|
||||||
|
Converts the given variant tensor to an iterator and stores it in the given resource.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$resource_handle,
|
||||||
|
TF_VariantTensor:$serialized
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
|
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
|
||||||
let summary = "Return the index of device the op runs.";
|
let summary = "Return the index of device the op runs.";
|
||||||
|
|
||||||
@ -4965,11 +5045,58 @@ tf.math.is_nan(x) ==> [False, True, False, True, False]
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_IteratorOp : TF_Op<"Iterator", []> {
|
||||||
|
let summary = "A container for an iterator resource.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
StrAttr:$shared_name,
|
||||||
|
StrAttr:$container,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_IteratorFromStringHandleOp : TF_Op<"IteratorFromStringHandle", []> {
|
||||||
|
let summary = [{
|
||||||
|
Converts the given string representing a handle to an iterator to a resource.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_StrTensor:$string_handle,
|
||||||
|
|
||||||
|
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
|
||||||
|
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$resource_handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_IteratorFromStringHandleV2Op : TF_Op<"IteratorFromStringHandleV2", []> {
|
||||||
|
let summary = "";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_StrTensor:$string_handle,
|
||||||
|
|
||||||
|
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
|
||||||
|
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$resource_handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> {
|
def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> {
|
||||||
let summary = "Gets the next output from the given iterator .";
|
let summary = "Gets the next output from the given iterator .";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TF_ResourceTensor:$iterator
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@ -4980,6 +5107,74 @@ def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> {
|
|||||||
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
|
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_IteratorGetNextAsOptionalOp : TF_Op<"IteratorGetNextAsOptional", []> {
|
||||||
|
let summary = [{
|
||||||
|
Gets the next output from the given iterator as an Optional variant.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator,
|
||||||
|
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_VariantTensor:$optional
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_IteratorGetNextSyncOp : TF_Op<"IteratorGetNextSync", []> {
|
||||||
|
let summary = "Gets the next output from the given iterator.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This operation is a synchronous version IteratorGetNext. It should only be used
|
||||||
|
in situations where the iterator does not block the calling thread, or where
|
||||||
|
the calling thread is not a member of the thread pool used to execute parallel
|
||||||
|
operations (e.g. in eager mode).
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Variadic<TF_Tensor>:$components
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||||
|
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_IteratorToStringHandleOp : TF_Op<"IteratorToStringHandle", []> {
|
||||||
|
let summary = [{
|
||||||
|
Converts the given `resource_handle` representing an iterator to a string.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$resource_handle
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_StrTensor:$string_handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_IteratorV2Op : TF_Op<"IteratorV2", []> {
|
||||||
|
let summary = "";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
StrAttr:$shared_name,
|
||||||
|
StrAttr:$container,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
|
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
|
||||||
let summary = "L2 Loss.";
|
let summary = "L2 Loss.";
|
||||||
|
|
||||||
@ -5586,6 +5781,24 @@ A 2-D example:
|
|||||||
TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_MakeIteratorOp : TF_Op<"MakeIterator", []> {
|
||||||
|
let summary = [{
|
||||||
|
Makes a new iterator from the given `dataset` and stores it in `iterator`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This operation may be executed multiple times. Each execution will reset the
|
||||||
|
iterator in `iterator` to the first element of `dataset`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_VariantTensor:$dataset,
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$iterator
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Multiply the matrix "a" by the matrix "b".
|
Multiply the matrix "a" by the matrix "b".
|
||||||
@ -6909,6 +7122,82 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_MultiDeviceIteratorOp : TF_Op<"MultiDeviceIterator", []> {
|
||||||
|
let summary = "Creates a MultiDeviceIterator resource.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Confined<StrArrayAttr, [ArrayMinCount<1>]>:$devices,
|
||||||
|
StrAttr:$shared_name,
|
||||||
|
StrAttr:$container,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_MultiDeviceIteratorFromStringHandleOp : TF_Op<"MultiDeviceIteratorFromStringHandle", []> {
|
||||||
|
let summary = [{
|
||||||
|
Generates a MultiDeviceIterator resource from its provided string handle.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_StrTensor:$string_handle,
|
||||||
|
|
||||||
|
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
|
||||||
|
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$multi_device_iterator
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_MultiDeviceIteratorGetNextFromShardOp : TF_Op<"MultiDeviceIteratorGetNextFromShard", []> {
|
||||||
|
let summary = "Gets next element for the provided shard number.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$multi_device_iterator,
|
||||||
|
TF_Int32Tensor:$shard_num,
|
||||||
|
TF_Int64Tensor:$incarnation_id
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Variadic<TF_Tensor>:$components
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||||
|
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_MultiDeviceIteratorInitOp : TF_Op<"MultiDeviceIteratorInit", []> {
|
||||||
|
let summary = "Initializes the multi device iterator with the given dataset.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_VariantTensor:$dataset,
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$multi_device_iterator,
|
||||||
|
TF_Int64Tensor:$max_buffer_size
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Int64Tensor:$incarnation_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_MultiDeviceIteratorToStringHandleOp : TF_Op<"MultiDeviceIteratorToStringHandle", []> {
|
||||||
|
let summary = "Produces a string handle for the given MultiDeviceIterator.";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$multi_device_iterator
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_StrTensor:$string_handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> {
|
def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> {
|
||||||
let summary = "Draws samples from a multinomial distribution.";
|
let summary = "Draws samples from a multinomial distribution.";
|
||||||
|
|
||||||
@ -7363,6 +7652,44 @@ output =
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_OneShotIteratorOp : TF_Op<"OneShotIterator", []> {
|
||||||
|
let summary = [{
|
||||||
|
Makes a "one-shot" iterator that can be iterated only once.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
A one-shot iterator bundles the logic for defining the dataset and
|
||||||
|
the state of the iterator in a single op, which allows simple input
|
||||||
|
pipelines to be defined without an additional initialization
|
||||||
|
("MakeIterator") step.
|
||||||
|
|
||||||
|
One-shot iterators have the following limitations:
|
||||||
|
|
||||||
|
* They do not support parameterization: all logic for creating the underlying
|
||||||
|
dataset must be bundled in the `dataset_factory` function.
|
||||||
|
* They are not resettable. Once a one-shot iterator reaches the end of its
|
||||||
|
underlying dataset, subsequent "IteratorGetNext" operations on that
|
||||||
|
iterator will always produce an `OutOfRange` error.
|
||||||
|
|
||||||
|
For greater flexibility, use "Iterator" and "MakeIterator" to define
|
||||||
|
an iterator using an arbitrary subgraph, which may capture tensors
|
||||||
|
(including fed values) as parameters, and which may be reset multiple
|
||||||
|
times by rerunning "MakeIterator".
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SymbolRefAttr:$dataset_factory,
|
||||||
|
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||||
|
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||||
|
StrAttr:$container,
|
||||||
|
StrAttr:$shared_name
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
|
def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
|
||||||
let summary = "Enqueue multiple Tensor values on the computation outfeed.";
|
let summary = "Enqueue multiple Tensor values on the computation outfeed.";
|
||||||
|
|
||||||
@ -10353,6 +10680,22 @@ Computes gradients for the scaled exponential linear (Selu) operation.
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_SerializeIteratorOp : TF_Op<"SerializeIterator", []> {
|
||||||
|
let summary = [{
|
||||||
|
Converts the given `resource_handle` representing an iterator to a variant tensor.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$resource_handle,
|
||||||
|
|
||||||
|
DefaultValuedAttr<I64Attr, "0">:$external_state_policy
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_VariantTensor:$serialized
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> {
|
def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> {
|
||||||
let summary = "Returns the shape of a tensor.";
|
let summary = "Returns the shape of a tensor.";
|
||||||
|
|
||||||
@ -11409,7 +11752,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
|
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
let summary = "Draws samples from a multinomial distribution.";
|
let summary = "Draws samples from a multinomial distribution.";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@ -11427,7 +11770,82 @@ def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
|
|||||||
TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> {
|
def TF_StatelessParameterizedTruncatedNormalOp : TF_Op<"StatelessParameterizedTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
|
let summary = "";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_I32OrI64Tensor:$shape,
|
||||||
|
TF_I32OrI64Tensor:$seed,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$means,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$stddevs,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$minvals,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$maxvals
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomBinomialOp : TF_Op<"StatelessRandomBinomial", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
|
let summary = [{
|
||||||
|
Outputs deterministic pseudorandom random numbers from a binomial distribution.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Outputs random values from a binomial distribution.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `seed`, `counts`, and `probs`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_I32OrI64Tensor:$shape,
|
||||||
|
TF_I32OrI64Tensor:$seed,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$counts,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$probs
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
|
||||||
|
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomGammaV2Op : TF_Op<"StatelessRandomGammaV2", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
|
let summary = [{
|
||||||
|
Outputs deterministic pseudorandom random numbers from a gamma distribution.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Outputs random values from a gamma distribution.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_I32OrI64Tensor:$shape,
|
||||||
|
TF_I32OrI64Tensor:$seed,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$alpha
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Outputs deterministic pseudorandom values from a normal distribution.
|
Outputs deterministic pseudorandom values from a normal distribution.
|
||||||
}];
|
}];
|
||||||
@ -11452,7 +11870,34 @@ The outputs are a deterministic function of `shape` and `seed`.
|
|||||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> {
|
def TF_StatelessRandomPoissonOp : TF_Op<"StatelessRandomPoisson", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
|
let summary = [{
|
||||||
|
Outputs deterministic pseudorandom random numbers from a Poisson distribution.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Outputs random values from a Poisson distribution.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape`, `seed`, and `lam`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_I32OrI64Tensor:$shape,
|
||||||
|
TF_I32OrI64Tensor:$seed,
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$lam
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Rtype = TF_DerivedOperandTypeAttr<2>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Outputs deterministic pseudorandom random values from a uniform distribution.
|
Outputs deterministic pseudorandom random values from a uniform distribution.
|
||||||
}];
|
}];
|
||||||
@ -11478,7 +11923,32 @@ The outputs are a deterministic function of `shape` and `seed`.
|
|||||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> {
|
def TF_StatelessRandomUniformFullIntOp : TF_Op<"StatelessRandomUniformFullInt", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
|
let summary = [{
|
||||||
|
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The generated values are uniform integers covering the whole range of `dtype`.
|
||||||
|
|
||||||
|
The outputs are a deterministic function of `shape` and `seed`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_I32OrI64Tensor:$shape,
|
||||||
|
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$seed
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||||
}];
|
}];
|
||||||
@ -11505,7 +11975,7 @@ The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxv
|
|||||||
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> {
|
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Outputs deterministic pseudorandom values from a truncated normal distribution.
|
Outputs deterministic pseudorandom values from a truncated normal distribution.
|
||||||
}];
|
}];
|
||||||
|
@ -73,6 +73,9 @@ def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
|
|||||||
// certain state around within their implementations.
|
// certain state around within their implementations.
|
||||||
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
|
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
|
||||||
|
|
||||||
|
// Trait to indicate an operation cannot be constant folded.
|
||||||
|
def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
|
||||||
|
|
||||||
// Coefficient wise binary operation with implicit broadcasting support, for
|
// Coefficient wise binary operation with implicit broadcasting support, for
|
||||||
// example tf.Sub operation.
|
// example tf.Sub operation.
|
||||||
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
|
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
|
||||||
@ -112,6 +115,7 @@ def TF_SummaryResource : TF_ResourceBase<"Summary">;
|
|||||||
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
|
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
|
||||||
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
|
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
|
||||||
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
|
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
|
||||||
|
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
|
||||||
|
|
||||||
def TF_VariableRead : MemRead<TF_VariableResource>;
|
def TF_VariableRead : MemRead<TF_VariableResource>;
|
||||||
def TF_StackRead : MemRead<TF_StackResource>;
|
def TF_StackRead : MemRead<TF_StackResource>;
|
||||||
@ -119,6 +123,7 @@ def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
|
|||||||
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
|
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
|
||||||
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
|
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
|
||||||
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
|
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
|
||||||
|
def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
|
||||||
|
|
||||||
def TF_VariableWrite : MemWrite<TF_VariableResource>;
|
def TF_VariableWrite : MemWrite<TF_VariableResource>;
|
||||||
def TF_StackWrite : MemWrite<TF_StackResource>;
|
def TF_StackWrite : MemWrite<TF_StackResource>;
|
||||||
@ -127,6 +132,7 @@ def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
|
|||||||
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
|
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
|
||||||
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
|
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
|
||||||
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
|
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
|
||||||
|
def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
|
||||||
|
|
||||||
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
|
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
|
||||||
def TF_StackAlloc : MemAlloc<TF_StackResource>;
|
def TF_StackAlloc : MemAlloc<TF_StackResource>;
|
||||||
@ -135,12 +141,14 @@ def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
|
|||||||
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
|
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
|
||||||
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
|
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
|
||||||
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
|
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
|
||||||
|
def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
|
||||||
|
|
||||||
def TF_StackFree : MemFree<TF_StackResource>;
|
def TF_StackFree : MemFree<TF_StackResource>;
|
||||||
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
|
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
|
||||||
def TF_SummaryFree : MemFree<TF_SummaryResource>;
|
def TF_SummaryFree : MemFree<TF_SummaryResource>;
|
||||||
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
|
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
|
||||||
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
|
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
|
||||||
|
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TensorFlow op definitions
|
// TensorFlow op definitions
|
||||||
|
@ -1446,7 +1446,8 @@ static LogicalResult Verify(SplitVOp op) {
|
|||||||
if (!split_sizes_type) return success();
|
if (!split_sizes_type) return success();
|
||||||
|
|
||||||
if (split_sizes_type.getRank() != 1 ||
|
if (split_sizes_type.getRank() != 1 ||
|
||||||
split_sizes_type.getDimSize(0) != op.getNumResults())
|
(split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize &&
|
||||||
|
split_sizes_type.getDimSize(0) != op.getNumResults()))
|
||||||
return op.emitOpError("split sizes should be a 1D tensor of ")
|
return op.emitOpError("split sizes should be a 1D tensor of ")
|
||||||
<< op.getNumResults() << " elements";
|
<< op.getNumResults() << " elements";
|
||||||
|
|
||||||
|
@ -53,6 +53,10 @@ struct DatasetMemoryCache
|
|||||||
StringRef getName() final { return "DatasetMemoryCache"; }
|
StringRef getName() final { return "DatasetMemoryCache"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct DatasetIterator : ::mlir::SideEffects::Resource::Base<DatasetIterator> {
|
||||||
|
StringRef getName() final { return "DatasetIterator"; }
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace ResourceEffects
|
} // namespace ResourceEffects
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -124,6 +124,10 @@ class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Trait to indicate an operation cannot be constant folded.
|
||||||
|
template <typename ConcreteType>
|
||||||
|
class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold> {};
|
||||||
|
|
||||||
// Coefficient-wise binary operation with implicit broadcasting support, for
|
// Coefficient-wise binary operation with implicit broadcasting support, for
|
||||||
// example tf.Sub operation.
|
// example tf.Sub operation.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
|
@ -502,3 +502,12 @@ func @fold_conv() -> tensor<1x520x520x1xf32> {
|
|||||||
// CHECK: tf.Const
|
// CHECK: tf.Const
|
||||||
// CHECK-NOT: tf.DepthwiseConv2dNative
|
// CHECK-NOT: tf.DepthwiseConv2dNative
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: DontFoldNoConstantFold
|
||||||
|
func @DontFoldNoConstantFold() -> tensor<8xf32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%1 = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
// CHECK: tf.StatelessRandomUniform
|
||||||
|
%2 = "tf.StatelessRandomUniform"(%0, %1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<8xf32>
|
||||||
|
return %2 : tensor<8xf32>
|
||||||
|
}
|
||||||
|
@ -138,17 +138,17 @@ func @op_string_operand_string_result(%arg0: tensor<!tf.string>) -> tensor<i32>
|
|||||||
return %0 : tensor<i32>
|
return %0 : tensor<i32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation.
|
// Test that operations inside tf.IfRegion op are corrected marked for outside
|
||||||
|
// compilation.
|
||||||
|
|
||||||
// CHECK-LABEL: func @if_region_captured_string
|
// CHECK-LABEL: func @ops_inside_tf_if_outside_compiled
|
||||||
func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
func @ops_inside_tf_if_outside_compiled(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
||||||
%0 = "tf_device.cluster"() ( {
|
%0 = "tf_device.cluster"() ( {
|
||||||
// CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
|
// CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||||
// CHECK-NOT: _xla_outside_compilation
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
// CHECK: "tf.IfRegion"
|
// CHECK: "tf.IfRegion"
|
||||||
// CHECK: "tf.StringToNumber"
|
// CHECK: "tf.StringToNumber"
|
||||||
// CHECK-NOT: _xla_outside_compilation
|
// CHECK-SAME: _xla_outside_compilation
|
||||||
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
|
|
||||||
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||||
%2 = "tf.IfRegion"(%arg0) ( {
|
%2 = "tf.IfRegion"(%arg0) ( {
|
||||||
%3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
|
%3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
|
||||||
@ -163,7 +163,8 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) ->
|
|||||||
return %0 : tensor<f32>
|
return %0 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation.
|
// Test that ops with string results/operands inside a tf.IfRegion branch are
|
||||||
|
// marked for outside compilation.
|
||||||
|
|
||||||
// CHECK-LABEL: func @if_region_string_op
|
// CHECK-LABEL: func @if_region_string_op
|
||||||
func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
|
func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
|
||||||
@ -191,7 +192,8 @@ func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32
|
|||||||
return %0 : tensor<f32>
|
return %0 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation.
|
// Test that ops with string results/operands inside a nested tf.IfRegion branch
|
||||||
|
// are marked for outside compilation.
|
||||||
|
|
||||||
// CHECK-LABEL: func @nested_if_region_string_op
|
// CHECK-LABEL: func @nested_if_region_string_op
|
||||||
func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
|
func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
|
||||||
@ -231,16 +233,17 @@ func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> ten
|
|||||||
return %0 : tensor<f32>
|
return %0 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that a tf.WhileRegion op with a captured string operand is marked for outside compilation.
|
// Test that ops inside tf.WhileRegion op are correct marked for outside
|
||||||
|
// compilation.
|
||||||
|
|
||||||
// CHECK-LABEL: func @while_region_captured_string
|
// CHECK-LABEL: func @ops_inside_while_outside_compiled
|
||||||
func @while_region_captured_string(%arg0: tensor<i32>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
func @ops_inside_while_outside_compiled(%arg0: tensor<i32>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
||||||
%0 = "tf_device.cluster"() ( {
|
%0 = "tf_device.cluster"() ( {
|
||||||
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||||
// CHECK-NOT: _xla_outside_compilation
|
// CHECK-NOT: _xla_outside_compilation
|
||||||
// CHECK: "tf.WhileRegion"
|
// CHECK: "tf.WhileRegion"
|
||||||
// CHECK: "tf.StringToNumber"
|
// CHECK: "tf.StringToNumber"
|
||||||
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
|
// CHECK-SAME: _xla_outside_compilation
|
||||||
%1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
%1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||||
%2:2 = "tf.WhileRegion"(%1, %arg0) ( {
|
%2:2 = "tf.WhileRegion"(%1, %arg0) ( {
|
||||||
^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):
|
^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user