Resolved merge conflicts
This commit is contained in:
commit
f5762da2e7
2
.bazelrc
2
.bazelrc
@ -323,6 +323,8 @@ build:windows --copt=/experimental:preprocessor
|
||||
build:windows --host_copt=/experimental:preprocessor
|
||||
|
||||
# Misc build options we need for windows.
|
||||
build:windows --linkopt=/DEBUG
|
||||
build:windows --host_linkopt=/DEBUG
|
||||
build:windows --linkopt=/OPT:REF
|
||||
build:windows --host_linkopt=/OPT:REF
|
||||
build:windows --linkopt=/OPT:ICF
|
||||
|
6
.github/bot_config.yml
vendored
6
.github/bot_config.yml
vendored
@ -12,12 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
|
28
.github/workflows/update-nightly.yml
vendored
Normal file
28
.github/workflows/update-nightly.yml
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Allow manual triggers
|
||||
schedule:
|
||||
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
|
||||
name: Set nightly branch to master HEAD
|
||||
jobs:
|
||||
master-to-nightly:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: zofrex/mirror-branch@v1
|
||||
name: Set nightly branch to master HEAD
|
||||
with:
|
||||
target-branch: 'nightly'
|
22
RELEASE.md
22
RELEASE.md
@ -1,4 +1,4 @@
|
||||
# Release 2.4.0
|
||||
h# Release 2.4.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
||||
@ -209,8 +209,13 @@
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* TextVectorization can now accept a vocabulary list or file as an
|
||||
init arg.
|
||||
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||
accepts a `return_attention_scores` argument. When set to
|
||||
True, the layer returns the attention scores as an additional output
|
||||
argument.
|
||||
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
|
||||
with the same implementation as their `tf.losses` equivalent.
|
||||
* `tf.function` / AutoGraph:
|
||||
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
performance.
|
||||
@ -296,16 +301,21 @@
|
||||
|
||||
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
|
||||
|
||||
* `TensorRT`
|
||||
* Add parameter allow_mixed_precision_on_unconverted_ops to
|
||||
TrtConversionParams.
|
||||
* `tf.print`:
|
||||
|
||||
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
|
||||
didn't have the keys sorted, the keys and values were not being printed
|
||||
in accordance with their correct mapping.
|
||||
|
||||
* Other:
|
||||
|
||||
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
|
||||
and "denylist" where possible. Please see
|
||||
https://developers.google.com/style/word-list#blacklist for more
|
||||
context. <ADD RELEASE NOTES HERE>
|
||||
context.
|
||||
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
|
||||
rollout the new MLIR TPU bridge.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
// TODO(b/170399182): Update test once an alternative to using the function
|
||||
// optimization hook is in place.
|
||||
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
|
@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = sqrt(inputs[0])
|
||||
// return grad(y, {inputs[0]})
|
||||
Status SqrtGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
for (auto sqrt_output : sqrt_outputs) {
|
||||
sqrt_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSqrtGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// y = sqrt(x)
|
||||
// outputs = tape.gradient(y, x)
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_NEAR(*result_value, 0.5, 0.001);
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// Pseudo-code:
|
||||
//
|
||||
|
@ -29,6 +29,7 @@ cc_library(
|
||||
}),
|
||||
deps = [
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//third_party/hadoop:hdfs",
|
||||
|
@ -22,9 +22,9 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/logging.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for HADOOP environments.
|
||||
@ -148,15 +148,20 @@ class LibHDFS {
|
||||
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
|
||||
if (hdfs_home != nullptr) {
|
||||
auto JoinPath = [](std::string home, std::string lib) {
|
||||
#if defined(_WIN32)
|
||||
if (home.back() != '\\') home.push_back('\\');
|
||||
return home + "lib\\native\\" + lib;
|
||||
#else
|
||||
if (home.back() != '/') home.push_back('/');
|
||||
return home + "lib/native/" + lib;
|
||||
#endif
|
||||
};
|
||||
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
|
||||
TryLoadAndBind(path.c_str(), &handle_, status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
return;
|
||||
} else {
|
||||
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
|
||||
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
|
||||
}
|
||||
}
|
||||
|
||||
@ -207,6 +212,7 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
|
||||
builder, namenode.empty() ? "default" : namenode.c_str());
|
||||
cacheKey += namenode;
|
||||
}
|
||||
absl::MutexLock l(&hadoop_file->connection_cache_lock);
|
||||
if (hadoop_file->connection_cache.find(cacheKey) ==
|
||||
hadoop_file->connection_cache.end()) {
|
||||
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
|
||||
@ -418,17 +424,20 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_read_only_memory_region {
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
// Hadoop doesn't support Readonly Memory Region
|
||||
} // namespace tf_read_only_memory_region
|
||||
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_hadoop_filesystem {
|
||||
|
||||
HadoopFile::HadoopFile(TF_Status* status)
|
||||
: libhdfs(new LibHDFS(status)),
|
||||
connection_cache_lock(),
|
||||
connection_cache() {}
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
filesystem->plugin_filesystem = new HadoopFile({new LibHDFS(status), {}});
|
||||
filesystem->plugin_filesystem = new HadoopFile(status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
@ -699,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
return num_entries;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
|
||||
return strdup(uri);
|
||||
}
|
||||
|
||||
} // namespace tf_hadoop_filesystem
|
||||
|
||||
@ -707,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_hadoop_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file =
|
||||
tf_hadoop_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_hadoop_filesystem::NewAppendableFile;
|
||||
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
|
||||
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
|
||||
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
|
||||
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
|
||||
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
|
||||
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
|
||||
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
|
||||
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
|
||||
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "third_party/hadoop/hdfs.h"
|
||||
@ -47,7 +48,10 @@ void Close(const TF_WritableFile* file, TF_Status* status);
|
||||
namespace tf_hadoop_filesystem {
|
||||
typedef struct HadoopFile {
|
||||
LibHDFS* libhdfs;
|
||||
std::map<std::string, hdfsFS> connection_cache;
|
||||
absl::Mutex connection_cache_lock;
|
||||
std::map<std::string, hdfsFS> connection_cache
|
||||
ABSL_GUARDED_BY(connection_cache_lock);
|
||||
HadoopFile(TF_Status* status);
|
||||
} HadoopFile;
|
||||
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
|
@ -24,6 +24,7 @@ using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::SqrtGrad;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
|
||||
AbstractTensorHandlePtr exp_;
|
||||
};
|
||||
|
||||
class SqrtGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||
sqrt->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
std::string name = "Sqrt_Grad";
|
||||
grad_outputs->resize(1);
|
||||
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~SqrtGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractTensorHandlePtr sqrt_;
|
||||
};
|
||||
|
||||
class MatMulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
|
||||
|
@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return exp_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
|
||||
|
||||
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sqrt(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status SqrtGrad(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -166,6 +166,8 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -20,8 +20,10 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
||||
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||
}
|
||||
|
||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
Status Variable::CreateUninitialized(
|
||||
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name, const char* raw_device_name,
|
||||
const std::vector<std::string>& component_devices,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, raw_device_name, &handle));
|
||||
|
||||
if (component_devices.empty()) {
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, raw_device_name, &handle));
|
||||
output->reset(
|
||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
if (!tensorflow::isa<EagerContext>(ctx)) {
|
||||
return errors::InvalidArgument(
|
||||
"Can only load distributed variables with EagerContext.");
|
||||
}
|
||||
|
||||
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
|
||||
|
||||
std::vector<TensorHandle*> handles;
|
||||
for (const auto& device : component_devices) {
|
||||
ImmediateTensorHandlePtr handlePtr;
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
|
||||
&handlePtr));
|
||||
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
|
||||
return errors::Internal("Returned replica handle has unsupported type.");
|
||||
}
|
||||
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
|
||||
}
|
||||
TensorHandle* packed_handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||
std::move(handles), eager_ctx, &packed_handle));
|
||||
// The call to `CreatePackedHandle` incremented the handles' reference count,
|
||||
// which we must now decrement to make the packed handle the owner of those
|
||||
// handles. We can't loop through the `handles` vector because it was
|
||||
// `std::move`d in the call above.
|
||||
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
|
||||
TensorHandle* component;
|
||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
|
||||
component->Unref();
|
||||
}
|
||||
|
||||
handle.reset(packed_handle);
|
||||
output->reset(
|
||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||
return Status();
|
||||
|
@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible {
|
||||
public:
|
||||
// Creates an uninitialized resource variable. Note that a caller must
|
||||
// call "assign" to associate a value with the variable.
|
||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
const char* raw_device_name,
|
||||
std::unique_ptr<Variable>* output);
|
||||
static Status CreateUninitialized(
|
||||
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name, const char* raw_device_name,
|
||||
const std::vector<std::string>& component_devices,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
// The dtype of the underlying variable.
|
||||
DataType dtype();
|
||||
|
@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const std::string& name = variable.name();
|
||||
tensorflow::TensorShape shape(variable.shape());
|
||||
tensorflow::DataType dtype = variable.dtype();
|
||||
std::vector<std::string> component_devices;
|
||||
|
||||
for (const auto& component :
|
||||
variable.experimental_distributed_variable_components()) {
|
||||
component_devices.push_back(component.device());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
|
||||
ctx, dtype, shape, name,
|
||||
variable.device().empty() ? nullptr : variable.device().c_str(), output));
|
||||
variable.device().empty() ? nullptr : variable.device().c_str(),
|
||||
component_devices, output));
|
||||
return Status();
|
||||
}
|
||||
|
||||
|
@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
||||
Status status;
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||
absl::nullopt, nullptr, &var));
|
||||
absl::nullopt, nullptr, {}, &var));
|
||||
|
||||
// Create a TensorHandle
|
||||
ImmediateTensorHandlePtr expected_handle =
|
||||
|
@ -127,7 +127,7 @@ def tf_library(
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --dump_fetch_nodes > $@"),
|
||||
tools = [tfcompile_tool],
|
||||
exec_tools = [tfcompile_tool],
|
||||
# Run tfcompile on the build host, rather than forge, since it's
|
||||
# typically way faster on the local machine.
|
||||
local = 1,
|
||||
@ -242,7 +242,7 @@ def tf_library(
|
||||
" --out_function_object=$(@D)/" + function_object_file +
|
||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||
),
|
||||
tools = [tfcompile_tool],
|
||||
exec_tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
# Run tfcompile on the build host since it's typically faster on the
|
||||
@ -281,7 +281,7 @@ def tf_library(
|
||||
" --out_session_module=$(@D)/" + session_module_pb +
|
||||
" " + flags
|
||||
),
|
||||
tools = [tfcompile_tool],
|
||||
exec_tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
local = 1,
|
||||
|
@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
|
||||
const Node& node, absl::string_view attr_name,
|
||||
absl::string_view call_name) {
|
||||
std::vector<NameAttrList> attr_lists;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
|
||||
|
||||
std::vector<NodeDef> out;
|
||||
for (int i = 0; i < attr_lists.size(); i++) {
|
||||
out.emplace_back();
|
||||
NodeDef& inserted = out.back();
|
||||
inserted.set_name(absl::StrCat(call_name, "_", i));
|
||||
inserted.set_op(attr_lists[i].name());
|
||||
*inserted.mutable_attr() = attr_lists[i].attr();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
// Utility which searches for values in a sorted list by scanning over it once.
|
||||
// No matter how many times ScanForValue is called, the list is scanned at most
|
||||
// once. However, if a call to ScanForValue skips over a value, that value is
|
||||
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
|
||||
return is_compilable;
|
||||
}
|
||||
|
||||
bool RecursiveCompilabilityChecker::IsCompilableCase(
|
||||
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
|
||||
const {
|
||||
xla::StatusOr<std::vector<NodeDef>> calls =
|
||||
MakeCallNodesFromAttribute(case_node, "branches", "branch");
|
||||
if (!calls.ok()) {
|
||||
VLOG(2) << "Rejecting node " << case_node.name() << ": "
|
||||
<< "missing attribute 'branches'";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_compilable = true;
|
||||
|
||||
for (const NodeDef& call : *calls) {
|
||||
is_compilable &=
|
||||
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes);
|
||||
}
|
||||
return is_compilable;
|
||||
}
|
||||
|
||||
// Tests whether 'while_node' is a completely compilable loop.
|
||||
// Every operator in the condition and body functions must be compilable for a
|
||||
// while loop to be compilable.
|
||||
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
|
||||
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
|
||||
uncompilable_nodes)) {
|
||||
LogNotCompilable(node, "unsupported case");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!op_filter_.allow_stateful_rng_ops &&
|
||||
IsStatefulRandomOp(node.type_string())) {
|
||||
absl::string_view uncompilable_reason = "stateful random op";
|
||||
|
@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
|
||||
// Whether ops known to have numerical accuracy issues should be considered
|
||||
// compilable..
|
||||
bool allow_inaccurate_ops = false;
|
||||
|
||||
// Require the function to be always compilable, regardless whether some
|
||||
// control flow branches might be dead for a given input.
|
||||
bool require_always_compilable = false;
|
||||
};
|
||||
|
||||
RecursiveCompilabilityChecker(OperationFilter op_filter,
|
||||
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
// Tests whether 'case_node' is compilable. Every operator in all branches
|
||||
// must be compilable.
|
||||
bool IsCompilableCase(const Node& case_node,
|
||||
FunctionLibraryRuntime* lib_runtime,
|
||||
std::vector<StackFrameView>* stack_trace,
|
||||
NameAttrList* encapsulating_function,
|
||||
UncompilableNodesMap* uncompilable_nodes) const;
|
||||
|
||||
// Returns compilability of node def retrieved from `node`'s attribute with
|
||||
// name `attr_name`.
|
||||
bool ExtractNodeDefAndCheckCompilability(
|
||||
|
@ -34,7 +34,16 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
|
||||
AttrValue attr;
|
||||
for (const char* name : names) {
|
||||
attr.mutable_list()->add_func()->set_name(name);
|
||||
}
|
||||
return attr;
|
||||
}
|
||||
|
||||
constexpr char kFunctionalIfNodeName[] = "If";
|
||||
constexpr char kFunctionalCaseNodeName[] = "Case";
|
||||
constexpr char kFunctionalWhileNodeName[] = "While";
|
||||
constexpr char kCompilableFunctionName[] = "CompilableFn";
|
||||
constexpr char kCompilableFunctionNodeName[] = "n_c";
|
||||
@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
|
||||
op_filter_.allow_inaccurate_ops = false;
|
||||
op_filter_.allow_slow_ops = false;
|
||||
|
||||
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||
device_type_);
|
||||
checker_ = CreateCompilabilityChecker();
|
||||
}
|
||||
|
||||
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
|
||||
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
|
||||
device_type_);
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
||||
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
|
||||
"unsupported op"));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
|
||||
FunctionDefLibrary flib;
|
||||
*flib.add_function() = FunctionDefHelper::Define(
|
||||
/*Function*/ kUncompilableFunctionName,
|
||||
/*Inputs*/ {"n_a:float"},
|
||||
/*Outputs*/ {"n_c_uncompilable:float"},
|
||||
/*Attributes*/ {},
|
||||
// Node info
|
||||
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
|
||||
*flib.add_function() = FunctionDefHelper::Define(
|
||||
/*Function*/ kUncompilableFunctionTwoName,
|
||||
/*Inputs*/ {"n_a:float"},
|
||||
/*Outputs*/ {"n_d_uncompilable:float"},
|
||||
/*Attribute*/ {},
|
||||
// Node info
|
||||
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
|
||||
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
|
||||
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
|
||||
std::vector<NodeBuilder::NodeOut> inputes(
|
||||
{NodeBuilder::NodeOut(placeholder.node())});
|
||||
Node* case_node;
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
|
||||
.Input(branch_index.node())
|
||||
.Input(inputes)
|
||||
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
|
||||
kUncompilableFunctionTwoName}))
|
||||
.Attr("Tout", {DT_INT32})
|
||||
.Finalize(root.graph(), &case_node));
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
|
||||
|
||||
auto case_node_it = std::find_if(
|
||||
graph->nodes().begin(), graph->nodes().end(),
|
||||
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
|
||||
EXPECT_NE(case_node_it, graph->nodes().end());
|
||||
auto* flib_runtime = GetFunctionLibraryRuntime();
|
||||
|
||||
op_filter_.require_always_compilable = false;
|
||||
checker_ = CreateCompilabilityChecker();
|
||||
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||
op_filter_.require_always_compilable = true;
|
||||
checker_ = CreateCompilabilityChecker();
|
||||
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
|
||||
}
|
||||
|
||||
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!RecursiveCompilabilityChecker{
|
||||
CreateOperationFilter(*registration),
|
||||
DeviceType{registration->compilation_device_name}}
|
||||
.IsCompilableNode(*node, lib_runtime)) {
|
||||
RecursiveCompilabilityChecker::OperationFilter filter =
|
||||
CreateOperationFilter(*registration);
|
||||
filter.require_always_compilable = true;
|
||||
|
||||
RecursiveCompilabilityChecker checker(
|
||||
filter, DeviceType{registration->compilation_device_name});
|
||||
|
||||
if (!checker.IsCompilableNode(*node, lib_runtime)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
}
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
std::vector<std::string> control_rets;
|
||||
if (result_dtypes.empty()) {
|
||||
control_rets.push_back(node_def.name());
|
||||
}
|
||||
return CompileGraphToXlaHlo(
|
||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
|
||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
#endif
|
||||
|
@ -29,7 +29,7 @@ LLVM_SRC=...
|
||||
|
||||
# Create basic workspace file
|
||||
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
|
||||
# and over the bazel BUILD files.
|
||||
# and copy over the bazel BUILD files.
|
||||
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
|
||||
cp third_party/mlir/BUILD $LLVM_SRC/mlir
|
||||
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD
|
||||
|
@ -48,6 +48,7 @@ filegroup(
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
|
@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Atan operator";
|
||||
|
||||
let description = [{
|
||||
Returns `Atan(operand)` element-wise.
|
||||
|
||||
$$
|
||||
\atan(x) = \atan2(x, 1)
|
||||
$$
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Sinh operation";
|
||||
|
@ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
|
||||
|
||||
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
|
||||
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
|
||||
@ -910,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
|
||||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
||||
// TODO(hinsu): Make this struct dialect independent so that it can be shared
|
||||
// between HLO and LHLO dialect.
|
||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
|
||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||
|
||||
let description = "Structure of dimension information for conv op";
|
||||
}
|
||||
|
||||
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||
// Default value: zero for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$padding,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||
ConvDimensionNumbers:$dimension_numbers,
|
||||
I64Attr:$feature_group_count,
|
||||
I64Attr:$batch_group_count,
|
||||
HLO_PrecisionConfigAttr:$precision_config
|
||||
);
|
||||
let arguments = !con(
|
||||
(ins
|
||||
HLO_Tensor:$lhs,
|
||||
HLO_Tensor:$rhs),
|
||||
ConvolutionAttributes<HLO_Dialect>.attributes);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
}
|
||||
|
@ -1007,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common convolution attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ConvDimensionNumbersBase<Dialect dialect>
|
||||
: StructAttr<"ConvDimensionNumbers", dialect, [
|
||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||
|
||||
let description = "Structure of dimension information for conv op";
|
||||
}
|
||||
|
||||
class ConvolutionAttributes<Dialect dialect> {
|
||||
dag attributes = (ins
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||
// Default value: zero for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$padding,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
|
||||
I64Attr:$feature_group_count,
|
||||
I64Attr:$batch_group_count,
|
||||
HLO_PrecisionConfigAttr:$precision_config
|
||||
);
|
||||
}
|
||||
|
||||
class BASE_HLO_ConvOp {
|
||||
string summary = "Convolution operator";
|
||||
|
||||
|
@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/CopyOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
|
||||
|
||||
def LHLO_Dialect : Dialect {
|
||||
let name = "lmhlo";
|
||||
let cppNamespace = "::mlir::lmhlo";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO type definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any integer tensor types
|
||||
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||
|
||||
// Any floating-point tensor types
|
||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||
|
||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||
|
||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||
|
||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||
|
||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||
|
||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||
|
||||
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -345,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||
let summary = [{
|
||||
"modifies the offset, sizes and strides of a statically shaped memref.
|
||||
modifies the offset, sizes and strides of a statically shaped memref
|
||||
}];
|
||||
let description = [{
|
||||
Allows to modify the offset, sizes and strides of a statically shaped memref.
|
||||
Casts the statically shaped memref operand to a memref with optionally
|
||||
modified offsets, sizes and strides.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
@ -592,40 +568,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
||||
// shared between the HLO and LHLO dialects.
|
||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||
|
||||
let description = "Structure of dimension information for conv op";
|
||||
}
|
||||
|
||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||
// Default value: zero for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$padding,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||
ConvDimensionNumbers:$dimension_numbers,
|
||||
I64Attr:$feature_group_count,
|
||||
I64Attr:$batch_group_count,
|
||||
HLO_PrecisionConfigAttr:$precision_config
|
||||
);
|
||||
let arguments = !con(
|
||||
(ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
|
||||
ConvolutionAttributes<LHLO_Dialect>.attributes);
|
||||
}
|
||||
|
||||
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {
|
||||
|
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef LHLO_OPS_BASE
|
||||
#define LHLO_OPS_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LMHLO type definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Any integer tensor types
|
||||
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||
|
||||
// Any floating-point tensor types
|
||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||
|
||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||
|
||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||
|
||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||
|
||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||
|
||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||
|
||||
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
|
||||
|
||||
#endif // LHLO_OPS_BASE
|
@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <typename PredicateType>
|
||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
||||
return llvm::None;
|
||||
|
@ -15,9 +15,9 @@ limitations under the License.
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
|
||||
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
|
||||
let constructor = "createTestChloLegalizeToHloPass()";
|
||||
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
|
||||
let summary = "Legalize CHLO to HLO.";
|
||||
let constructor = "createChloLegalizeToHloPass()";
|
||||
}
|
||||
|
||||
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||
|
@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
||||
/// Lowers from HLO dialect to Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||
|
||||
/// Lowers from the CHLO dialect to the HLO dialect.
|
||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
|
||||
|
||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||
/// allocated buffers for function results will be returned and escape the
|
||||
@ -63,7 +66,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
||||
|
||||
/// Lowers trigonometric operations from the standard dialect to approximations
|
||||
// that do not use intrinsics.
|
||||
/// that do not use intrinsics.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLegalizeTrigonometricToApproximationPass();
|
||||
|
||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
|
||||
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
|
||||
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
|
||||
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
|
||||
|
@ -2001,6 +2001,23 @@ struct divide<APInt> {
|
||||
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct remainder : std::modulus<T> {};
|
||||
|
||||
template <>
|
||||
struct remainder<APInt> {
|
||||
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct remainder<APFloat> {
|
||||
APFloat operator()(const APFloat& a, const APFloat& b) const {
|
||||
APFloat result(a);
|
||||
result.remainder(b);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct max {
|
||||
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
|
||||
@ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
|
||||
BINARY_FOLDER(SubOp, std::minus);
|
||||
BINARY_FOLDER(MulOp, std::multiplies);
|
||||
BINARY_FOLDER(DivOp, divide);
|
||||
BINARY_FOLDER(RemOp, remainder);
|
||||
BINARY_FOLDER(MaxOp, max);
|
||||
BINARY_FOLDER(MinOp, min);
|
||||
|
||||
|
@ -27,8 +27,8 @@ namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestChloLegalizeToHloPass
|
||||
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
|
||||
struct ChloLegalizeToHloPass
|
||||
: public PassWrapper<ChloLegalizeToHloPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
||||
}
|
||||
@ -36,11 +36,12 @@ struct TestChloLegalizeToHloPass
|
||||
void runOnFunction() override {
|
||||
ConversionTarget conversionTarget(getContext());
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
|
||||
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
|
||||
// The conversion uses helpers from the standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
|
||||
@ -56,8 +57,8 @@ struct TestChloLegalizeToHloPass
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() {
|
||||
return std::make_unique<TestChloLegalizeToHloPass>();
|
||||
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() {
|
||||
return std::make_unique<ChloLegalizeToHloPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
|
@ -24,16 +24,17 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Expand acos to MHLO dialect as follows:
|
||||
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
|
||||
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
|
||||
// = pi if x == -1
|
||||
def : Pat<(HLOClient_AcosOp $input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp $input,
|
||||
(HLO_ConstantLike<"0"> $input),
|
||||
(HLO_CompareOp
|
||||
$input,
|
||||
(HLO_ConstantLike<"-1"> $input),
|
||||
HLO_COMPARISON_DIRECTION_NE
|
||||
),
|
||||
(HLO_MulOp
|
||||
(HLO_ConstantLike<"2.0f"> $input),
|
||||
(HLO_ConstantLike<"2"> $input),
|
||||
(HLO_Atan2Op
|
||||
(HLO_SqrtOp
|
||||
(HLO_SubOp
|
||||
@ -47,7 +48,16 @@ def : Pat<(HLOClient_AcosOp $input),
|
||||
)
|
||||
)
|
||||
),
|
||||
(HLO_ConstantLike<"M_PI"> $input))>;
|
||||
(HLO_ConstantLike<"M_PI"> $input)
|
||||
)>;
|
||||
|
||||
// Express `atan` as
|
||||
// atan(x) = atan2(x, 1)
|
||||
def : Pat<(HLOClient_AtanOp $input),
|
||||
(HLO_Atan2Op
|
||||
$input,
|
||||
(HLO_ConstantLike<"1"> $input)
|
||||
)>;
|
||||
|
||||
// Express `sinh` as
|
||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||
@ -95,4 +105,3 @@ def : Pat<(HLOClient_TanOp $input),
|
||||
(HLO_SinOp $input),
|
||||
(HLO_CosOp $input)
|
||||
)>;
|
||||
|
||||
|
@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
||||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||
Value shape_operand,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type) {
|
||||
result.getDefiningOp()->emitOpError()
|
||||
<< "tensor to buffer conversion expects ranked results";
|
||||
@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||
auto memref_type =
|
||||
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||
|
||||
Operation* op = result.getDefiningOp();
|
||||
|
||||
// Extract the required element out of the vector.
|
||||
SmallVector<Value, 4> dynamic_operands;
|
||||
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
|
||||
if (shape_element.value() != ShapedType::kDynamicSize) continue;
|
||||
Value index = rewriter->create<ConstantOp>(
|
||||
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
|
||||
shape_element.index()));
|
||||
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
|
||||
ValueRange{index});
|
||||
Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
|
||||
Value alloc_operand =
|
||||
rewriter->create<ExtractElementOp>(loc, shape_operand, index);
|
||||
if (!alloc_operand.getType().isIndex()) {
|
||||
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
|
||||
rewriter->getIndexType());
|
||||
@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||
dynamic_operands.push_back(alloc_operand);
|
||||
}
|
||||
|
||||
// Insert in front of op to ensure sizes are available.
|
||||
OpBuilder allocBuilder(op);
|
||||
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||
return alloc;
|
||||
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||
}
|
||||
|
||||
Value InsertAlloc(Location loc, OpResult result,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type || !result_type.hasStaticShape()) {
|
||||
result.getDefiningOp()->emitOpError()
|
||||
<< "tensor to buffer conversion expects statically shaped results";
|
||||
@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
if (failed(
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status =
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
rewriter.replaceOp(
|
||||
op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -32,8 +32,6 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
using mlir::PassRegistration;
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
@ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
|
||||
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
||||
@ -932,6 +933,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
|
@ -235,6 +235,23 @@ class ApproximateAtan2Lowering
|
||||
}
|
||||
};
|
||||
|
||||
class ApproximateAtanLowering
|
||||
: public ApproximateOnExtendedF32Lowering<AtanOp> {
|
||||
public:
|
||||
explicit ApproximateAtanLowering(MLIRContext *ctx)
|
||||
: ApproximateOnExtendedF32Lowering<AtanOp>(ctx) {}
|
||||
|
||||
// Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation
|
||||
// for the argument range [-1, 1].
|
||||
Value emitApproximation(ValueRange args, Location loc,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value x = args.front();
|
||||
assert(x.getType().isF32());
|
||||
Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1));
|
||||
return rewriter.create<Atan2Op>(loc, x, one);
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeTrigonometricToApproximationPass
|
||||
: public PassWrapper<LegalizeTrigonometricToApproximationPass,
|
||||
FunctionPass> {
|
||||
@ -257,6 +274,7 @@ void PopulateTrigonometricToApproximationPatterns(
|
||||
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
ApproximateAtanLowering,
|
||||
ApproximateAtan2Lowering,
|
||||
ApproximateTanhLowering>(context);
|
||||
// clang-format on
|
||||
|
@ -37,7 +37,6 @@ limitations under the License.
|
||||
|
||||
using mlir::FunctionPass;
|
||||
using mlir::OwningRewritePatternList;
|
||||
using mlir::PassRegistration;
|
||||
using mlir::PassWrapper;
|
||||
|
||||
namespace {
|
||||
|
@ -38,7 +38,6 @@ using mlir::LogicalResult;
|
||||
using mlir::MLIRContext;
|
||||
using mlir::OpRewritePattern;
|
||||
using mlir::OwningRewritePatternList;
|
||||
using mlir::PassRegistration;
|
||||
using mlir::PassWrapper;
|
||||
using mlir::PatternRewriter;
|
||||
using mlir::RankedTensorType;
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using mlir::FunctionPass;
|
||||
using mlir::PassRegistration;
|
||||
using mlir::PassWrapper;
|
||||
|
||||
namespace {
|
||||
|
@ -48,7 +48,7 @@ namespace {
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(TanOp) sep fn(AcosOp) sep fn(SinhOp)
|
||||
fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> {
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remainder_fold_int
|
||||
func @remainder_fold_int() -> tensor<4xi32> {
|
||||
%0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32>
|
||||
%1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32>
|
||||
// CHECK: mhlo.constant dense<[2, 1, 0, 1]>
|
||||
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
|
||||
return %2 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remainder_fold_float
|
||||
func @remainder_fold_float() -> tensor<4xf32> {
|
||||
%0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32>
|
||||
%1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]>
|
||||
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_scalar_fold
|
||||
func @max_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// Check the non-broadcast case for each registered op, then just check a
|
||||
// representative op for detailed broadcast semantics.
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
||||
|
||||
// Lower statically shaped `constant_like` to constant.
|
||||
// CHECK-LABEL: @constant_like_static_shape
|
||||
|
@ -261,3 +261,120 @@ func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 {
|
||||
%res = atan2 %arg0, %arg1 : f16
|
||||
return %res : f16
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @atan_f64
|
||||
func @atan_f64(%arg : f64) -> f64 {
|
||||
// CHECK: atan
|
||||
%res = atan %arg : f64
|
||||
return %res : f64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atan_f32
|
||||
// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32
|
||||
func @atan_f32(%arg : f32) -> f32 {
|
||||
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
|
||||
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
|
||||
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
|
||||
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
|
||||
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
|
||||
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
|
||||
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
|
||||
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
|
||||
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
|
||||
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
|
||||
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
|
||||
// CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32
|
||||
// CHECK: %[[VAL_1:.*]] = absf %arg0 : f32
|
||||
// CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32
|
||||
// CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32
|
||||
// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32
|
||||
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32
|
||||
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32
|
||||
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32
|
||||
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32
|
||||
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32
|
||||
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32
|
||||
// CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32
|
||||
// CHECK: %[[VAL_26:.*]] = cmpf "oeq", %arg0, %[[CST_9]] : f32
|
||||
// CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32
|
||||
// CHECK: %[[VAL_28:.*]] = cmpf "uno", %arg0, %[[CST]] : f32
|
||||
// CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32
|
||||
// CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32
|
||||
// CHECK: return %[[VAL_30]] : f32
|
||||
%res = atan %arg : f32
|
||||
return %res : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @atan_f16
|
||||
// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16
|
||||
func @atan_f16(%arg : f16) -> f16 {
|
||||
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
|
||||
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
|
||||
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
|
||||
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
|
||||
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
|
||||
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
|
||||
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
|
||||
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
|
||||
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
|
||||
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
|
||||
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
|
||||
// CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32
|
||||
// CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32
|
||||
// CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_3:.*]] = cmpf "ole", %[[VAL_1]], %[[VAL_2]] : f32
|
||||
// CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32
|
||||
// CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32
|
||||
// CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32
|
||||
// CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32
|
||||
// CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32
|
||||
// CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32
|
||||
// CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32
|
||||
// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32
|
||||
// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32
|
||||
// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32
|
||||
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32
|
||||
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32
|
||||
// CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32
|
||||
// CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32
|
||||
// CHECK: %[[VAL_27:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_9]] : f32
|
||||
// CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32
|
||||
// CHECK: %[[VAL_29:.*]] = cmpf "uno", %[[VAL_0]], %[[CST]] : f32
|
||||
// CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32
|
||||
// CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32
|
||||
// CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16
|
||||
// CHECK: return %[[VAL_32]] : f16
|
||||
%res = atan %arg : f16
|
||||
return %res : f16
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @add
|
||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
|
@ -477,7 +477,6 @@ cc_library(
|
||||
"transforms/default_quant_params.cc",
|
||||
"transforms/generated_post_quantize.inc",
|
||||
"transforms/generated_quantize.inc",
|
||||
"transforms/load_quantization_recipe.cc",
|
||||
"transforms/post_quantize.cc",
|
||||
"transforms/prepare_quantize.cc",
|
||||
"transforms/quantize.cc",
|
||||
@ -670,6 +669,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib",
|
||||
"//tensorflow/lite/kernels/internal:kernel_utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/schema:schema_utils",
|
||||
"//tensorflow/lite/tools/versioning",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
@ -706,6 +706,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/schema:schema_utils",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -75,6 +75,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_utils.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
#include "tensorflow/lite/tools/versioning/op_version.h"
|
||||
#include "tensorflow/lite/tools/versioning/runtime_version.h"
|
||||
|
@ -75,6 +75,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_utils.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using mlir::Builder;
|
||||
@ -271,18 +272,18 @@ StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
|
||||
return std::string("tfl.basic_lstm");
|
||||
}
|
||||
|
||||
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
auto builtin_code = tflite::GetBuiltinCode(&op_code);
|
||||
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
return std::string("tfl.custom");
|
||||
}
|
||||
if (op_code.builtin_code == tflite::BuiltinOperator_IF) {
|
||||
if (builtin_code == tflite::BuiltinOperator_IF) {
|
||||
return std::string("tf.If");
|
||||
}
|
||||
if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) {
|
||||
if (builtin_code == tflite::BuiltinOperator_WHILE) {
|
||||
return std::string("tf.While");
|
||||
}
|
||||
|
||||
llvm::StringRef op_name(
|
||||
tflite::EnumNameBuiltinOperator(op_code.builtin_code));
|
||||
llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
|
||||
return llvm::Twine("tfl.", op_name.lower()).str();
|
||||
}
|
||||
|
||||
@ -637,7 +638,8 @@ StatusOr<Operation*> ConvertOp(
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
auto builtin_code = tflite::GetBuiltinCode(&op_code);
|
||||
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
|
||||
auto status = mlir::CustomOptionsToAttributes(
|
||||
op_code.custom_code, op.custom_options, builder, loc, &attrs);
|
||||
if (!status.ok()) {
|
||||
|
@ -459,11 +459,13 @@ node {
|
||||
# CHECK-LABEL: {
|
||||
# CHECK: version: 3,
|
||||
# CHECK: operator_codes: [ {
|
||||
# CHECK: builtin_code: CONV_2D,
|
||||
# CHECK: version: 3
|
||||
# CHECK: deprecated_builtin_code: 3,
|
||||
# CHECK: version: 3,
|
||||
# CHECK: builtin_code: CONV_2D
|
||||
# CHECK: }, {
|
||||
# CHECK: builtin_code: RESHAPE,
|
||||
# CHECK: deprecated_builtin_code: 22,
|
||||
# CHECK: version: 1
|
||||
# CHECK: builtin_code: RESHAPE
|
||||
# CHECK: } ],
|
||||
# CHECK: subgraphs: [ {
|
||||
# CHECK: tensors: [ {
|
||||
|
@ -4,8 +4,9 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LSTM,
|
||||
// CHECK-NEXT: deprecated_builtin_code: 16,
|
||||
// CHECK-NEXT: version: 2
|
||||
// CHECK-NEXT: builtin_code: LSTM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -6,14 +6,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: MUL,
|
||||
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "MyCustomOp"
|
||||
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||
// CHECK-NEXT: custom_code: "MyCustomOp",
|
||||
// CHECK-NEXT: builtin_code: CUSTOM
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: DEQUANTIZE,
|
||||
// CHECK-NEXT: deprecated_builtin_code: 6,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: builtin_code: DEQUANTIZE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D,
|
||||
// CHECK-NEXT: deprecated_builtin_code: 4,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: DEQUANTIZE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 6,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: DEQUANTIZE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D,
|
||||
// CHECK-NEXT: version: 2
|
||||
// CHECK-NEXT: deprecated_builtin_code: 4,
|
||||
// CHECK-NEXT: version: 2,
|
||||
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,11 +5,13 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: MUL,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -6,8 +6,9 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: FAKE_QUANT,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 80,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: FAKE_QUANT
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -4,8 +4,9 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||
// CHECK-NEXT: custom_code: "FlexAddV2"
|
||||
// CHECK-NEXT: builtin_code: CUSTOM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
||||
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd",
|
||||
// CHECK-NEXT: builtin_code: CUSTOM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd"
|
||||
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||
// CHECK-NEXT: custom_code: "FlexAdd",
|
||||
// CHECK-NEXT: builtin_code: CUSTOM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,14 +5,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "FlexDiv"
|
||||
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||
// CHECK-NEXT: custom_code: "FlexDiv",
|
||||
// CHECK-NEXT: builtin_code: CUSTOM
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: FULLY_CONNECTED,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 9,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: FULLY_CONNECTED,
|
||||
// CHECK-NEXT: version: 2
|
||||
// CHECK-NEXT: deprecated_builtin_code: 9,
|
||||
// CHECK-NEXT: version: 2,
|
||||
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -3,8 +3,9 @@
|
||||
// CHECK: {
|
||||
// CHECK: version: 3,
|
||||
// CHECK: operator_codes: [ {
|
||||
// CHECK: builtin_code: CUSTOM,
|
||||
// CHECK: custom_code: "HashTableV2"
|
||||
// CHECK: deprecated_builtin_code: 32,
|
||||
// CHECK: custom_code: "HashTableV2",
|
||||
// CHECK: builtin_code: CUSTOM
|
||||
// CHECK: } ],
|
||||
// CHECK: subgraphs: [ {
|
||||
// CHECK: tensors: [ {
|
||||
|
@ -4,16 +4,19 @@
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LESS,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 58,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: LESS
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: IF,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 118,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: IF
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: MUL,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,11 +5,13 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LOGICAL_OR,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 84,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: LOGICAL_OR
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: LOGICAL_AND,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 86,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: LOGICAL_AND
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -4,8 +4,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LSTM,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 16,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: LSTM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -7,8 +7,9 @@ func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: LSTM,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 16,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: LSTM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,20 +5,25 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 99,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: MUL,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: DIV,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 42,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: DIV
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 47,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: EXP
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: NEG,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 59,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: NEG
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform<i8:f32, 0.1>>) -> tensor<3x!quant.uniform<i8:
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: MUL,
|
||||
// CHECK-NEXT: version: 2
|
||||
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||
// CHECK-NEXT: version: 2,
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform<i8:f32, 1.0>>) -> tensor<3x!quant.uniform<i8:
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: MUL,
|
||||
// CHECK-NEXT: version: 3
|
||||
// CHECK-NEXT: deprecated_builtin_code: 18,
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: builtin_code: MUL
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 1,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -3,8 +3,9 @@
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: CUSTOM,
|
||||
// CHECK-NEXT: custom_code: "NumericVerify"
|
||||
// CHECK-NEXT: deprecated_builtin_code: 32,
|
||||
// CHECK-NEXT: custom_code: "NumericVerify",
|
||||
// CHECK-NEXT: builtin_code: CUSTOM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -4,20 +4,25 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: QUANTIZE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 114,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: QUANTIZE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: CONV_2D,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 3,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: CONV_2D
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: RESHAPE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 22,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: RESHAPE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: SOFTMAX,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 25,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: SOFTMAX
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: DEQUANTIZE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 6,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: DEQUANTIZE
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -5,8 +5,9 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: RESHAPE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 22,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: RESHAPE
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -7,8 +7,9 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: SUB,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 41,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: SUB
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
|
@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: SVDF,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 27,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: SVDF
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) ->
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: SVDF,
|
||||
// CHECK-NEXT: version: 2
|
||||
// CHECK-NEXT: deprecated_builtin_code: 27,
|
||||
// CHECK-NEXT: version: 2,
|
||||
// CHECK-NEXT: builtin_code: SVDF
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -3,14 +3,17 @@
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: WHILE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 119,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: WHILE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: GREATER,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 61,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: GREATER
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: SUB,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 41,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: SUB
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
|
@ -4,8 +4,9 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: TRANSPOSE_CONV,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 67,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: TRANSPOSE_CONV
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -3,8 +3,9 @@
|
||||
// CHECK: {
|
||||
// CHECK: version: 3,
|
||||
// CHECK: operator_codes: [ {
|
||||
// CHECK: builtin_code: CUSTOM,
|
||||
// CHECK: custom_code: "SomeOperation"
|
||||
// CHECK: deprecated_builtin_code: 32,
|
||||
// CHECK: custom_code: "SomeOperation",
|
||||
// CHECK: builtin_code: CUSTOM
|
||||
// CHECK: } ],
|
||||
// CHECK: subgraphs: [ {
|
||||
// CHECK: tensors: [ {
|
||||
|
@ -4,8 +4,9 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 44,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 35,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
|
@ -3,14 +3,17 @@
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: WHILE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 119,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: WHILE
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: GREATER,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 61,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: GREATER
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: SUB,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: deprecated_builtin_code: 41,
|
||||
// CHECK-NEXT: version: 1,
|
||||
// CHECK-NEXT: builtin_code: SUB
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
|
@ -272,6 +272,22 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
// CHECK: return %[[RES]] : tensor<4x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fuseBroadcastMulIntoFullyConnected
|
||||
func @fuseBroadcastMulIntoFullyConnected(%arg0: tensor<1x10368xbf16>) -> tensor<32x1x256xbf16> {
|
||||
%cst_0 = constant dense<2.0> : tensor<256x10368xbf16>
|
||||
%cst_1 = constant unit
|
||||
%cst_2 = constant dense<3.0> : tensor<32x1x256xbf16>
|
||||
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {
|
||||
fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"
|
||||
} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16>
|
||||
%1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16>
|
||||
return %1 : tensor<32x1x256xbf16>
|
||||
|
||||
// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16>
|
||||
// CHECK: %[[V1:.*]] = "tfl.mul"(%[[V0]], {{.*}}) {{{.*}}} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16>
|
||||
// CHECK: return %[[V1]] : tensor<32x1x256xbf16>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs
|
||||
func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||
|
@ -139,6 +139,9 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
|
||||
mlir::PassManager* pass_manager) {
|
||||
// Explicitly disable dumping Op details on failures.
|
||||
module.getContext()->printOpOnDiagnostic(false);
|
||||
|
||||
// Register a warning handler only log to std out.
|
||||
mlir::ScopedDiagnosticHandler s(
|
||||
module.getContext(), [](mlir::Diagnostic& diag) {
|
||||
|
@ -416,6 +416,10 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// If we are broadcasting on the lhs then don't fold the multiply as it
|
||||
// would increase the amount of compute done by the fully connected op.
|
||||
if (mul_op.lhs().getType() != mul_op.getType()) return failure();
|
||||
|
||||
// Mul.
|
||||
DenseElementsAttr cst;
|
||||
Value constant_val = mul_op.rhs();
|
||||
|
@ -74,8 +74,8 @@ tool_names = [
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
|
||||
'tfjs-opt'
|
||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
|
||||
'xla-thunks-opt', 'tfjs-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -1394,6 +1394,7 @@ cc_library(
|
||||
":decode_constant_pass",
|
||||
":eval_util",
|
||||
":tensorflow",
|
||||
":tensorflow_traits",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
@ -1961,6 +1962,7 @@ cc_library(
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":export_tf_dialect_op",
|
||||
":export_utils",
|
||||
":tensorflow",
|
||||
":tensorflow_attributes",
|
||||
|
@ -297,6 +297,33 @@ Equivalent to np.angle.
|
||||
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AnonymousIteratorOp : TF_Op<"AnonymousIterator", []> {
|
||||
let summary = "A container for an iterator resource.";
|
||||
|
||||
let arguments = (ins
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_AnonymousIteratorV2Op : TF_Op<"AnonymousIteratorV2", []> {
|
||||
let summary = "A container for an iterator resource.";
|
||||
|
||||
let arguments = (ins
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle,
|
||||
TF_VariantTensor:$deleter
|
||||
);
|
||||
}
|
||||
|
||||
def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> {
|
||||
let summary = "";
|
||||
|
||||
@ -308,6 +335,21 @@ def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_AnonymousMultiDeviceIteratorOp : TF_Op<"AnonymousMultiDeviceIterator", []> {
|
||||
let summary = "A container for a multi device iterator resource.";
|
||||
|
||||
let arguments = (ins
|
||||
Confined<StrArrayAttr, [ArrayMinCount<1>]>:$devices,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle,
|
||||
TF_VariantTensor:$deleter
|
||||
);
|
||||
}
|
||||
|
||||
def TF_AnonymousRandomSeedGeneratorOp : TF_Op<"AnonymousRandomSeedGenerator", []> {
|
||||
let summary = "";
|
||||
|
||||
@ -2485,6 +2527,17 @@ is the same, though it is cleaner to use `tf.io.decode_image`.
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DeleteIteratorOp : TF_Op<"DeleteIterator", []> {
|
||||
let summary = "A container for an iterator resource.";
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorFree]>:$handle,
|
||||
TF_VariantTensor:$deleter
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> {
|
||||
let summary = "";
|
||||
|
||||
@ -2496,6 +2549,20 @@ def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> {
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_DeleteMultiDeviceIteratorOp : TF_Op<"DeleteMultiDeviceIterator", []> {
|
||||
let summary = "A container for an iterator resource.";
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorFree]>:$multi_device_iterator,
|
||||
Arg<Variadic<TF_ResourceTensor>, "", [TF_DatasetIteratorRead]>:$iterators,
|
||||
TF_VariantTensor:$deleter
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_DeleteRandomSeedGeneratorOp : TF_Op<"DeleteRandomSeedGenerator", []> {
|
||||
let summary = "";
|
||||
|
||||
@ -2719,6 +2786,19 @@ Computes the gradients of depthwise convolution with respect to the input.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_DeserializeIteratorOp : TF_Op<"DeserializeIterator", []> {
|
||||
let summary = [{
|
||||
Converts the given variant tensor to an iterator and stores it in the given resource.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$resource_handle,
|
||||
TF_VariantTensor:$serialized
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
|
||||
let summary = "Return the index of device the op runs.";
|
||||
|
||||
@ -4965,11 +5045,58 @@ tf.math.is_nan(x) ==> [False, True, False, True, False]
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IteratorOp : TF_Op<"Iterator", []> {
|
||||
let summary = "A container for an iterator resource.";
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$shared_name,
|
||||
StrAttr:$container,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_IteratorFromStringHandleOp : TF_Op<"IteratorFromStringHandle", []> {
|
||||
let summary = [{
|
||||
Converts the given string representing a handle to an iterator to a resource.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_StrTensor:$string_handle,
|
||||
|
||||
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$resource_handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_IteratorFromStringHandleV2Op : TF_Op<"IteratorFromStringHandleV2", []> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
TF_StrTensor:$string_handle,
|
||||
|
||||
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$resource_handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> {
|
||||
let summary = "Gets the next output from the given iterator .";
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$iterator
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -4980,6 +5107,74 @@ def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> {
|
||||
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IteratorGetNextAsOptionalOp : TF_Op<"IteratorGetNextAsOptional", []> {
|
||||
let summary = [{
|
||||
Gets the next output from the given iterator as an Optional variant.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator,
|
||||
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$optional
|
||||
);
|
||||
}
|
||||
|
||||
def TF_IteratorGetNextSyncOp : TF_Op<"IteratorGetNextSync", []> {
|
||||
let summary = "Gets the next output from the given iterator.";
|
||||
|
||||
let description = [{
|
||||
This operation is a synchronous version IteratorGetNext. It should only be used
|
||||
in situations where the iterator does not block the calling thread, or where
|
||||
the calling thread is not a member of the thread pool used to execute parallel
|
||||
operations (e.g. in eager mode).
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$components
|
||||
);
|
||||
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IteratorToStringHandleOp : TF_Op<"IteratorToStringHandle", []> {
|
||||
let summary = [{
|
||||
Converts the given `resource_handle` representing an iterator to a string.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$resource_handle
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_StrTensor:$string_handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_IteratorV2Op : TF_Op<"IteratorV2", []> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$shared_name,
|
||||
StrAttr:$container,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
|
||||
let summary = "L2 Loss.";
|
||||
|
||||
@ -5586,6 +5781,24 @@ A 2-D example:
|
||||
TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MakeIteratorOp : TF_Op<"MakeIterator", []> {
|
||||
let summary = [{
|
||||
Makes a new iterator from the given `dataset` and stores it in `iterator`.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
This operation may be executed multiple times. Each execution will reset the
|
||||
iterator in `iterator` to the first element of `dataset`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$dataset,
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$iterator
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
||||
let summary = [{
|
||||
Multiply the matrix "a" by the matrix "b".
|
||||
@ -6909,6 +7122,82 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MultiDeviceIteratorOp : TF_Op<"MultiDeviceIterator", []> {
|
||||
let summary = "Creates a MultiDeviceIterator resource.";
|
||||
|
||||
let arguments = (ins
|
||||
Confined<StrArrayAttr, [ArrayMinCount<1>]>:$devices,
|
||||
StrAttr:$shared_name,
|
||||
StrAttr:$container,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MultiDeviceIteratorFromStringHandleOp : TF_Op<"MultiDeviceIteratorFromStringHandle", []> {
|
||||
let summary = [{
|
||||
Generates a MultiDeviceIterator resource from its provided string handle.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_StrTensor:$string_handle,
|
||||
|
||||
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$multi_device_iterator
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MultiDeviceIteratorGetNextFromShardOp : TF_Op<"MultiDeviceIteratorGetNextFromShard", []> {
|
||||
let summary = "Gets next element for the provided shard number.";
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$multi_device_iterator,
|
||||
TF_Int32Tensor:$shard_num,
|
||||
TF_Int64Tensor:$incarnation_id
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$components
|
||||
);
|
||||
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MultiDeviceIteratorInitOp : TF_Op<"MultiDeviceIteratorInit", []> {
|
||||
let summary = "Initializes the multi device iterator with the given dataset.";
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$dataset,
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$multi_device_iterator,
|
||||
TF_Int64Tensor:$max_buffer_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Int64Tensor:$incarnation_id
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MultiDeviceIteratorToStringHandleOp : TF_Op<"MultiDeviceIteratorToStringHandle", []> {
|
||||
let summary = "Produces a string handle for the given MultiDeviceIterator.";
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$multi_device_iterator
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_StrTensor:$string_handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> {
|
||||
let summary = "Draws samples from a multinomial distribution.";
|
||||
|
||||
@ -7363,6 +7652,44 @@ output =
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_OneShotIteratorOp : TF_Op<"OneShotIterator", []> {
|
||||
let summary = [{
|
||||
Makes a "one-shot" iterator that can be iterated only once.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
A one-shot iterator bundles the logic for defining the dataset and
|
||||
the state of the iterator in a single op, which allows simple input
|
||||
pipelines to be defined without an additional initialization
|
||||
("MakeIterator") step.
|
||||
|
||||
One-shot iterators have the following limitations:
|
||||
|
||||
* They do not support parameterization: all logic for creating the underlying
|
||||
dataset must be bundled in the `dataset_factory` function.
|
||||
* They are not resettable. Once a one-shot iterator reaches the end of its
|
||||
underlying dataset, subsequent "IteratorGetNext" operations on that
|
||||
iterator will always produce an `OutOfRange` error.
|
||||
|
||||
For greater flexibility, use "Iterator" and "MakeIterator" to define
|
||||
an iterator using an arbitrary subgraph, which may capture tensors
|
||||
(including fed values) as parameters, and which may be reset multiple
|
||||
times by rerunning "MakeIterator".
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolRefAttr:$dataset_factory,
|
||||
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
|
||||
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
|
||||
StrAttr:$container,
|
||||
StrAttr:$shared_name
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
|
||||
let summary = "Enqueue multiple Tensor values on the computation outfeed.";
|
||||
|
||||
@ -10353,6 +10680,22 @@ Computes gradients for the scaled exponential linear (Selu) operation.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SerializeIteratorOp : TF_Op<"SerializeIterator", []> {
|
||||
let summary = [{
|
||||
Converts the given `resource_handle` representing an iterator to a variant tensor.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$resource_handle,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "0">:$external_state_policy
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$serialized
|
||||
);
|
||||
}
|
||||
|
||||
def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> {
|
||||
let summary = "Returns the shape of a tensor.";
|
||||
|
||||
@ -11409,7 +11752,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
|
||||
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = "Draws samples from a multinomial distribution.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -11427,7 +11770,82 @@ def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
|
||||
TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> {
|
||||
def TF_StatelessParameterizedTruncatedNormalOp : TF_Op<"StatelessParameterizedTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = "";
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TF_I32OrI64Tensor:$seed,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$means,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$stddevs,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$minvals,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$maxvals
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomBinomialOp : TF_Op<"StatelessRandomBinomial", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random numbers from a binomial distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Outputs random values from a binomial distribution.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `seed`, `counts`, and `probs`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TF_I32OrI64Tensor:$seed,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$counts,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$probs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomGammaV2Op : TF_Op<"StatelessRandomGammaV2", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random numbers from a gamma distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Outputs random values from a gamma distribution.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TF_I32OrI64Tensor:$seed,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$alpha
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom values from a normal distribution.
|
||||
}];
|
||||
@ -11452,7 +11870,34 @@ The outputs are a deterministic function of `shape` and `seed`.
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> {
|
||||
def TF_StatelessRandomPoissonOp : TF_Op<"StatelessRandomPoisson", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random numbers from a Poisson distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Outputs random values from a Poisson distribution.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `seed`, and `lam`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TF_I32OrI64Tensor:$seed,
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$lam
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Rtype = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random values from a uniform distribution.
|
||||
}];
|
||||
@ -11478,7 +11923,32 @@ The outputs are a deterministic function of `shape` and `seed`.
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> {
|
||||
def TF_StatelessRandomUniformFullIntOp : TF_Op<"StatelessRandomUniformFullInt", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The generated values are uniform integers covering the whole range of `dtype`.
|
||||
|
||||
The outputs are a deterministic function of `shape` and `seed`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_I32OrI64Tensor:$shape,
|
||||
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$seed
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||
}];
|
||||
@ -11505,7 +11975,7 @@ The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxv
|
||||
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> {
|
||||
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom values from a truncated normal distribution.
|
||||
}];
|
||||
|
@ -73,6 +73,9 @@ def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
|
||||
// certain state around within their implementations.
|
||||
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
|
||||
|
||||
// Trait to indicate an operation cannot be constant folded.
|
||||
def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
|
||||
|
||||
// Coefficient wise binary operation with implicit broadcasting support, for
|
||||
// example tf.Sub operation.
|
||||
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
|
||||
@ -112,6 +115,7 @@ def TF_SummaryResource : TF_ResourceBase<"Summary">;
|
||||
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
|
||||
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
|
||||
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
|
||||
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
|
||||
|
||||
def TF_VariableRead : MemRead<TF_VariableResource>;
|
||||
def TF_StackRead : MemRead<TF_StackResource>;
|
||||
@ -119,6 +123,7 @@ def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
|
||||
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_VariableWrite : MemWrite<TF_VariableResource>;
|
||||
def TF_StackWrite : MemWrite<TF_StackResource>;
|
||||
@ -127,6 +132,7 @@ def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
|
||||
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
|
||||
def TF_StackAlloc : MemAlloc<TF_StackResource>;
|
||||
@ -135,12 +141,14 @@ def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
|
||||
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
|
||||
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
|
||||
|
||||
def TF_StackFree : MemFree<TF_StackResource>;
|
||||
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
|
||||
def TF_SummaryFree : MemFree<TF_SummaryResource>;
|
||||
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
|
||||
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
|
||||
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorFlow op definitions
|
||||
|
@ -1446,7 +1446,8 @@ static LogicalResult Verify(SplitVOp op) {
|
||||
if (!split_sizes_type) return success();
|
||||
|
||||
if (split_sizes_type.getRank() != 1 ||
|
||||
split_sizes_type.getDimSize(0) != op.getNumResults())
|
||||
(split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize &&
|
||||
split_sizes_type.getDimSize(0) != op.getNumResults()))
|
||||
return op.emitOpError("split sizes should be a 1D tensor of ")
|
||||
<< op.getNumResults() << " elements";
|
||||
|
||||
|
@ -53,6 +53,10 @@ struct DatasetMemoryCache
|
||||
StringRef getName() final { return "DatasetMemoryCache"; }
|
||||
};
|
||||
|
||||
struct DatasetIterator : ::mlir::SideEffects::Resource::Base<DatasetIterator> {
|
||||
StringRef getName() final { return "DatasetIterator"; }
|
||||
};
|
||||
|
||||
} // namespace ResourceEffects
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -124,6 +124,10 @@ class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate> {
|
||||
}
|
||||
};
|
||||
|
||||
// Trait to indicate an operation cannot be constant folded.
|
||||
template <typename ConcreteType>
|
||||
class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold> {};
|
||||
|
||||
// Coefficient-wise binary operation with implicit broadcasting support, for
|
||||
// example tf.Sub operation.
|
||||
template <typename ConcreteType>
|
||||
|
@ -502,3 +502,12 @@ func @fold_conv() -> tensor<1x520x520x1xf32> {
|
||||
// CHECK: tf.Const
|
||||
// CHECK-NOT: tf.DepthwiseConv2dNative
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DontFoldNoConstantFold
|
||||
func @DontFoldNoConstantFold() -> tensor<8xf32> {
|
||||
%0 = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: tf.StatelessRandomUniform
|
||||
%2 = "tf.StatelessRandomUniform"(%0, %1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<8xf32>
|
||||
return %2 : tensor<8xf32>
|
||||
}
|
||||
|
@ -138,17 +138,17 @@ func @op_string_operand_string_result(%arg0: tensor<!tf.string>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation.
|
||||
// Test that operations inside tf.IfRegion op are corrected marked for outside
|
||||
// compilation.
|
||||
|
||||
// CHECK-LABEL: func @if_region_captured_string
|
||||
func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
||||
// CHECK-LABEL: func @ops_inside_tf_if_outside_compiled
|
||||
func @ops_inside_tf_if_outside_compiled(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.StringToNumber"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
|
||||
// CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.StringToNumber"
|
||||
// CHECK-SAME: _xla_outside_compilation
|
||||
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%2 = "tf.IfRegion"(%arg0) ( {
|
||||
%3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
|
||||
@ -163,7 +163,8 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) ->
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation.
|
||||
// Test that ops with string results/operands inside a tf.IfRegion branch are
|
||||
// marked for outside compilation.
|
||||
|
||||
// CHECK-LABEL: func @if_region_string_op
|
||||
func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
|
||||
@ -191,7 +192,8 @@ func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation.
|
||||
// Test that ops with string results/operands inside a nested tf.IfRegion branch
|
||||
// are marked for outside compilation.
|
||||
|
||||
// CHECK-LABEL: func @nested_if_region_string_op
|
||||
func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
|
||||
@ -231,16 +233,17 @@ func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> ten
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// Test that a tf.WhileRegion op with a captured string operand is marked for outside compilation.
|
||||
// Test that ops inside tf.WhileRegion op are correct marked for outside
|
||||
// compilation.
|
||||
|
||||
// CHECK-LABEL: func @while_region_captured_string
|
||||
func @while_region_captured_string(%arg0: tensor<i32>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
||||
// CHECK-LABEL: func @ops_inside_while_outside_compiled
|
||||
func @ops_inside_while_outside_compiled(%arg0: tensor<i32>, %arg1: tensor<!tf.string>) -> tensor<f32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf.WhileRegion"
|
||||
// CHECK: "tf.StringToNumber"
|
||||
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
|
||||
// CHECK: "tf.WhileRegion"
|
||||
// CHECK: "tf.StringToNumber"
|
||||
// CHECK-SAME: _xla_outside_compilation
|
||||
%1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%2:2 = "tf.WhileRegion"(%1, %arg0) ( {
|
||||
^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user