Resolved merge conflicts

This commit is contained in:
Mahmoud Abuzaina 2020-10-09 08:42:49 -07:00
commit f5762da2e7
890 changed files with 19616 additions and 6917 deletions

View File

@ -323,6 +323,8 @@ build:windows --copt=/experimental:preprocessor
build:windows --host_copt=/experimental:preprocessor
# Misc build options we need for windows.
build:windows --linkopt=/DEBUG
build:windows --host_linkopt=/DEBUG
build:windows --linkopt=/OPT:REF
build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF

View File

@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees
assignees:

28
.github/workflows/update-nightly.yml vendored Normal file
View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
name: Set nightly branch to master HEAD
jobs:
master-to-nightly:
runs-on: ubuntu-latest
steps:
- uses: zofrex/mirror-branch@v1
name: Set nightly branch to master HEAD
with:
target-branch: 'nightly'

View File

@ -1,4 +1,4 @@
# Release 2.4.0
h# Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
@ -209,8 +209,13 @@
* Improvements to Keras preprocessing layers:
* TextVectorization can now accept a vocabulary list or file as an
init arg.
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
accepts a `return_attention_scores` argument. When set to
True, the layer returns the attention scores as an additional output
argument.
* Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints
with the same implementation as their `tf.losses` equivalent.
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
@ -296,16 +301,21 @@
* `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268).
* `TensorRT`
* Add parameter allow_mixed_precision_on_unconverted_ops to
TrtConversionParams.
* `tf.print`:
* Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict`
didn't have the keys sorted, the keys and values were not being printed
in accordance with their correct mapping.
* Other:
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more
context. <ADD RELEASE NOTES HERE>
context.
* Add `tf.config.experimental.mlir_bridge_rollout` which will help us
rollout the new MLIR TPU bridge.
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors

View File

@ -545,7 +545,9 @@ TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
// TODO(b/170399182): Update test once an alternative to using the function
// optimization hook is in place.
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}

View File

@ -61,6 +61,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
return Status::OK();
}
@ -131,6 +132,37 @@ Status ExpGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// y = sqrt(inputs[0])
// return grad(y, {inputs[0]})
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto sqrt_output : sqrt_outputs) {
sqrt_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
@ -401,6 +433,50 @@ TEST_P(CppGradients, TestExpGrad) {
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSqrtGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = sqrt(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//

View File

@ -29,6 +29,7 @@ cc_library(
}),
deps = [
"//tensorflow/c:env",
"//tensorflow/c:logging",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//third_party/hadoop:hdfs",

View File

@ -22,9 +22,9 @@ limitations under the License.
#include <sstream>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/logging.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for HADOOP environments.
@ -148,15 +148,20 @@ class LibHDFS {
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
if (hdfs_home != nullptr) {
auto JoinPath = [](std::string home, std::string lib) {
#if defined(_WIN32)
if (home.back() != '\\') home.push_back('\\');
return home + "lib\\native\\" + lib;
#else
if (home.back() != '/') home.push_back('/');
return home + "lib/native/" + lib;
#endif
};
std::string path = JoinPath(hdfs_home, kLibHdfsDso);
TryLoadAndBind(path.c_str(), &handle_, status);
if (TF_GetCode(status) == TF_OK) {
return;
} else {
std::cerr << "HadoopFileSystem load error: " << TF_Message(status);
TF_Log(TF_FATAL, "HadoopFileSystem load error: %s", TF_Message(status));
}
}
@ -207,6 +212,7 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file,
builder, namenode.empty() ? "default" : namenode.c_str());
cacheKey += namenode;
}
absl::MutexLock l(&hadoop_file->connection_cache_lock);
if (hadoop_file->connection_cache.find(cacheKey) ==
hadoop_file->connection_cache.end()) {
auto cacheFs = libhdfs->hdfsBuilderConnect(builder);
@ -418,17 +424,20 @@ void Close(const TF_WritableFile* file, TF_Status* status) {
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(vnvo2409): Implement later
// Hadoop doesn't support Readonly Memory Region
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_hadoop_filesystem {
HadoopFile::HadoopFile(TF_Status* status)
: libhdfs(new LibHDFS(status)),
connection_cache_lock(),
connection_cache() {}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new HadoopFile({new LibHDFS(status), {}});
filesystem->plugin_filesystem = new HadoopFile(status);
if (TF_GetCode(status) != TF_OK) return;
TF_SetStatus(status, TF_OK, "");
}
@ -699,7 +708,9 @@ int GetChildren(const TF_Filesystem* filesystem, const char* path,
return num_entries;
}
// TODO(vnvo2409): Implement later
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
} // namespace tf_hadoop_filesystem
@ -707,6 +718,42 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_hadoop_filesystem::Init;
ops->filesystem_ops->cleanup = tf_hadoop_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_hadoop_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file =
tf_hadoop_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_hadoop_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_hadoop_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->path_exists = tf_hadoop_filesystem::PathExists;
ops->filesystem_ops->stat = tf_hadoop_filesystem::Stat;
ops->filesystem_ops->get_file_size = tf_hadoop_filesystem::GetFileSize;
ops->filesystem_ops->delete_file = tf_hadoop_filesystem::DeleteFile;
ops->filesystem_ops->create_dir = tf_hadoop_filesystem::CreateDir;
ops->filesystem_ops->delete_dir = tf_hadoop_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_hadoop_filesystem::RenameFile;
ops->filesystem_ops->get_children = tf_hadoop_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_hadoop_filesystem::TranslateName;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
#include "third_party/hadoop/hdfs.h"
@ -47,7 +48,10 @@ void Close(const TF_WritableFile* file, TF_Status* status);
namespace tf_hadoop_filesystem {
typedef struct HadoopFile {
LibHDFS* libhdfs;
std::map<std::string, hdfsFS> connection_cache;
absl::Mutex connection_cache_lock;
std::map<std::string, hdfsFS> connection_cache
ABSL_GUARDED_BY(connection_cache_lock);
HadoopFile(TF_Status* status);
} HadoopFile;
void Init(TF_Filesystem* filesystem, TF_Status* status);

View File

@ -24,6 +24,7 @@ using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::MatMul;
using tensorflow::ops::Mul;
using tensorflow::ops::SqrtGrad;
namespace tensorflow {
namespace gradients {
@ -72,6 +73,25 @@ class ExpGradientFunction : public GradientFunction {
AbstractTensorHandlePtr exp_;
};
class SqrtGradientFunction : public GradientFunction {
public:
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
sqrt->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
std::string name = "Sqrt_Grad";
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
return Status::OK();
}
~SqrtGradientFunction() override {}
private:
AbstractTensorHandlePtr sqrt_;
};
class MatMulGradientFunction : public GradientFunction {
public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
@ -210,5 +230,14 @@ BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -19,10 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -144,5 +144,33 @@ Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return exp_op->Execute(outputs, &num_retvals);
}
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
int num_retvals = 1;
Status s = sqrt_op->Execute(outputs, &num_retvals);
return s;
}
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
int num_retvals = 1;
Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
return s;
}
} // namespace ops
} // namespace tensorflow

View File

@ -50,6 +50,15 @@ Status DivNoNan(AbstractContext* ctx,
Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sqrt(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status SqrtGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -166,6 +166,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
@ -62,15 +64,53 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
}
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
const char* raw_device_name,
std::unique_ptr<Variable>* output) {
Status Variable::CreateUninitialized(
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, const char* raw_device_name,
const std::vector<std::string>& component_devices,
std::unique_ptr<Variable>* output) {
ImmediateTensorHandlePtr handle;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, raw_device_name, &handle));
if (component_devices.empty()) {
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, raw_device_name, &handle));
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();
}
if (!tensorflow::isa<EagerContext>(ctx)) {
return errors::InvalidArgument(
"Can only load distributed variables with EagerContext.");
}
EagerContext* eager_ctx = reinterpret_cast<EagerContext*>(ctx);
std::vector<TensorHandle*> handles;
for (const auto& device : component_devices) {
ImmediateTensorHandlePtr handlePtr;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, device.empty() ? nullptr : device.c_str(),
&handlePtr));
if (!tensorflow::isa<TensorHandle>(handlePtr.get())) {
return errors::Internal("Returned replica handle has unsupported type.");
}
handles.push_back(reinterpret_cast<TensorHandle*>(handlePtr.release()));
}
TensorHandle* packed_handle;
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
std::move(handles), eager_ctx, &packed_handle));
// The call to `CreatePackedHandle` incremented the handles' reference count,
// which we must now decrement to make the packed handle the owner of those
// handles. We can't loop through the `handles` vector because it was
// `std::move`d in the call above.
for (int i = 0; i != packed_handle->NumPackedHandles(); ++i) {
TensorHandle* component;
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &component));
component->Unref();
}
handle.reset(packed_handle);
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();

View File

@ -34,11 +34,11 @@ class Variable : public TensorHandleConvertible {
public:
// Creates an uninitialized resource variable. Note that a caller must
// call "assign" to associate a value with the variable.
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
const char* raw_device_name,
std::unique_ptr<Variable>* output);
static Status CreateUninitialized(
ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, const char* raw_device_name,
const std::vector<std::string>& component_devices,
std::unique_ptr<Variable>* output);
// The dtype of the underlying variable.
DataType dtype();

View File

@ -235,10 +235,17 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const std::string& name = variable.name();
tensorflow::TensorShape shape(variable.shape());
tensorflow::DataType dtype = variable.dtype();
std::vector<std::string> component_devices;
for (const auto& component :
variable.experimental_distributed_variable_components()) {
component_devices.push_back(component.device());
}
TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
ctx, dtype, shape, name,
variable.device().empty() ? nullptr : variable.device().c_str(), output));
variable.device().empty() ? nullptr : variable.device().c_str(),
component_devices, output));
return Status();
}

View File

@ -119,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
Status status;
std::unique_ptr<Variable> var;
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
absl::nullopt, nullptr, &var));
absl::nullopt, nullptr, {}, &var));
// Create a TensorHandle
ImmediateTensorHandlePtr expected_handle =

View File

@ -127,7 +127,7 @@ def tf_library(
"$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools = [tfcompile_tool],
exec_tools = [tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local = 1,
@ -242,7 +242,7 @@ def tf_library(
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
),
tools = [tfcompile_tool],
exec_tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
# Run tfcompile on the build host since it's typically faster on the
@ -281,7 +281,7 @@ def tf_library(
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags
),
tools = [tfcompile_tool],
exec_tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
local = 1,

View File

@ -84,6 +84,23 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name,
return Status::OK();
}
xla::StatusOr<std::vector<NodeDef>> MakeCallNodesFromAttribute(
const Node& node, absl::string_view attr_name,
absl::string_view call_name) {
std::vector<NameAttrList> attr_lists;
TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &attr_lists));
std::vector<NodeDef> out;
for (int i = 0; i < attr_lists.size(); i++) {
out.emplace_back();
NodeDef& inserted = out.back();
inserted.set_name(absl::StrCat(call_name, "_", i));
inserted.set_op(attr_lists[i].name());
*inserted.mutable_attr() = attr_lists[i].attr();
}
return out;
}
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
@ -227,6 +244,30 @@ bool RecursiveCompilabilityChecker::IsCompilableIf(
return is_compilable;
}
bool RecursiveCompilabilityChecker::IsCompilableCase(
const Node& case_node, FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes)
const {
xla::StatusOr<std::vector<NodeDef>> calls =
MakeCallNodesFromAttribute(case_node, "branches", "branch");
if (!calls.ok()) {
VLOG(2) << "Rejecting node " << case_node.name() << ": "
<< "missing attribute 'branches'";
return false;
}
bool is_compilable = true;
for (const NodeDef& call : *calls) {
is_compilable &=
IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes);
}
return is_compilable;
}
// Tests whether 'while_node' is a completely compilable loop.
// Every operator in the condition and body functions must be compilable for a
// while loop to be compilable.
@ -417,6 +458,13 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
return false;
}
if (op_filter_.require_always_compilable && node.IsCaseNode() &&
!IsCompilableCase(node, lib_runtime, stack_trace, encapsulating_function,
uncompilable_nodes)) {
LogNotCompilable(node, "unsupported case");
return false;
}
if (!op_filter_.allow_stateful_rng_ops &&
IsStatefulRandomOp(node.type_string())) {
absl::string_view uncompilable_reason = "stateful random op";

View File

@ -124,6 +124,10 @@ class RecursiveCompilabilityChecker {
// Whether ops known to have numerical accuracy issues should be considered
// compilable..
bool allow_inaccurate_ops = false;
// Require the function to be always compilable, regardless whether some
// control flow branches might be dead for a given input.
bool require_always_compilable = false;
};
RecursiveCompilabilityChecker(OperationFilter op_filter,
@ -211,6 +215,14 @@ class RecursiveCompilabilityChecker {
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const;
// Tests whether 'case_node' is compilable. Every operator in all branches
// must be compilable.
bool IsCompilableCase(const Node& case_node,
FunctionLibraryRuntime* lib_runtime,
std::vector<StackFrameView>* stack_trace,
NameAttrList* encapsulating_function,
UncompilableNodesMap* uncompilable_nodes) const;
// Returns compilability of node def retrieved from `node`'s attribute with
// name `attr_name`.
bool ExtractNodeDefAndCheckCompilability(

View File

@ -34,7 +34,16 @@ limitations under the License.
namespace tensorflow {
namespace {
AttrValue FuncListAttr(const absl::Span<const char* const> names) {
AttrValue attr;
for (const char* name : names) {
attr.mutable_list()->add_func()->set_name(name);
}
return attr;
}
constexpr char kFunctionalIfNodeName[] = "If";
constexpr char kFunctionalCaseNodeName[] = "Case";
constexpr char kFunctionalWhileNodeName[] = "While";
constexpr char kCompilableFunctionName[] = "CompilableFn";
constexpr char kCompilableFunctionNodeName[] = "n_c";
@ -76,8 +85,12 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
op_filter_.allow_inaccurate_ops = false;
op_filter_.allow_slow_ops = false;
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
device_type_);
checker_ = CreateCompilabilityChecker();
}
std::unique_ptr<RecursiveCompilabilityChecker> CreateCompilabilityChecker() {
return absl::make_unique<RecursiveCompilabilityChecker>(op_filter_,
device_type_);
}
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
@ -355,6 +368,57 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
"unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalCaseNode) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_c_uncompilable:float"},
/*Attributes*/ {},
// Node info
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionTwoName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_d_uncompilable:float"},
/*Attribute*/ {},
// Node info
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
auto branch_index = ops::Placeholder(root.WithOpName("pred"), DT_INT32);
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
std::vector<NodeBuilder::NodeOut> inputes(
{NodeBuilder::NodeOut(placeholder.node())});
Node* case_node;
TF_ASSERT_OK(
NodeBuilder(kFunctionalCaseNodeName, "Case", &root.graph()->flib_def())
.Input(branch_index.node())
.Input(inputes)
.Attr("branches", FuncListAttr({kUncompilableFunctionName,
kUncompilableFunctionTwoName}))
.Attr("Tout", {DT_INT32})
.Finalize(root.graph(), &case_node));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
auto case_node_it = std::find_if(
graph->nodes().begin(), graph->nodes().end(),
[&](const Node* n) { return n->name() == kFunctionalCaseNodeName; });
EXPECT_NE(case_node_it, graph->nodes().end());
auto* flib_runtime = GetFunctionLibraryRuntime();
op_filter_.require_always_compilable = false;
checker_ = CreateCompilabilityChecker();
EXPECT_TRUE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
op_filter_.require_always_compilable = true;
checker_ = CreateCompilabilityChecker();
EXPECT_FALSE(checker_->IsCompilableNode(**case_node_it, flib_runtime));
}
TEST_F(CompilabilityCheckUtilTest, TestCanNotTriggerXlaCompilation) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();

View File

@ -1196,10 +1196,14 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue;
}
if (!RecursiveCompilabilityChecker{
CreateOperationFilter(*registration),
DeviceType{registration->compilation_device_name}}
.IsCompilableNode(*node, lib_runtime)) {
RecursiveCompilabilityChecker::OperationFilter filter =
CreateOperationFilter(*registration);
filter.require_always_compilable = true;
RecursiveCompilabilityChecker checker(
filter, DeviceType{registration->compilation_device_name});
if (!checker.IsCompilableNode(*node, lib_runtime)) {
continue;
}

View File

@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
}
GraphDebugInfo debug_info;
std::vector<std::string> control_rets;
if (result_dtypes.empty()) {
control_rets.push_back(node_def.name());
}
return CompileGraphToXlaHlo(
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
#endif

View File

@ -29,7 +29,7 @@ LLVM_SRC=...
# Create basic workspace file
echo 'workspace(name = "llvm-project")' > $LLVM_SRC/WORKSPACE
# and over the bazel BUILD files.
# and copy over the bazel BUILD files.
cp third_party/llvm/llvm.autogenerated.BUILD $LLVM_SRC/llvm/BUILD
cp third_party/mlir/BUILD $LLVM_SRC/mlir
cp third_party/mlir/test.BUILD $LLVM_SRC/mlir/test/BUILD

View File

@ -48,6 +48,7 @@ filegroup(
"include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td",
"include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",

View File

@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
}];
}
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
HLO_FpOrComplexTensor> {
let summary = "Atan operator";
let description = [{
Returns `Atan(operand)` element-wise.
$$
\atan(x) = \atan2(x, 1)
$$
}];
}
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
HLO_FpOrComplexTensor> {
let summary = "Sinh operation";

View File

@ -353,7 +353,9 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp;
def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp;
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp {
let hasFolder = 1;
}
def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp;
@ -910,39 +912,12 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute",
let results = (outs HLO_Tensor);
}
// TODO(hinsu): Make this struct dialect independent so that it can be shared
// between HLO and LHLO dialect.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
let arguments = !con(
(ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs),
ConvolutionAttributes<HLO_Dialect>.attributes);
let results = (outs HLO_Tensor);
}

View File

@ -1007,6 +1007,42 @@ class BASE_HLO_ConcatenateOp {
}];
}
//===----------------------------------------------------------------------===//
// Common convolution attributes
//===----------------------------------------------------------------------===//
class ConvDimensionNumbersBase<Dialect dialect>
: StructAttr<"ConvDimensionNumbers", dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
class ConvolutionAttributes<Dialect dialect> {
dag attributes = (ins
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbersBase<dialect>:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}
class BASE_HLO_ConvOp {
string summary = "Convolution operator";

View File

@ -37,38 +37,13 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td"
def LHLO_Dialect : Dialect {
let name = "lmhlo";
let cppNamespace = "::mlir::lmhlo";
}
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
//===----------------------------------------------------------------------===//
// LMHLO nullary op definitions.
//===----------------------------------------------------------------------===//
@ -345,10 +320,11 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = [{
"modifies the offset, sizes and strides of a statically shaped memref.
modifies the offset, sizes and strides of a statically shaped memref
}];
let description = [{
Allows to modify the offset, sizes and strides of a statically shaped memref.
Casts the statically shaped memref operand to a memref with optionally
modified offsets, sizes and strides.
Example:
```mlir
@ -592,40 +568,13 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
);
}
// TODO(bondhugula): Make this struct dialect independent so that it can be
// shared between the HLO and LHLO dialects.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
ConvolutionAttributes<LHLO_Dialect>.attributes);
}
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp {

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LHLO_OPS_BASE
#define LHLO_OPS_BASE
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
//===----------------------------------------------------------------------===//
// LMHLO type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
#endif // LHLO_OPS_BASE

View File

@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
loc, result_types, args, b);
}
template <typename PredicateType>
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
return llvm::None;

View File

@ -15,9 +15,9 @@ limitations under the License.
include "mlir/Pass/PassBase.td"
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
let constructor = "createTestChloLegalizeToHloPass()";
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Legalize CHLO to HLO.";
let constructor = "createChloLegalizeToHloPass()";
}
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {

View File

@ -44,6 +44,9 @@ std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
/// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from the CHLO dialect to the HLO dialect.
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the
@ -63,7 +66,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
/// Lowers trigonometric operations from the standard dialect to approximations
// that do not use intrinsics.
/// that do not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>>
createLegalizeTrigonometricToApproximationPass();

View File

@ -22,7 +22,6 @@ limitations under the License.
namespace mlir {
namespace mhlo {
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();

View File

@ -2001,6 +2001,23 @@ struct divide<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); }
};
template <typename T>
struct remainder : std::modulus<T> {};
template <>
struct remainder<APInt> {
APInt operator()(const APInt& a, const APInt& b) const { return a.srem(b); }
};
template <>
struct remainder<APFloat> {
APFloat operator()(const APFloat& a, const APFloat& b) const {
APFloat result(a);
result.remainder(b);
return result;
}
};
template <typename T>
struct max {
T operator()(const T& a, const T& b) const { return std::max<T>(a, b); }
@ -2042,6 +2059,7 @@ BINARY_FOLDER(AddOp, std::plus);
BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(MulOp, std::multiplies);
BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(RemOp, remainder);
BINARY_FOLDER(MaxOp, max);
BINARY_FOLDER(MinOp, min);

View File

@ -27,8 +27,8 @@ namespace mhlo {
namespace {
struct TestChloLegalizeToHloPass
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
struct ChloLegalizeToHloPass
: public PassWrapper<ChloLegalizeToHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
}
@ -36,11 +36,12 @@ struct TestChloLegalizeToHloPass
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
// The conversion uses helpers from the Standard dialect.
// The conversion uses helpers from the standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
@ -56,8 +57,8 @@ struct TestChloLegalizeToHloPass
} // namespace
std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() {
return std::make_unique<TestChloLegalizeToHloPass>();
std::unique_ptr<FunctionPass> createChloLegalizeToHloPass() {
return std::make_unique<ChloLegalizeToHloPass>();
}
} // namespace mhlo

View File

@ -24,16 +24,17 @@ include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
//===----------------------------------------------------------------------===//
// Expand acos to MHLO dialect as follows:
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1
// acos(x) = 2 * atan2(sqrt(1 - x^2), (1 + x)) if x != -1
// = pi if x == -1
def : Pat<(HLOClient_AcosOp $input),
(HLO_SelectOp
(HLO_CompareOp $input,
(HLO_ConstantLike<"0"> $input),
(HLO_CompareOp
$input,
(HLO_ConstantLike<"-1"> $input),
HLO_COMPARISON_DIRECTION_NE
),
(HLO_MulOp
(HLO_ConstantLike<"2.0f"> $input),
(HLO_ConstantLike<"2"> $input),
(HLO_Atan2Op
(HLO_SqrtOp
(HLO_SubOp
@ -47,7 +48,16 @@ def : Pat<(HLOClient_AcosOp $input),
)
)
),
(HLO_ConstantLike<"M_PI"> $input))>;
(HLO_ConstantLike<"M_PI"> $input)
)>;
// Express `atan` as
// atan(x) = atan2(x, 1)
def : Pat<(HLOClient_AtanOp $input),
(HLO_Atan2Op
$input,
(HLO_ConstantLike<"1"> $input)
)>;
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
@ -95,4 +105,3 @@ def : Pat<(HLOClient_TanOp $input),
(HLO_SinOp $input),
(HLO_CosOp $input)
)>;

View File

@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
auto result_type = result.getType().dyn_cast<RankedTensorType>();
if (!result_type) {
result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects ranked results";
@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
Operation* op = result.getDefiningOp();
// Extract the required element out of the vector.
SmallVector<Value, 4> dynamic_operands;
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
if (shape_element.value() != ShapedType::kDynamicSize) continue;
Value index = rewriter->create<ConstantOp>(
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
shape_element.index()));
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
ValueRange{index});
Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
Value alloc_operand =
rewriter->create<ExtractElementOp>(loc, shape_operand, index);
if (!alloc_operand.getType().isIndex()) {
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
rewriter->getIndexType());
@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
dynamic_operands.push_back(alloc_operand);
}
// Insert in front of op to ensure sizes are available.
OpBuilder allocBuilder(op);
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
return alloc;
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
}
Value InsertAlloc(Location loc, OpResult result,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
auto result_type = result.getType().dyn_cast<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape()) {
result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects statically shaped results";
@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
SmallVector<Value, 1> results_shape;
auto status =
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
rewriter.replaceOp(
op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
return success();
}
};

View File

@ -32,8 +32,6 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
using mlir::PassRegistration;
namespace mlir {
namespace mhlo {
namespace {

View File

@ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<lmhlo::AddOp>,
PointwiseToLinalgConverter<lmhlo::AndOp>,
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
PointwiseToLinalgConverter<lmhlo::CeilOp>,
PointwiseToLinalgConverter<lmhlo::CompareOp>,
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
@ -932,6 +933,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,

View File

@ -235,6 +235,23 @@ class ApproximateAtan2Lowering
}
};
class ApproximateAtanLowering
: public ApproximateOnExtendedF32Lowering<AtanOp> {
public:
explicit ApproximateAtanLowering(MLIRContext *ctx)
: ApproximateOnExtendedF32Lowering<AtanOp>(ctx) {}
// Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation
// for the argument range [-1, 1].
Value emitApproximation(ValueRange args, Location loc,
PatternRewriter &rewriter) const override {
Value x = args.front();
assert(x.getType().isF32());
Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1));
return rewriter.create<Atan2Op>(loc, x, one);
}
};
struct LegalizeTrigonometricToApproximationPass
: public PassWrapper<LegalizeTrigonometricToApproximationPass,
FunctionPass> {
@ -257,6 +274,7 @@ void PopulateTrigonometricToApproximationPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
ApproximateAtanLowering,
ApproximateAtan2Lowering,
ApproximateTanhLowering>(context);
// clang-format on

View File

@ -37,7 +37,6 @@ limitations under the License.
using mlir::FunctionPass;
using mlir::OwningRewritePatternList;
using mlir::PassRegistration;
using mlir::PassWrapper;
namespace {

View File

@ -38,7 +38,6 @@ using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpRewritePattern;
using mlir::OwningRewritePatternList;
using mlir::PassRegistration;
using mlir::PassWrapper;
using mlir::PatternRewriter;
using mlir::RankedTensorType;

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
using mlir::FunctionPass;
using mlir::PassRegistration;
using mlir::PassWrapper;
namespace {

View File

@ -48,7 +48,7 @@ namespace {
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(TanOp) sep fn(AcosOp) sep fn(SinhOp)
fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -63,6 +63,24 @@ func @divide_fold_float() -> tensor<4xf64> {
return %2 : tensor<4xf64>
}
// CHECK-LABEL: remainder_fold_int
func @remainder_fold_int() -> tensor<4xi32> {
%0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32>
%1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32>
// CHECK: mhlo.constant dense<[2, 1, 0, 1]>
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>)
return %2 : tensor<4xi32>
}
// CHECK-LABEL: remainder_fold_float
func @remainder_fold_float() -> tensor<4xf32> {
%0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32>
%1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32>
// CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]>
%2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>)
return %2 : tensor<4xf32>
}
// CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64>

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt -mhlo-test-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-hlo-opt -chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics.

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
// Lower statically shaped `constant_like` to constant.
// CHECK-LABEL: @constant_like_static_shape

View File

@ -261,3 +261,120 @@ func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 {
%res = atan2 %arg0, %arg1 : f16
return %res : f16
}
// -----
// CHECK-LABEL: @atan_f64
func @atan_f64(%arg : f64) -> f64 {
// CHECK: atan
%res = atan %arg : f64
return %res : f64
}
// -----
// CHECK-LABEL: func @atan_f32
// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32
func @atan_f32(%arg : f32) -> f32 {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32
// CHECK: %[[VAL_1:.*]] = absf %arg0 : f32
// CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32
// CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32
// CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32
// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
// CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32
// CHECK: %[[VAL_26:.*]] = cmpf "oeq", %arg0, %[[CST_9]] : f32
// CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32
// CHECK: %[[VAL_28:.*]] = cmpf "uno", %arg0, %[[CST]] : f32
// CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32
// CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32
// CHECK: return %[[VAL_30]] : f32
%res = atan %arg : f32
return %res : f32
}
// -----
// CHECK-LABEL: @atan_f16
// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16
func @atan_f16(%arg : f16) -> f16 {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32
// CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32
// CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_3:.*]] = cmpf "ole", %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32
// CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32
// CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32
// CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32
// CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32
// CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32
// CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32
// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32
// CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32
// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32
// CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32
// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32
// CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32
// CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32
// CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32
// CHECK: %[[VAL_27:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_9]] : f32
// CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32
// CHECK: %[[VAL_29:.*]] = cmpf "uno", %[[VAL_0]], %[[CST]] : f32
// CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32
// CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32
// CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16
// CHECK: return %[[VAL_32]] : f16
%res = atan %arg : f16
return %res : f16
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt %s -mhlo-test-chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
// RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -mhlo-test-lower-complex | FileCheck %s
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {

View File

@ -477,7 +477,6 @@ cc_library(
"transforms/default_quant_params.cc",
"transforms/generated_post_quantize.inc",
"transforms/generated_quantize.inc",
"transforms/load_quantization_recipe.cc",
"transforms/post_quantize.cc",
"transforms/prepare_quantize.cc",
"transforms/quantize.cc",
@ -670,6 +669,7 @@ cc_library(
"//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_utils",
"//tensorflow/lite/tools/versioning",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
@ -706,6 +706,7 @@ cc_library(
"//tensorflow/core/platform:status",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_utils",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -75,6 +75,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/schema/schema_utils.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/tools/versioning/op_version.h"
#include "tensorflow/lite/tools/versioning/runtime_version.h"

View File

@ -75,6 +75,7 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/schema/schema_utils.h"
using llvm::ArrayRef;
using mlir::Builder;
@ -271,18 +272,18 @@ StatusOr<std::string> GetMlirOpName(const tflite::OperatorT& op,
return std::string("tfl.basic_lstm");
}
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
auto builtin_code = tflite::GetBuiltinCode(&op_code);
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
return std::string("tfl.custom");
}
if (op_code.builtin_code == tflite::BuiltinOperator_IF) {
if (builtin_code == tflite::BuiltinOperator_IF) {
return std::string("tf.If");
}
if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) {
if (builtin_code == tflite::BuiltinOperator_WHILE) {
return std::string("tf.While");
}
llvm::StringRef op_name(
tflite::EnumNameBuiltinOperator(op_code.builtin_code));
llvm::StringRef op_name(tflite::EnumNameBuiltinOperator(builtin_code));
return llvm::Twine("tfl.", op_name.lower()).str();
}
@ -637,7 +638,8 @@ StatusOr<Operation*> ConvertOp(
}
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) {
auto builtin_code = tflite::GetBuiltinCode(&op_code);
if (builtin_code == tflite::BuiltinOperator_CUSTOM) {
auto status = mlir::CustomOptionsToAttributes(
op_code.custom_code, op.custom_options, builder, loc, &attrs);
if (!status.ok()) {

View File

@ -459,11 +459,13 @@ node {
# CHECK-LABEL: {
# CHECK: version: 3,
# CHECK: operator_codes: [ {
# CHECK: builtin_code: CONV_2D,
# CHECK: version: 3
# CHECK: deprecated_builtin_code: 3,
# CHECK: version: 3,
# CHECK: builtin_code: CONV_2D
# CHECK: }, {
# CHECK: builtin_code: RESHAPE,
# CHECK: deprecated_builtin_code: 22,
# CHECK: version: 1
# CHECK: builtin_code: RESHAPE
# CHECK: } ],
# CHECK: subgraphs: [ {
# CHECK: tensors: [ {

View File

@ -4,8 +4,9 @@ func @main(tensor<1x384xf32>, tensor<1x96xf32>, tensor<384x480xf32>, tensor<384x
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LSTM,
// CHECK-NEXT: deprecated_builtin_code: 16,
// CHECK-NEXT: version: 2
// CHECK-NEXT: builtin_code: LSTM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -6,14 +6,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: MUL,
// CHECK-NEXT: deprecated_builtin_code: 18,
// CHECK-NEXT: version: 1
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "MyCustomOp"
// CHECK-NEXT: deprecated_builtin_code: 32,
// CHECK-NEXT: custom_code: "MyCustomOp",
// CHECK-NEXT: builtin_code: CUSTOM
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 47,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: DEQUANTIZE,
// CHECK-NEXT: deprecated_builtin_code: 6,
// CHECK-NEXT: version: 1
// CHECK-NEXT: builtin_code: DEQUANTIZE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D,
// CHECK-NEXT: deprecated_builtin_code: 4,
// CHECK-NEXT: version: 1
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,11 +5,13 @@ func @main(tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: DEQUANTIZE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 6,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: DEQUANTIZE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D,
// CHECK-NEXT: version: 2
// CHECK-NEXT: deprecated_builtin_code: 4,
// CHECK-NEXT: version: 2,
// CHECK-NEXT: builtin_code: DEPTHWISE_CONV_2D
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,11 +5,13 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: MUL,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 18,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 47,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -6,8 +6,9 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: FAKE_QUANT,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 80,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: FAKE_QUANT
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -4,8 +4,9 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: deprecated_builtin_code: 32,
// CHECK-NEXT: custom_code: "FlexAddV2"
// CHECK-NEXT: builtin_code: CUSTOM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexAdd"
// CHECK-NEXT: deprecated_builtin_code: 32,
// CHECK-NEXT: custom_code: "FlexAdd",
// CHECK-NEXT: builtin_code: CUSTOM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<4xf64>, tensor<4xf64>) -> tensor<4xf64> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexAdd"
// CHECK-NEXT: deprecated_builtin_code: 32,
// CHECK-NEXT: custom_code: "FlexAdd",
// CHECK-NEXT: builtin_code: CUSTOM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,14 +5,17 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: deprecated_builtin_code: 18,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexDiv"
// CHECK-NEXT: deprecated_builtin_code: 32,
// CHECK-NEXT: custom_code: "FlexDiv",
// CHECK-NEXT: builtin_code: CUSTOM
// CHECK-NEXT: }, {
// CHECK-NEXT: deprecated_builtin_code: 47,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: FULLY_CONNECTED,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 9,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: FULLY_CONNECTED,
// CHECK-NEXT: version: 2
// CHECK-NEXT: deprecated_builtin_code: 9,
// CHECK-NEXT: version: 2,
// CHECK-NEXT: builtin_code: FULLY_CONNECTED
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -3,8 +3,9 @@
// CHECK: {
// CHECK: version: 3,
// CHECK: operator_codes: [ {
// CHECK: builtin_code: CUSTOM,
// CHECK: custom_code: "HashTableV2"
// CHECK: deprecated_builtin_code: 32,
// CHECK: custom_code: "HashTableV2",
// CHECK: builtin_code: CUSTOM
// CHECK: } ],
// CHECK: subgraphs: [ {
// CHECK: tensors: [ {

View File

@ -4,16 +4,19 @@
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LESS,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 58,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: LESS
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: IF,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 118,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: IF
// CHECK-NEXT: }, {
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: MUL,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 18,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,11 +5,13 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LOGICAL_OR,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 84,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: LOGICAL_OR
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: LOGICAL_AND,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 86,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: LOGICAL_AND
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -4,8 +4,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LSTM,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 16,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: LSTM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -7,8 +7,9 @@ func @main(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: LSTM,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 16,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: LSTM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,20 +5,25 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 99,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: MUL,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 18,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: DIV,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 42,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: DIV
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 47,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: EXP
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: NEG,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 59,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: NEG
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform<i8:f32, 0.1>>) -> tensor<3x!quant.uniform<i8:
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: MUL,
// CHECK-NEXT: version: 2
// CHECK-NEXT: deprecated_builtin_code: 18,
// CHECK-NEXT: version: 2,
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<3x!quant.uniform<i8:f32, 1.0>>) -> tensor<3x!quant.uniform<i8:
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: MUL,
// CHECK-NEXT: version: 3
// CHECK-NEXT: deprecated_builtin_code: 18,
// CHECK-NEXT: version: 3,
// CHECK-NEXT: builtin_code: MUL
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 1,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: AVERAGE_POOL_2D
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -3,8 +3,9 @@
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "NumericVerify"
// CHECK-NEXT: deprecated_builtin_code: 32,
// CHECK-NEXT: custom_code: "NumericVerify",
// CHECK-NEXT: builtin_code: CUSTOM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -4,20 +4,25 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: QUANTIZE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 114,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: QUANTIZE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: CONV_2D,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 3,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: CONV_2D
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: RESHAPE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 22,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: RESHAPE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: SOFTMAX,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 25,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: SOFTMAX
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: DEQUANTIZE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 6,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: DEQUANTIZE
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -5,8 +5,9 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: RESHAPE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 22,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: RESHAPE
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -7,8 +7,9 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: SUB,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 41,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: SUB
// CHECK-NEXT: }, {
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],

View File

@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: SVDF,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 27,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: SVDF
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x i8>, tensor<4 x f32>, tensor<4 x f32>) ->
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: SVDF,
// CHECK-NEXT: version: 2
// CHECK-NEXT: deprecated_builtin_code: 27,
// CHECK-NEXT: version: 2,
// CHECK-NEXT: builtin_code: SVDF
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -3,14 +3,17 @@
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: WHILE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 119,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: WHILE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: GREATER,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 61,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: GREATER
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: SUB,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 41,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: SUB
// CHECK-NEXT: }, {
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],

View File

@ -4,8 +4,9 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: TRANSPOSE_CONV,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 67,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: TRANSPOSE_CONV
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -3,8 +3,9 @@
// CHECK: {
// CHECK: version: 3,
// CHECK: operator_codes: [ {
// CHECK: builtin_code: CUSTOM,
// CHECK: custom_code: "SomeOperation"
// CHECK: deprecated_builtin_code: 32,
// CHECK: custom_code: "SomeOperation",
// CHECK: builtin_code: CUSTOM
// CHECK: } ],
// CHECK: subgraphs: [ {
// CHECK: tensors: [ {

View File

@ -4,8 +4,9 @@ func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 44,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -4,8 +4,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 35,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {

View File

@ -3,14 +3,17 @@
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: WHILE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 119,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: WHILE
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: GREATER,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 61,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: GREATER
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: SUB,
// CHECK-NEXT: version: 1
// CHECK-NEXT: deprecated_builtin_code: 41,
// CHECK-NEXT: version: 1,
// CHECK-NEXT: builtin_code: SUB
// CHECK-NEXT: }, {
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],

View File

@ -272,6 +272,22 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
// CHECK: return %[[RES]] : tensor<4x2xf32>
}
// CHECK-LABEL: @fuseBroadcastMulIntoFullyConnected
func @fuseBroadcastMulIntoFullyConnected(%arg0: tensor<1x10368xbf16>) -> tensor<32x1x256xbf16> {
%cst_0 = constant dense<2.0> : tensor<256x10368xbf16>
%cst_1 = constant unit
%cst_2 = constant dense<3.0> : tensor<32x1x256xbf16>
%0 = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) {
fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"
} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16>
%1 = "tfl.mul"(%0, %cst_2) {fused_activation_function = "NONE"} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16>
return %1 : tensor<32x1x256xbf16>
// CHECK: %[[V0:.*]] = "tfl.fully_connected"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x10368xbf16>, tensor<256x10368xbf16>, none) -> tensor<1x256xbf16>
// CHECK: %[[V1:.*]] = "tfl.mul"(%[[V0]], {{.*}}) {{{.*}}} : (tensor<1x256xbf16>, tensor<32x1x256xbf16>) -> tensor<32x1x256xbf16>
// CHECK: return %[[V1]] : tensor<32x1x256xbf16>
}
// CHECK-LABEL: @fuseAddIntoFollowingFullyConnectedWithQDQs
func @fuseAddIntoFollowingFullyConnectedWithQDQs(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {

View File

@ -139,6 +139,9 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
bool emit_select_tf_ops, bool emit_custom_ops,
const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result,
mlir::PassManager* pass_manager) {
// Explicitly disable dumping Op details on failures.
module.getContext()->printOpOnDiagnostic(false);
// Register a warning handler only log to std out.
mlir::ScopedDiagnosticHandler s(
module.getContext(), [](mlir::Diagnostic& diag) {

View File

@ -416,6 +416,10 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
LogicalResult matchAndRewrite(TFL::MulOp mul_op,
PatternRewriter &rewriter) const override {
// If we are broadcasting on the lhs then don't fold the multiply as it
// would increase the amount of compute done by the fully connected op.
if (mul_op.lhs().getType() != mul_op.getType()) return failure();
// Mul.
DenseElementsAttr cst;
Value constant_val = mul_op.rhs();

View File

@ -74,8 +74,8 @@ tool_names = [
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_gpu_binary', 'xla-thunks-opt',
'tfjs-opt'
'hlo_to_llvm_ir', 'kernel-gen-opt', 'tf_to_kernel', 'tf_to_gpu_binary',
'xla-thunks-opt', 'tfjs-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -1394,6 +1394,7 @@ cc_library(
":decode_constant_pass",
":eval_util",
":tensorflow",
":tensorflow_traits",
":tensorflow_types",
"//tensorflow/c:tf_status",
"//tensorflow/c/eager:c_api",
@ -1961,6 +1962,7 @@ cc_library(
deps = [
":convert_tensor",
":convert_type",
":export_tf_dialect_op",
":export_utils",
":tensorflow",
":tensorflow_attributes",

View File

@ -297,6 +297,33 @@ Equivalent to np.angle.
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
}
def TF_AnonymousIteratorOp : TF_Op<"AnonymousIterator", []> {
let summary = "A container for an iterator resource.";
let arguments = (ins
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
);
}
def TF_AnonymousIteratorV2Op : TF_Op<"AnonymousIteratorV2", []> {
let summary = "A container for an iterator resource.";
let arguments = (ins
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle,
TF_VariantTensor:$deleter
);
}
def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> {
let summary = "";
@ -308,6 +335,21 @@ def TF_AnonymousMemoryCacheOp : TF_Op<"AnonymousMemoryCache", []> {
);
}
def TF_AnonymousMultiDeviceIteratorOp : TF_Op<"AnonymousMultiDeviceIterator", []> {
let summary = "A container for a multi device iterator resource.";
let arguments = (ins
Confined<StrArrayAttr, [ArrayMinCount<1>]>:$devices,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle,
TF_VariantTensor:$deleter
);
}
def TF_AnonymousRandomSeedGeneratorOp : TF_Op<"AnonymousRandomSeedGenerator", []> {
let summary = "";
@ -2485,6 +2527,17 @@ is the same, though it is cleaner to use `tf.io.decode_image`.
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_DeleteIteratorOp : TF_Op<"DeleteIterator", []> {
let summary = "A container for an iterator resource.";
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorFree]>:$handle,
TF_VariantTensor:$deleter
);
let results = (outs);
}
def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> {
let summary = "";
@ -2496,6 +2549,20 @@ def TF_DeleteMemoryCacheOp : TF_Op<"DeleteMemoryCache", []> {
let results = (outs);
}
def TF_DeleteMultiDeviceIteratorOp : TF_Op<"DeleteMultiDeviceIterator", []> {
let summary = "A container for an iterator resource.";
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorFree]>:$multi_device_iterator,
Arg<Variadic<TF_ResourceTensor>, "", [TF_DatasetIteratorRead]>:$iterators,
TF_VariantTensor:$deleter
);
let results = (outs);
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<1>;
}
def TF_DeleteRandomSeedGeneratorOp : TF_Op<"DeleteRandomSeedGenerator", []> {
let summary = "";
@ -2719,6 +2786,19 @@ Computes the gradients of depthwise convolution with respect to the input.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_DeserializeIteratorOp : TF_Op<"DeserializeIterator", []> {
let summary = [{
Converts the given variant tensor to an iterator and stores it in the given resource.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$resource_handle,
TF_VariantTensor:$serialized
);
let results = (outs);
}
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
let summary = "Return the index of device the op runs.";
@ -4965,11 +5045,58 @@ tf.math.is_nan(x) ==> [False, True, False, True, False]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_IteratorOp : TF_Op<"Iterator", []> {
let summary = "A container for an iterator resource.";
let arguments = (ins
StrAttr:$shared_name,
StrAttr:$container,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
);
}
def TF_IteratorFromStringHandleOp : TF_Op<"IteratorFromStringHandle", []> {
let summary = [{
Converts the given string representing a handle to an iterator to a resource.
}];
let arguments = (ins
TF_StrTensor:$string_handle,
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$resource_handle
);
}
def TF_IteratorFromStringHandleV2Op : TF_Op<"IteratorFromStringHandleV2", []> {
let summary = "";
let arguments = (ins
TF_StrTensor:$string_handle,
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$resource_handle
);
}
def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> {
let summary = "Gets the next output from the given iterator .";
let arguments = (ins
TF_ResourceTensor:$iterator
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator
);
let results = (outs
@ -4980,6 +5107,74 @@ def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> {
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
}
def TF_IteratorGetNextAsOptionalOp : TF_Op<"IteratorGetNextAsOptional", []> {
let summary = [{
Gets the next output from the given iterator as an Optional variant.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
TF_VariantTensor:$optional
);
}
def TF_IteratorGetNextSyncOp : TF_Op<"IteratorGetNextSync", []> {
let summary = "Gets the next output from the given iterator.";
let description = [{
This operation is a synchronous version IteratorGetNext. It should only be used
in situations where the iterator does not block the calling thread, or where
the calling thread is not a member of the thread pool used to execute parallel
operations (e.g. in eager mode).
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$iterator
);
let results = (outs
Variadic<TF_Tensor>:$components
);
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
}
def TF_IteratorToStringHandleOp : TF_Op<"IteratorToStringHandle", []> {
let summary = [{
Converts the given `resource_handle` representing an iterator to a string.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$resource_handle
);
let results = (outs
TF_StrTensor:$string_handle
);
}
def TF_IteratorV2Op : TF_Op<"IteratorV2", []> {
let summary = "";
let arguments = (ins
StrAttr:$shared_name,
StrAttr:$container,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
);
}
def TF_L2LossOp : TF_Op<"L2Loss", [NoSideEffect]> {
let summary = "L2 Loss.";
@ -5586,6 +5781,24 @@ A 2-D example:
TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>;
}
def TF_MakeIteratorOp : TF_Op<"MakeIterator", []> {
let summary = [{
Makes a new iterator from the given `dataset` and stores it in `iterator`.
}];
let description = [{
This operation may be executed multiple times. Each execution will reset the
iterator in `iterator` to the first element of `dataset`.
}];
let arguments = (ins
TF_VariantTensor:$dataset,
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$iterator
);
let results = (outs);
}
def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
let summary = [{
Multiply the matrix "a" by the matrix "b".
@ -6909,6 +7122,82 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MultiDeviceIteratorOp : TF_Op<"MultiDeviceIterator", []> {
let summary = "Creates a MultiDeviceIterator resource.";
let arguments = (ins
Confined<StrArrayAttr, [ArrayMinCount<1>]>:$devices,
StrAttr:$shared_name,
StrAttr:$container,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
);
}
def TF_MultiDeviceIteratorFromStringHandleOp : TF_Op<"MultiDeviceIteratorFromStringHandle", []> {
let summary = [{
Generates a MultiDeviceIterator resource from its provided string handle.
}];
let arguments = (ins
TF_StrTensor:$string_handle,
DefaultValuedAttr<TypeArrayAttr, "{}">:$output_types,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$multi_device_iterator
);
}
def TF_MultiDeviceIteratorGetNextFromShardOp : TF_Op<"MultiDeviceIteratorGetNextFromShard", []> {
let summary = "Gets next element for the provided shard number.";
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead, TF_DatasetIteratorWrite]>:$multi_device_iterator,
TF_Int32Tensor:$shard_num,
TF_Int64Tensor:$incarnation_id
);
let results = (outs
Variadic<TF_Tensor>:$components
);
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
TF_DerivedResultTypeListAttr output_types = TF_DerivedResultTypeListAttr<0>;
}
def TF_MultiDeviceIteratorInitOp : TF_Op<"MultiDeviceIteratorInit", []> {
let summary = "Initializes the multi device iterator with the given dataset.";
let arguments = (ins
TF_VariantTensor:$dataset,
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorWrite]>:$multi_device_iterator,
TF_Int64Tensor:$max_buffer_size
);
let results = (outs
TF_Int64Tensor:$incarnation_id
);
}
def TF_MultiDeviceIteratorToStringHandleOp : TF_Op<"MultiDeviceIteratorToStringHandle", []> {
let summary = "Produces a string handle for the given MultiDeviceIterator.";
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$multi_device_iterator
);
let results = (outs
TF_StrTensor:$string_handle
);
}
def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> {
let summary = "Draws samples from a multinomial distribution.";
@ -7363,6 +7652,44 @@ output =
}];
}
def TF_OneShotIteratorOp : TF_Op<"OneShotIterator", []> {
let summary = [{
Makes a "one-shot" iterator that can be iterated only once.
}];
let description = [{
A one-shot iterator bundles the logic for defining the dataset and
the state of the iterator in a single op, which allows simple input
pipelines to be defined without an additional initialization
("MakeIterator") step.
One-shot iterators have the following limitations:
* They do not support parameterization: all logic for creating the underlying
dataset must be bundled in the `dataset_factory` function.
* They are not resettable. Once a one-shot iterator reaches the end of its
underlying dataset, subsequent "IteratorGetNext" operations on that
iterator will always produce an `OutOfRange` error.
For greater flexibility, use "Iterator" and "MakeIterator" to define
an iterator using an arbitrary subgraph, which may capture tensors
(including fed values) as parameters, and which may be reset multiple
times by rerunning "MakeIterator".
}];
let arguments = (ins
SymbolRefAttr:$dataset_factory,
Confined<TypeArrayAttr, [ArrayMinCount<1>]>:$output_types,
Confined<TF_ShapeAttrArray, [ArrayMinCount<1>]>:$output_shapes,
StrAttr:$container,
StrAttr:$shared_name
);
let results = (outs
Res<TF_ResourceTensor, "", [TF_DatasetIteratorAlloc]>:$handle
);
}
def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
let summary = "Enqueue multiple Tensor values on the computation outfeed.";
@ -10353,6 +10680,22 @@ Computes gradients for the scaled exponential linear (Selu) operation.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_SerializeIteratorOp : TF_Op<"SerializeIterator", []> {
let summary = [{
Converts the given `resource_handle` representing an iterator to a variant tensor.
}];
let arguments = (ins
Arg<TF_ResourceTensor, "", [TF_DatasetIteratorRead]>:$resource_handle,
DefaultValuedAttr<I64Attr, "0">:$external_state_policy
);
let results = (outs
TF_VariantTensor:$serialized
);
}
def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> {
let summary = "Returns the shape of a tensor.";
@ -11409,7 +11752,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> {
);
}
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect, TF_NoConstantFold]> {
let summary = "Draws samples from a multinomial distribution.";
let arguments = (ins
@ -11427,7 +11770,82 @@ def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> {
TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> {
def TF_StatelessParameterizedTruncatedNormalOp : TF_Op<"StatelessParameterizedTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> {
let summary = "";
let arguments = (ins
TF_I32OrI64Tensor:$shape,
TF_I32OrI64Tensor:$seed,
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$means,
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$stddevs,
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$minvals,
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$maxvals
);
let results = (outs
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output
);
TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
}
def TF_StatelessRandomBinomialOp : TF_Op<"StatelessRandomBinomial", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random numbers from a binomial distribution.
}];
let description = [{
Outputs random values from a binomial distribution.
The outputs are a deterministic function of `shape`, `seed`, `counts`, and `probs`.
}];
let arguments = (ins
TF_I32OrI64Tensor:$shape,
TF_I32OrI64Tensor:$seed,
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$counts,
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$probs
);
let results = (outs
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output
);
TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomGammaV2Op : TF_Op<"StatelessRandomGammaV2", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random numbers from a gamma distribution.
}];
let description = [{
Outputs random values from a gamma distribution.
The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
}];
let arguments = (ins
TF_I32OrI64Tensor:$shape,
TF_I32OrI64Tensor:$seed,
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$alpha
);
let results = (outs
TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
}
def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom values from a normal distribution.
}];
@ -11452,7 +11870,34 @@ The outputs are a deterministic function of `shape` and `seed`.
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> {
def TF_StatelessRandomPoissonOp : TF_Op<"StatelessRandomPoisson", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random numbers from a Poisson distribution.
}];
let description = [{
Outputs random values from a Poisson distribution.
The outputs are a deterministic function of `shape`, `seed`, and `lam`.
}];
let arguments = (ins
TF_I32OrI64Tensor:$shape,
TF_I32OrI64Tensor:$seed,
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$lam
);
let results = (outs
TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
TF_DerivedOperandTypeAttr Rtype = TF_DerivedOperandTypeAttr<2>;
}
def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random values from a uniform distribution.
}];
@ -11478,7 +11923,32 @@ The outputs are a deterministic function of `shape` and `seed`.
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> {
def TF_StatelessRandomUniformFullIntOp : TF_Op<"StatelessRandomUniformFullInt", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random integers from a uniform distribution.
}];
let description = [{
The generated values are uniform integers covering the whole range of `dtype`.
The outputs are a deterministic function of `shape` and `seed`.
}];
let arguments = (ins
TF_I32OrI64Tensor:$shape,
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$seed
);
let results = (outs
TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom random integers from a uniform distribution.
}];
@ -11505,7 +11975,7 @@ The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxv
TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>;
}
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> {
def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect, TF_NoConstantFold]> {
let summary = [{
Outputs deterministic pseudorandom values from a truncated normal distribution.
}];

View File

@ -73,6 +73,9 @@ def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
// certain state around within their implementations.
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
// Trait to indicate an operation cannot be constant folded.
def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
// Coefficient wise binary operation with implicit broadcasting support, for
// example tf.Sub operation.
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
@ -112,6 +115,7 @@ def TF_SummaryResource : TF_ResourceBase<"Summary">;
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
def TF_VariableRead : MemRead<TF_VariableResource>;
def TF_StackRead : MemRead<TF_StackResource>;
@ -119,6 +123,7 @@ def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
def TF_VariableWrite : MemWrite<TF_VariableResource>;
def TF_StackWrite : MemWrite<TF_StackResource>;
@ -127,6 +132,7 @@ def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
def TF_StackAlloc : MemAlloc<TF_StackResource>;
@ -135,12 +141,14 @@ def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
def TF_StackFree : MemFree<TF_StackResource>;
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
def TF_SummaryFree : MemFree<TF_SummaryResource>;
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions

View File

@ -1446,7 +1446,8 @@ static LogicalResult Verify(SplitVOp op) {
if (!split_sizes_type) return success();
if (split_sizes_type.getRank() != 1 ||
split_sizes_type.getDimSize(0) != op.getNumResults())
(split_sizes_type.getDimSize(0) != ShapedType::kDynamicSize &&
split_sizes_type.getDimSize(0) != op.getNumResults()))
return op.emitOpError("split sizes should be a 1D tensor of ")
<< op.getNumResults() << " elements";

View File

@ -53,6 +53,10 @@ struct DatasetMemoryCache
StringRef getName() final { return "DatasetMemoryCache"; }
};
struct DatasetIterator : ::mlir::SideEffects::Resource::Base<DatasetIterator> {
StringRef getName() final { return "DatasetIterator"; }
};
} // namespace ResourceEffects
} // namespace TF
} // namespace mlir

View File

@ -124,6 +124,10 @@ class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate> {
}
};
// Trait to indicate an operation cannot be constant folded.
template <typename ConcreteType>
class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold> {};
// Coefficient-wise binary operation with implicit broadcasting support, for
// example tf.Sub operation.
template <typename ConcreteType>

View File

@ -502,3 +502,12 @@ func @fold_conv() -> tensor<1x520x520x1xf32> {
// CHECK: tf.Const
// CHECK-NOT: tf.DepthwiseConv2dNative
}
// CHECK-LABEL: DontFoldNoConstantFold
func @DontFoldNoConstantFold() -> tensor<8xf32> {
%0 = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<[2, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: tf.StatelessRandomUniform
%2 = "tf.StatelessRandomUniform"(%0, %1) : (tensor<1xi32>, tensor<2xi32>) -> tensor<8xf32>
return %2 : tensor<8xf32>
}

View File

@ -138,17 +138,17 @@ func @op_string_operand_string_result(%arg0: tensor<!tf.string>) -> tensor<i32>
return %0 : tensor<i32>
}
// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation.
// Test that operations inside tf.IfRegion op are corrected marked for outside
// compilation.
// CHECK-LABEL: func @if_region_captured_string
func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
// CHECK-LABEL: func @ops_inside_tf_if_outside_compiled
func @ops_inside_tf_if_outside_compiled(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
// CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.IfRegion"
// CHECK: "tf.StringToNumber"
// CHECK-NOT: _xla_outside_compilation
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
// CHECK: "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.IfRegion"
// CHECK: "tf.StringToNumber"
// CHECK-SAME: _xla_outside_compilation
%1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.IfRegion"(%arg0) ( {
%3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor<!tf.string>) -> tensor<f32>
@ -163,7 +163,8 @@ func @if_region_captured_string(%arg0: tensor<i1>, %arg1: tensor<!tf.string>) ->
return %0 : tensor<f32>
}
// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation.
// Test that ops with string results/operands inside a tf.IfRegion branch are
// marked for outside compilation.
// CHECK-LABEL: func @if_region_string_op
func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
@ -191,7 +192,8 @@ func @if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32
return %0 : tensor<f32>
}
// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation.
// Test that ops with string results/operands inside a nested tf.IfRegion branch
// are marked for outside compilation.
// CHECK-LABEL: func @nested_if_region_string_op
func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> tensor<f32> {
@ -231,16 +233,17 @@ func @nested_if_region_string_op(%arg0: tensor<i1>, %arg1: tensor<?xi32>) -> ten
return %0 : tensor<f32>
}
// Test that a tf.WhileRegion op with a captured string operand is marked for outside compilation.
// Test that ops inside tf.WhileRegion op are correct marked for outside
// compilation.
// CHECK-LABEL: func @while_region_captured_string
func @while_region_captured_string(%arg0: tensor<i32>, %arg1: tensor<!tf.string>) -> tensor<f32> {
// CHECK-LABEL: func @ops_inside_while_outside_compiled
func @ops_inside_while_outside_compiled(%arg0: tensor<i32>, %arg1: tensor<!tf.string>) -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.WhileRegion"
// CHECK: "tf.StringToNumber"
// CHECK: _xla_outside_compilation = "auto1", is_stateless = true
// CHECK: "tf.WhileRegion"
// CHECK: "tf.StringToNumber"
// CHECK-SAME: _xla_outside_compilation
%1 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%2:2 = "tf.WhileRegion"(%1, %arg0) ( {
^bb0(%carg0: tensor<f32>, %carg1: tensor<i32>):

Some files were not shown because too many files have changed in this diff Show More