Merge remote-tracking branch origin/upstream/master

This commit is contained in:
Måns Nilsson 2020-11-25 15:23:56 +01:00
commit 477e3022a8
1243 changed files with 27511 additions and 21386 deletions

View File

@ -1,7 +1,9 @@
If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance
issue or a feature request or a build issue or a documentation issue (for small
doc fixes please send a PR instead). 2. Make sure the Issue Template is filled
out. 3. The issue should be related to the repo it is created in.
If you open a GitHub Issue, here is our policy:
1. It must be a bug/performance issue or a feature request or a build issue or
a documentation issue (for small doc fixes please send a PR instead).
1. Make sure the Issue Template is filled out.
1. The issue should be related to the repo it is created in.
**Here's why we have this policy:** We want to focus on the work that benefits
the whole community, e.g., fixing bugs and adding features. Individual support

View File

@ -49,11 +49,13 @@
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* Added dynamic range quantization support for the BatchMatMul op.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with
`tf.GradientTape` inside a `tf.function`.
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
* `tf.summary`:
* New `tf.summary.graph` allows manual write of TensorFlow graph
@ -65,6 +67,19 @@
supported MSVC version to 16.4 (current: 16.8).
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
* TensorRT
* Removed the deprecated `session_config` parameter for the TF1-TRT
converter `TrtGraphConverter`. Previously, we issued a warning when the
value of the parameter is not None.
* The TF2-TRT converter `TrtGraphConverterV2` takes an object of class
TrtConversionParams as a parameter. Removed three deprecated fields from
this class: `rewriter_config_template`, `is_dynamic_op`, and
`max_batch_size`. Previously, we issued a warning when the value of
`rewriter_config_template` is not None. We issued an error when the
value of `is_dynamic_op` is not True. We didn't use the value for
`max_batch_size` for building TensorRT engines.
* Issue a warning when function get_tensorrt_rewriter_config is used.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:

View File

@ -55,16 +55,16 @@ NCCL_LIB_PATHS = [
# List of files to configure when building Bazel on Apple platforms.
APPLE_BAZEL_FILES = [
'tensorflow/lite/experimental/ios/BUILD',
'tensorflow/lite/experimental/objc/BUILD',
'tensorflow/lite/experimental/swift/BUILD',
'tensorflow/lite/ios/BUILD',
'tensorflow/lite/objc/BUILD',
'tensorflow/lite/swift/BUILD',
'tensorflow/lite/tools/benchmark/experimental/ios/BUILD'
]
# List of files to move when building for iOS.
IOS_FILES = [
'tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec',
'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec',
'tensorflow/lite/objc/TensorFlowLiteObjC.podspec',
'tensorflow/lite/swift/TensorFlowLiteSwift.podspec',
]

View File

@ -72,6 +72,14 @@ config_setting(
visibility = ["//visibility:public"],
)
# Config setting that disables the default logger, only logging
# to registered TFLogSinks
config_setting(
name = "no_default_logger",
define_values = {"no_default_logger": "true"},
visibility = ["//visibility:public"],
)
# Config setting for determining if we are building for Android.
config_setting(
name = "android",
@ -588,9 +596,11 @@ config_setting(
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
# Instead, please use public APIs or public build rules TF provides.
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
# TODO(b/173549186): Move Google-internal TF code out of learning/brain
package_group(
name = "internal",
packages = [
"//learning/brain/mlir/...",
"//learning/lib/ami/simple_ml/...",
"//tensorflow/...",
],
@ -730,6 +740,7 @@ tf_cc_shared_object(
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/c:kernels_hdrs",
"//tensorflow/c:logging",
"//tensorflow/c:ops_hdrs",
"//tensorflow/cc/saved_model:loader_lite_impl",
"//tensorflow/core/common_runtime:core_cpu_impl",

View File

@ -522,6 +522,7 @@ cc_library(
":tf_datatype",
":tf_status",
":tf_tensor",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
],
)
@ -542,13 +543,17 @@ tf_cuda_library(
] + select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
":tf_tensor",
"//tensorflow/stream_executor:stream",
"//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
"//tensorflow/c/experimental/stream_executor:stream_executor",
"//tensorflow/c/experimental/stream_executor:stream_executor_internal",
],
}),
)

View File

@ -51,6 +51,7 @@ tf_cuda_library(
":immediate_execution_context",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
":immediate_execution_distributed_manager",
":abstract_tensor_handle",
":tfe_context_internal",
":tfe_cancellation_manager_internal",
@ -70,6 +71,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
@ -119,6 +121,7 @@ filegroup(
"gradients.h",
"gradients_internal.h",
"immediate_execution_context.h",
"immediate_execution_distributed_manager.h",
"immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"tape.h",
@ -298,7 +301,7 @@ tf_cuda_cc_test(
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156
deps = [
":c_api_experimental",
":c_api_unified_internal",
@ -466,6 +469,7 @@ tf_cuda_cc_test(
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"nomac",
"no_cuda_asan", # b/173825513
],
deps = [
":abstract_tensor_handle",
@ -584,6 +588,19 @@ cc_library(
],
)
cc_library(
name = "immediate_execution_distributed_manager",
hdrs = ["immediate_execution_distributed_manager.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "immediate_execution_context",
hdrs = ["immediate_execution_context.h"],
@ -592,12 +609,14 @@ cc_library(
],
deps = [
":abstract_context",
":immediate_execution_distributed_manager",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],

View File

@ -21,16 +21,11 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/c/eager/abstract_tensor_handle.h"
// clang-format off
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
@ -39,59 +34,40 @@ limitations under the License.
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#endif
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/public/version.h"
// "tensorflow/core/platform/platform.h" must be included first before using
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h"
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
#endif // !IS_MOBILE_PLATFORM
using tensorflow::string;
namespace {
@ -100,611 +76,6 @@ string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
#if !defined(IS_MOBILE_PLATFORM)
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
const tensorflow::ServerDef& server_def) {
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
return false;
}
return server_def.default_session_config().SerializeAsString() ==
context->session_options().config.SerializeAsString();
}
tensorflow::Status AddRemoteDevicesToMgr(
const std::vector<string>& added_remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::mutex remote_devices_mu;
int num_added_workers = added_remote_workers.size();
tensorflow::BlockingCounter counter(num_added_workers);
std::vector<tensorflow::Status> statuses(num_added_workers);
for (int i = 0; i < num_added_workers; i++) {
tensorflow::NewRemoteDevices(
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
const tensorflow::Status& s,
std::vector<tensorflow::Device*>* devices) {
statuses[i] = s;
if (s.ok()) {
tensorflow::mutex_lock l(remote_devices_mu);
for (tensorflow::Device* d : *devices) {
remote_devices.emplace_back(d);
}
}
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_added_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
return tensorflow::Status::OK();
}
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
remote_device_mgr.get()));
*device_mgr = std::move(remote_device_mgr);
return tensorflow::Status::OK();
}
tensorflow::Status RemoveRemoteDevicesFromMgr(
const std::vector<string>& removed_remote_workers,
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
const std::vector<tensorflow::Device*> remote_devices =
(remote_device_mgr->ListDevices());
std::vector<tensorflow::Device*> devices_to_remove;
for (tensorflow::Device* d : remote_devices) {
for (const string& remote_worker : removed_remote_workers) {
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
d->name())) {
devices_to_remove.emplace_back(d);
break;
}
}
}
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
return tensorflow::Status::OK();
}
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
const string& local_worker,
std::vector<string>* remote_workers) {
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(server);
if (grpc_server == nullptr) {
return tensorflow::errors::Internal(
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
}
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
remote_workers->erase(
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
remote_workers->end());
return tensorflow::Status::OK();
}
void DifferentiateWorkerLists(const std::vector<string>* current_list,
const std::vector<string>* new_list,
std::vector<string>* added,
std::vector<string>* removed,
std::vector<string>* existing) {
// Get STL set_difference and set_intersection with one list traversal.
// Similar to the set_difference library function, the input lists
// (`current_list` and `new_list`) must be sorted before calling the function.
added->resize(new_list->size());
removed->resize(current_list->size());
existing->resize(current_list->size());
std::vector<string>::const_iterator curr_it = current_list->begin();
std::vector<string>::const_iterator new_it = new_list->begin();
std::vector<string>::iterator added_it = added->begin();
std::vector<string>::iterator removed_it = removed->begin();
std::vector<string>::iterator existing_it = existing->begin();
while (curr_it != current_list->end() && new_it != new_list->end()) {
if (*curr_it < *new_it) {
*removed_it++ = *curr_it++;
} else if (*curr_it > *new_it) {
*added_it++ = *new_it++;
} else {
*existing_it++ = *curr_it++;
new_it++;
}
}
removed_it = std::copy(curr_it, current_list->end(), removed_it);
added_it = std::copy(new_it, new_list->end(), added_it);
added->resize(added_it - added->begin());
removed->resize(removed_it - removed->begin());
existing->resize(existing_it - existing->begin());
}
tensorflow::Status GetReplacedFromExistingWorkers(
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* client_cache,
std::vector<string>* replaced_workers) {
tensorflow::BlockingCounter counter(existing_workers->size());
std::vector<tensorflow::Status> statuses(existing_workers->size());
tensorflow::eager::KeepAliveRequest request;
request.set_context_id(context_id);
std::vector<tensorflow::eager::KeepAliveResponse> responses(
existing_workers->size());
for (int i = 0; i < existing_workers->size(); i++) {
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] =
client_cache->GetClient(existing_workers->at(i), &eager_client);
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
eager_client->KeepAliveAsync(
&request, &responses[i],
[i, &statuses, &counter](const tensorflow::Status& s) {
statuses[i] = s;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < existing_workers->size(); i++) {
// If the RPC fails (indicating that the requested ID doesn't exist on
// remote), or the returned view ID is not equal to the local one
// (indicating that the remote worker has a stale view of cluster), treat
// the worker as replaced.
if (!statuses[i].ok() ||
responses[i].context_view_id() != context_view_id) {
replaced_workers->emplace_back(existing_workers->at(i));
}
}
return tensorflow::Status::OK();
}
tensorflow::Status CreateRemoteContexts(
TFE_Context* ctx, const std::vector<string>& remote_workers,
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const bool lazy_copy_remote_function_inputs,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
statuses[i] = tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
counter.DecrementCount();
continue;
}
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs(
lazy_copy_remote_function_inputs);
eager_client->CreateContextAsync(
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
tensorflow::StatusGroup sg;
for (int i = 0; i < num_remote_workers; i++) {
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
sg.Update(statuses[i]);
}
}
return sg.as_summary_status();
}
tensorflow::Status UpdateRemoteContexts(
TFE_Context* ctx, const std::vector<string>& remote_workers,
const std::vector<string>& added_workers,
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
int cluster_device_count = base_request.cluster_device_attributes_size();
std::unordered_set<string> added_or_removed(added_workers.begin(),
added_workers.end());
std::copy(removed_workers.begin(), removed_workers.end(),
std::inserter(added_or_removed, added_or_removed.end()));
// Whether each device is in the updated (added or removed) workers
std::vector<bool> device_added_or_removed(cluster_device_count);
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
const auto& da = base_request.cluster_device_attributes().at(i);
tensorflow::DeviceNameUtils::ParsedName pn;
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
string task_name;
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
if (added_or_removed.find(task_name) != added_or_removed.end()) {
device_added_or_removed[i] = true;
}
}
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
statuses[i] = tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
counter.DecrementCount();
continue;
}
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
std::vector<bool> filtered_device_mask;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
// If any of the devices that match the device filters are in the set of
// added or removed workers, we must send a complete UpdateContextRequest.
// Otherwise, only send a simple request to increment context view ID.
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
std::transform(device_added_or_removed.begin(),
device_added_or_removed.end(), filtered_device_mask.begin(),
added_or_removed_filtered_devices.begin(),
std::logical_and<bool>());
const bool full_update_request =
std::accumulate(added_or_removed_filtered_devices.begin(),
added_or_removed_filtered_devices.end(), false,
std::logical_or<bool>());
tensorflow::eager::UpdateContextRequest request;
auto* response = new tensorflow::eager::UpdateContextResponse();
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
if (full_update_request) {
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
for (int i = 0; i < cluster_device_count; i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
}
eager_client->UpdateContextAsync(
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return tensorflow::Status::OK();
}
tensorflow::Status UpdateTFE_ContextWithServerDef(
int keep_alive_secs, const tensorflow::ServerDef& server_def,
TFE_Context* ctx, bool reset_context) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const ::tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
string worker_name =
tensorflow::strings::StrCat("/job:", server_def.job_name(),
"/replica:0/task:", server_def.task_index());
// List of current remote workers before updating server_def. Unused if
// resetting the server_def.
std::vector<string> curr_remote_workers;
// List of updated remote workers.
std::vector<string> remote_workers;
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
const tensorflow::DeviceMgr* device_mgr =
AreLocalDevicesCompatible(context, server_def)
? context->local_device_mgr()
: nullptr;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
server_def, {device_mgr}, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not.
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
}
tensorflow::uint64 context_id = context->GetContextId();
tensorflow::uint64 context_view_id = context->GetContextViewId();
if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId();
context_view_id = 0;
// Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
}
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// For cluster update, use a status group to aggregate statuses from
// * adding and removing remote devices
// * creating remote contexts on newly added workers
// * updating remote contexts on existing workers
// * updating the master context
// Note that we should not return immediately on errors in the middle of these
// updates to prevent cluster from having inconsistent context views.
//
// Unused if `reset_context` is True.
tensorflow::StatusGroup sg;
// When updating an existing context, populate the following lists with:
// * added_workers: set(remote_workers) - set(curr_remote_workers)
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
// * replaced_workers: workers with the same task names and potentially the
// same `hostname:port`s, but replaced by different processes
std::vector<string> added_workers;
std::vector<string> removed_workers;
std::vector<string> existing_workers;
std::vector<string> replaced_workers;
// New remote device manager created for new server_def. Unused if updating
// server_def.
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, grpc_server->master_env()->worker_cache,
&new_remote_device_mgr));
remote_device_mgr = new_remote_device_mgr.get();
} else {
context->ClearCachesAndDefaultExecutor();
// TODO(b/143914772): Potential memory leak if rendezvous has pending
// tensors for removed / replaced workers.
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
"Updating context with an invalid set of remote devices."));
}
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
std::sort(remote_workers.begin(), remote_workers.end());
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
&added_workers, &removed_workers,
&existing_workers);
sg.Update(GetReplacedFromExistingWorkers(
&existing_workers, context_id, context->GetContextViewId(), server_def,
remote_eager_workers.get(), &replaced_workers));
if (VLOG_IS_ON(1)) {
VLOG(1) << "Updating cluster with following changes";
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
for (const string& w : removed_workers)
VLOG(1) << " Removed worker " << w;
for (const string& w : replaced_workers)
VLOG(1) << " Replaced worker " << w;
}
if (!replaced_workers.empty()) {
// Treat replaced workers as removed then added back, so that we recreate
// remote devices and contexts, and re-register functions on those workers
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
replaced_workers.end());
added_workers.insert(added_workers.end(), replaced_workers.begin(),
replaced_workers.end());
for (const string& w : replaced_workers) {
existing_workers.erase(
std::remove(existing_workers.begin(), existing_workers.end(), w),
existing_workers.end());
}
}
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
sg.Update(AddRemoteDevicesToMgr(added_workers,
grpc_server->master_env()->worker_cache,
remote_device_mgr));
}
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendezvous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
// Initialize remote eager workers.
if (reset_context) {
const tensorflow::Status s = CreateRemoteContexts(
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request);
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
// the CreateRemoteContexts to fail. We currently only log instead of
// directly returning the error, since returning here will cause the server
// object to be destroyed (which currently CHECK-fails). The client will
// see additional errors if ops are subsequently sent to the failed workers.
if (TF_PREDICT_FALSE(!s.ok())) {
LOG(ERROR) << "Error when creating contexts on remote targets: "
<< s.error_message()
<< "\nExecuting remote ops or functions on these remote "
"targets will fail.";
}
} else {
if (sg.ok()) {
// Create remote contexts on the newly added workers only if the master
// has collected all device information from them (i.e., the
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
// GetAllRemoteDevices can still fail even with RPCs configured to wait
// until the remote workers to become alive. If the master creates remote
// contexts on the workers whose devices are still not collected, those
// workers will be treated as existing workers subsequently, so the master
// will never get devices from them even with retrying UpdateServerDef.
sg.Update(CreateRemoteContexts(
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
}
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
VLOG(1) << "Updating cluster with existing worker " << w;
}
}
// The master's context_view_id will be incremented by one in the
// UpdateRemoteMaster call later. We want existing workers to also have
// the updated context_view_id, so we must set their context_view_id to
// the master's current context_view_id + 1.
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
removed_workers, context_id,
context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
}
}
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
if (reset_context) {
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
// Initialize remote tensor communication based on worker session.
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, context,
worker_session.get());
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, context);
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
std::move(remote_mgr)));
// NOTE: We start the server after all other initialization, because the
// GrpcServer cannot be destroyed after it is started.
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
} else {
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
session_name, server_def, base_request.cluster_device_attributes(),
/*isolate_session_state=*/true));
sg.Update(context->UpdateRemoteMaster(context_id,
std::move(remote_eager_workers),
added_workers, removed_workers));
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
}
#undef LOG_AND_RETURN_IF_ERROR
return tensorflow::Status::OK();
}
#endif // !IS_MOBILE_PLATFORM
} // namespace
extern "C" {
@ -731,11 +102,21 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
opts->async);
#if !defined(IS_MOBILE_PLATFORM)
tfrt_context->SetDistributedManager(
std::make_unique<tfrt::tf::DistributedManagerContextInterface>(
tfrt_context->GetCoreRuntime()->GetHostContext()));
#endif // !IS_MOBILE_PLATFORM
return tensorflow::wrap(tfrt_context);
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
#endif
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
}
std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
@ -747,13 +128,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return tensorflow::wrap(new tensorflow::EagerContext(
tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
/*device_mgr_owned*/ true, r));
/*device_mgr_owned*/ true, r);
#if !defined(IS_MOBILE_PLATFORM)
eager_context->SetDistributedManager(
std::make_unique<tensorflow::EagerContextDistributedManager>(
eager_context));
#endif // !IS_MOBILE_PLATFORM
return tensorflow::wrap(eager_context);
}
void TFE_DeleteContext(TFE_Context* ctx) {
@ -791,26 +177,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker = remote_prefix + std::to_string(task_index);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =
context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
}
}
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/true);
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
server_def, /*reset_context=*/true, keep_alive_secs);
#endif // !IS_MOBILE_PLATFORM
}
@ -835,14 +204,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id.");
}
if (server_def.has_cluster_device_filters()) {
LOG(WARNING) << "Device filters can only be specified when initializing "
"the cluster. Any changes in device filters are ignored "
"when updating the server def.";
}
// TODO(haoyuzhang): Check server_def compatibility before the update
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/false);
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
server_def, /*reset_context=*/false, keep_alive_secs);
#endif // !IS_MOBILE_PLATFORM
}
@ -854,44 +218,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
"TFE_ContextSetServerDef not supported on mobile");
return false;
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
status->status =
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
return false;
}
tensorflow::WorkerInterface* wi =
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
if (wi == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Unable to find worker interface corresponding to task ", worker_name);
return false;
}
tensorflow::GetStatusRequest request;
tensorflow::GetStatusResponse response;
tensorflow::Status remote_status;
tensorflow::Notification done;
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
[&remote_status, &done](const tensorflow::Status& s) {
remote_status = s;
done.Notify();
});
done.WaitForNotification();
// We set OK status so the call does not raise any exceptions. Instead, caller
// users the return value to tell if the remote worker is alive.
status->status = tensorflow::Status::OK();
if (remote_status.ok()) {
return true;
}
LOG(INFO) << "Remote worker " << worker_name
<< " is not alive: " << remote_status.error_message();
return false;
bool is_alive;
status->status =
tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
worker_name, &is_alive);
return is_alive;
#endif // !IS_MOBILE_PLATFORM
}

View File

@ -226,7 +226,7 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change.
// These APIs are mainly to facilitate testing and are subject to change.
namespace internal {
Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/config.pb.h"
@ -174,6 +176,18 @@ class ImmediateExecutionContext : public AbstractContext {
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
ImmediateExecutionTensorHandle* handle) = 0;
//===--------------------------------------------------------------------===//
// Distributed runtime related functions.
//===--------------------------------------------------------------------===//
#if !defined(IS_MOBILE_PLATFORM)
// Set a distributed manager that helps set up, update, and check liveness
// of member tasks in the cluster.
virtual void SetDistributedManager(
std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
#endif // !IS_MOBILE_PLATFORM
protected:
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class ImmediateExecutionContext;
class ServerDef;
class ImmediateExecutionDistributedManager {
public:
virtual ~ImmediateExecutionDistributedManager() {}
// Set up distributed execution environment on local and remote tasks.
// When `reset_context` is true, initialize new cluster context state based on
// cluster configurations provided in `server_def`; otherwise, update existing
// context state with the provided `server_def`.
// Contexts created on remote tasks will be considered stale and garbage
// collected after `keep_alive_secs` of inactivity.
virtual Status SetOrUpdateServerDef(const ServerDef& server_def,
bool reset_context,
int keep_alive_secs) = 0;
// Check if the remote task is alive.
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
bool* is_alive) = 0;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/util/abstract_stack_trace.h"
#include "tensorflow/core/util/managed_stack_trace.h"
struct TFE_Op;
@ -48,10 +48,10 @@ class ImmediateExecutionOperation : public AbstractOperation {
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Set stack trace to be used for potential async error reporting.
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
virtual void SetStackTrace(ManagedStackTrace stack_trace) = 0;
// Returns the stack trace set by `SetStackTrace` if exists.
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
virtual absl::optional<ManagedStackTrace> GetStackTrace() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {

View File

@ -7,10 +7,21 @@ load(
"tf_cc_test",
)
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup")
package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "headers",
srcs = [
"stream_executor.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
cc_library(
name = "stream_executor_hdrs",
hdrs = ["stream_executor.h"],
@ -49,9 +60,11 @@ cc_library(
"stream_executor.h",
"stream_executor_internal.h",
],
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/stream_executor:executor_cache",
"//tensorflow/stream_executor/lib",
],

View File

@ -24,7 +24,6 @@ limitations under the License.
#include <string>
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
@ -44,6 +43,7 @@ using tensorflow::StatusFromTF_Status;
namespace stream_executor {
using tensorflow::StringPiece;
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
namespace {
@ -188,41 +188,6 @@ port::Status ValidateSEPlatformRegistrationParams(
}
#undef VALIDATE_MEMBER
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
class CStream : public internal::StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
return s;
}
void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
SP_Stream Handle() { return stream_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};
// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
@ -237,82 +202,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
}
}
class CEvent : public internal::EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status Record(SP_Stream stream_handle) {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (event_handle_ != nullptr) {
stream_executor_->destroy_event(device_, event_handle_);
event_handle_ = nullptr;
}
}
SP_Event Handle() { return event_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Event event_handle_;
};
class CTimer : public internal::TimerInterface {
public:
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns)
: device_(device),
stream_executor_(stream_executor),
timer_handle_(nullptr),
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
port::Status Create() {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (timer_handle_ != nullptr) {
stream_executor_->destroy_timer(device_, timer_handle_);
timer_handle_ = nullptr;
}
}
SP_Timer Handle() { return timer_handle_; }
uint64 Microseconds() const override {
return timer_fns_->nanoseconds(timer_handle_) / 1000;
}
uint64 Nanoseconds() const override {
return timer_fns_->nanoseconds(timer_handle_);
}
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Timer timer_handle_;
SP_TimerFns* timer_fns_;
};
// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
@ -321,14 +210,12 @@ SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
device_memory_base.opaque = const_cast<void*>(mem->opaque());
device_memory_base.size = mem->size();
device_memory_base.payload = mem->payload();
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
return device_memory_base;
}
DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
DeviceMemoryBase base(mem.opaque, mem.size);
base.SetPayload(mem.payload);
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
return base;
}
@ -426,7 +313,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
LOG(ERROR) << status.error_message();
return absl::nullopt;
}
// TODO(annarev): validate SP_AllocatorStats.
::stream_executor::AllocatorStats stats;
stats.num_allocs = c_stats.num_allocs;
stats.bytes_in_use = c_stats.bytes_in_use;

View File

@ -140,8 +140,9 @@ typedef enum SE_EventStatus {
// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57
typedef struct SP_DeviceMemoryBase {
size_t struct_size;
void* ext; // free-form data set by plugin
void* ext; // Reserved for future use
// Platform-dependent value representing allocated memory.
// Note that the pointer does not have to be to the virtual address itself.
void* opaque;
uint64_t size; // Size in bytes of this allocation.
uint64_t payload; // Value for plugin's use

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform.h"
@ -37,6 +38,13 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
// testing).
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
struct TFStatusDeleter {
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
};
// This file implements core stream executor base classes in terms of
// the C API defined in stream_executor.h. A class "CSomething" represents a
// "Something" that can be manipulated via calls in the C interface.
class CPlatform : public Platform {
public:
explicit CPlatform(SP_Platform platform,
@ -83,5 +91,111 @@ class CPlatform : public Platform {
stream_executor::ExecutorCache executor_cache_;
};
class CStream : public internal::StreamInterface {
public:
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
stream_handle_(nullptr) {}
~CStream() override { Destroy(); }
port::Status Create() {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
return s;
}
void Destroy() {
if (stream_handle_ != nullptr) {
stream_executor_->destroy_stream(device_, stream_handle_);
stream_handle_ = nullptr;
}
}
SP_Stream Handle() { return stream_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Stream stream_handle_;
};
class CEvent : public internal::EventInterface {
public:
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
: device_(device),
stream_executor_(stream_executor),
event_handle_(nullptr) {}
~CEvent() override { Destroy(); }
port::Status Create() {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->create_event(device_, &event_handle_, c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
}
port::Status Record(SP_Stream stream_handle) {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->record_event(device_, stream_handle, event_handle_,
c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (event_handle_ != nullptr) {
stream_executor_->destroy_event(device_, event_handle_);
event_handle_ = nullptr;
}
}
SP_Event Handle() { return event_handle_; }
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Event event_handle_;
};
class CTimer : public internal::TimerInterface {
public:
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
SP_TimerFns* timer_fns)
: device_(device),
stream_executor_(stream_executor),
timer_handle_(nullptr),
timer_fns_(timer_fns) {}
~CTimer() override { Destroy(); }
port::Status Create() {
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
return tensorflow::StatusFromTF_Status(c_status.get());
}
void Destroy() {
if (timer_handle_ != nullptr) {
stream_executor_->destroy_timer(device_, timer_handle_);
timer_handle_ = nullptr;
}
}
SP_Timer Handle() { return timer_handle_; }
uint64 Microseconds() const override {
return timer_fns_->nanoseconds(timer_handle_) / 1000;
}
uint64 Nanoseconds() const override {
return timer_fns_->nanoseconds(timer_handle_);
}
private:
SP_Device* device_;
SP_StreamExecutor* stream_executor_;
SP_Timer timer_handle_;
SP_TimerFns* timer_fns_;
};
} // namespace stream_executor
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_

View File

@ -24,7 +24,13 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
// Required for IS_MOBILE_PLATFORM definition
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
@ -168,6 +174,35 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
TF_SetStatus(status, TF_OK, "");
}
// This function is only for pluggable device.
// It will return nullptr in all other cases.
// This function is experimental and subject to change.
SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"Accessing device stream is not supported on mobile. File a bug at "
"https://github.com/tensorflow/tensorflow/issues if this feature is "
"important to you");
return nullptr;
#else
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
status->status = tensorflow::errors::FailedPrecondition(
"Accessing device stream is not supported for a CPU device.");
return nullptr;
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
status->status = tensorflow::errors::FailedPrecondition(
"Accessing device stream is only supported for pluggable devices.");
return nullptr;
} else { // Is a PluggableDevice
TF_SetStatus(status, TF_OK, "");
auto c_stream = static_cast<stream_executor::CStream*>(
cc_ctx->op_device_context()->stream()->implementation());
return c_stream->Handle();
}
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
}
int TF_NumInputs(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return cc_ctx->num_inputs();

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
@ -65,6 +66,11 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
typedef struct TF_OpKernelContext TF_OpKernelContext;
// TF_InitKernel to do op/kernel registration.
// Plugin should implement TF_InitKernel to register kernels. This function
// should register all kernels in a plugin.
void TF_InitKernel();
// Allocates a new kernel builder and returns a pointer to it.
//
// If non-null, TensorFlow will call create_func when it needs to instantiate
@ -128,6 +134,16 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
// --------------------------------------------------------------------------
// OpKernelContext routines
// TF_GetStream returns the SP_Stream available in ctx.
// This function returns a stream only for devices registered using the
// StreamExecutor C API
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
// nullptr and set error status in all other cases.
// Experimental: this function doesn't have compatibility guarantees and subject
// to change at any time.
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx,
TF_Status* status);
// TF_NumInputs returns the number of inputs available in ctx.
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);

View File

@ -49,14 +49,12 @@ Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) {
constexpr char longTagParam[] = "LONGTAG____________________________";
constexpr float largeValueParam = 2352352.2623433;
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
void BM_ScalarSummary##name##device(int iters) { \
testing::StopTiming(); \
TensorShape tensorshape(DIMARGS dims); \
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
testing::StartTiming(); \
test::Benchmark("cpu", g).Run(iters); \
} \
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
void BM_ScalarSummary##name##device(::testing::benchmark::State& state) { \
TensorShape tensorshape(DIMARGS dims); \
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state); \
} \
BENCHMARK(BM_ScalarSummary##name##device);
BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2);

View File

@ -378,6 +378,23 @@ template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx);
REGISTER_OP("StreamOp").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestStream) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
SP_Stream stream = TF_GetStream(ctx, s);
// Stream is always null if device is not a pluggable device. More test
// cases will be added when pluggable device mechanism is supported.
EXPECT_EQ(stream, nullptr);
EXPECT_NE(TF_OK, TF_GetCode(s));
TF_DeleteStatus(s);
};
SetupOp("StreamOp", "StreamOp", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
}
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {

View File

@ -115,8 +115,8 @@ genrule(
# have control of the full GPU.
cmd = "CUDA_VISIBLE_DEVICES='' " +
"$(location :make_test_graphs) --out_dir $(@D)",
exec_tools = [":make_test_graphs"],
tags = ["manual"],
tools = [":make_test_graphs"],
)
tf_library(

View File

@ -127,7 +127,7 @@ def tf_library(
"$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
exec_tools = [tfcompile_tool],
tools = [tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local = 1,
@ -162,7 +162,7 @@ def tf_library(
"//tensorflow/python/tools:freeze_graph)" +
freeze_args
),
exec_tools = ["//tensorflow/python/tools:freeze_graph"],
tools = ["//tensorflow/python/tools:freeze_graph"],
tags = tags,
)
tfcompile_graph = freeze_file
@ -242,7 +242,7 @@ def tf_library(
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
),
exec_tools = [tfcompile_tool],
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
# Run tfcompile on the build host since it's typically faster on the
@ -281,7 +281,7 @@ def tf_library(
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags
),
exec_tools = [tfcompile_tool],
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
local = 1,

View File

@ -39,7 +39,7 @@ struct XlaAutoJitFlag {
int32 optimization_level_general;
};
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
// is:
// <number>: sets general and single_gpu setting to the provided number.
// single-gpu(<number>): sets the single_gpu setting to the provided number.

View File

@ -103,7 +103,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
if (flr->config_proto()) {
config_proto = *flr->config_proto();
}
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
MlirBridgeRolloutPolicy policy =
GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>

View File

@ -449,15 +449,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
xla::Shape output_host_shape = output.on_host_shape();
xla::Shape output_device_shape = output.on_device_shape();
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output, &output_host_shape, &output_device_shape));
stream, &output, &output_device_shape));
output.set_shapes(output_host_shape, output_device_shape);
output.set_shapes(output_device_shape, output_device_shape);
for (int i = 0; i < ctx->num_outputs(); ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes.push_back(shape);

View File

@ -51,10 +51,10 @@ filegroup(
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td",
],
)
@ -464,6 +464,7 @@ cc_library(
":hlo",
":lhlo",
":lhlo_gpu",
"@llvm-project//mlir:IR",
],
)
@ -499,6 +500,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
@ -637,10 +639,12 @@ cc_library(
":lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:ViewLikeInterface",
],
@ -664,6 +668,7 @@ cc_library(
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
@ -698,10 +703,12 @@ cc_library(
deps = [
":cycle_detector",
":hlo",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
@ -732,6 +739,7 @@ cc_library(
deps = [
":hlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -752,6 +760,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,
@ -769,6 +778,8 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
@ -787,6 +798,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
@ -827,9 +839,11 @@ cc_library(
":hlo",
":lower_complex_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,

View File

@ -417,6 +417,32 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
let hasCanonicalizer = 1;
}
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
[NoSideEffect, SameOperandsAndResultShape],
HLO_FpTensor> {
let summary = "Erfc operator";
let description = [{
Computes the Gauss error function of `x` element-wise.
erf(x) = erf_impl(x) if |x| < 1
= 1 - erfc_impl(x) otherwise
}];
}
def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
[NoSideEffect, SameOperandsAndResultShape],
HLO_FpTensor> {
let summary = "Erfc operator";
let description = [{
Computes an approximation of the error function complement (1 - erf(x)).
erfc(x) = erfc_impl(x) if |x| > 1
= 1 - erf_impl(x) otherwise
}];
}
//===----------------------------------------------------------------------===//
// Broadcasting compare op
//===----------------------------------------------------------------------===//

View File

@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
let results = (outs HLO_StaticShapeTensor);
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomHLOConverter = 1;
}

View File

@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
Arg<I32Buffer, "", [MemWrite]>:$info,
BoolAttr:$is_upper);
BoolAttr:$is_lower);
}
#endif // LHLO_GPU_OPS

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"

View File

@ -197,10 +197,11 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO
//===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way.
def LHLO_ReduceOp: LHLO_Op<"reduce", [
SameVariadicOperandSize,
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_ReduceOp {
//
// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
// TODO(timshen): cleanup lmhlo.TerminatorOp.
def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
@ -655,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let builders = [
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
];
let extraClassDeclaration = [{
SmallVector<Value, 4> getInputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load.memref());
});
return buffers;
}
SmallVector<Value, 4> getOutputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.memref());
});
return buffers;
}
SmallVector<Value, 4> getFusionParameters() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load);
});
return buffers;
}
SmallVector<Value, 4> getFusionResults() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.tensor());
});
return buffers;
}
}];
}
def TerminatorOp :

View File

@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
return {};
}
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
context);
}
//===----------------------------------------------------------------------===//
// Case Op
//===----------------------------------------------------------------------===//

View File

@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
def ShapeOfDynamicReshape : Pat<
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
(replaceWithValue $shape)>;
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
def IdentityBroadcastReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
(replaceWithValue $input),
[(HasSameType $input, $op)]>;
def IdentityBroadcastInDimReshape : Pat<
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
(replaceWithValue $input),
[(HasSameType $input, $op)]>;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
@ -578,9 +579,8 @@ struct HloLegalizeToLhlo
});
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
lmhlo::CopyOp>(
&context, converter, patterns);
populateFuncOpTypeConversionPattern(patterns, &context, converter);
populateCallOpTypeConversionPattern(patterns, &context, converter);
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
patterns, target);
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -91,6 +91,31 @@ class LhloFuseLinalgPass
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
auto alias = tensor_load.memref();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
auto alias = tensor_to_memref.tensor();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
auto alias = tensor_cast.source();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto regionInterface =
@ -142,10 +167,10 @@ class LhloFuseLinalgPass
// Fuse producers of tiled linalg ops.
llvm::SmallDenseSet<Operation*> erase_set;
SmallVector<Operation*, 8> linalg_ops;
SmallVector<LinalgOp, 8> linalg_ops;
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
for (auto* op : llvm::reverse(linalg_ops)) {
for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
for (LinalgOp op : llvm::reverse(linalg_ops)) {
for (unsigned id = 0, e = op.getNumInputs(); id < e; ++id) {
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
if (auto info = fuseProducerOfBuffer(b, op, id, graph)) {

View File

@ -49,8 +49,9 @@ namespace {
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AtanOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) \
sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = llvm::APInt(const llvm::APInt&);
// TODO(hinsu): Correctly handle unsigned element types.
bool is_bool = old_type.isInteger(1);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Int -> Float
return elements.mapValues(
new_type, llvm::function_ref<func_type>([&newFloatType](
new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
const llvm::APInt& intVal) {
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
llvm::APFloat newDouble(static_cast<double>(val));
bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &loses_info);
@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
// new_type is Integer
// Int -> Int
return elements.mapValues(
new_type,
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
return llvm::APInt(bit_width, intVal.getSExtValue());
new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
const llvm::APInt& intVal) {
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
return llvm::APInt(bit_width, val);
}));
}

View File

@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
// CHECK-SAME: ]> : tensor<4x5xi32>
}
// CHECK-LABEL: @identity_broadcast_reshape
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast"(%arg0) {
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32>
}
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
// CHECK: return %arg0 : tensor<128xf32>
}

View File

@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
// -----
// CHECK-LABEL: func @const_bool_f32
func @const_bool_f32() -> tensor<2xf32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xf32>
}
// -----
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i16> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
// -----
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<i64> {
// CHECK-LABEL: func @const_bool_widening
func @const_bool_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
// -----
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<2xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: func @const_negative_int_widening
func @const_negative_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>

View File

@ -372,3 +372,58 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// PLOOP: else
// PLOOP: memref_reshape
// PLOOP: scf.yield
// -----
// Confirm that tiling information is passed through tensor_load, tensor_cast
// and memref_to_tensor operations.
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
-> memref<?xf32> {
%c1 = constant 1 : index
%1 = alloc() : memref<32xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = tensor_load %1 : memref<32xf32>
%3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
%4 = tensor_to_memref %3 : memref<?xf32>
return %4 : memref<?xf32>
}
// CHECK-LABEL: func @tensor_ops
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: tensor_load
// CHECK: tensor_cast
// CHECK: tensor_to_memref
// TILED-LABEL: func @tensor_ops
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: tensor_load
// TILED: tensor_cast
// TILED: tensor_to_memref
// PLOOP-LABEL: func @tensor_ops
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: tensor_load
// PLOOP: tensor_cast
// PLOOP: tensor_to_memref

View File

@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
%scratch = alloc() : memref<32xi8>
%info = alloc() : memref<32xi32>
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
return
}

View File

@ -35,9 +35,9 @@ filegroup(
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)
@ -667,6 +667,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:logging",
@ -704,6 +705,7 @@ cc_library(
":convert_type",
":flatbuffer_tflite_operator_lib",
":tensorflow_lite",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",

View File

@ -167,7 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
case 32:
return tflite::TensorType_INT32;
case 64:
return tflite::TensorType_INT64;
return itype.isUnsigned() ? tflite::TensorType_UINT64
: tflite::TensorType_INT64;
}
} else if (auto q_uniform_type =
type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
@ -386,19 +387,23 @@ class Translator {
// internal error.
static Optional<std::string> Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper);
private:
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper)
: module_(module),
name_mapper_(*op_or_arg_name_mapper),
builder_(kInitialBufferSize),
saved_model_tags_(saved_model_tags) {
saved_model_tags_(saved_model_tags),
select_user_tf_ops_(select_user_tf_ops) {
// The first buffer must be empty according to the schema definition.
empty_buffer_ = tflite::CreateBuffer(builder_);
buffers_.push_back(empty_buffer_);
@ -574,6 +579,8 @@ class Translator {
// Set of saved model tags, if any.
const std::unordered_set<std::string> saved_model_tags_;
// User's defined ops allowed with Flex.
const std::unordered_set<std::string> select_user_tf_ops_;
};
std::string Translator::UniqueName(mlir::Value val) {
@ -1103,12 +1110,15 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
resource_ops_.insert(node_def->op());
}
const bool is_allowed_flex_op =
IsAllowlistedFlexOp(node_def->op()) ||
((select_user_tf_ops_.count(node_def->op()) != 0) &&
(tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
// Flex op case
// Eventually, the allowlist will go away and we will rely on some TF op
// trait (e.g. No side effect) to determine if it is a supported "Flex"
// op or not.
if (enabled_op_types_.contains(OpType::kSelectTf) &&
IsAllowlistedFlexOp(node_def->op())) {
if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
// Construct ops as flex op encoding TensorFlow node definition
// as custom options.
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
@ -1159,7 +1169,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
}
// Insert failed op to `flex_ops` or `custom_ops`.
if (IsAllowlistedFlexOp(node_def->op())) {
if (is_allowed_flex_op) {
failed_flex_ops_.insert(os.str());
} else {
failed_custom_ops_.insert(os.str());
@ -1619,12 +1629,15 @@ bool UpdateEntryFunction(ModuleOp module) {
Optional<std::string> Translator::Translate(
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
if (!UpdateEntryFunction(module)) return llvm::None;
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, tags, op_or_arg_name_mapper);
emit_custom_ops, select_user_tf_ops, tags,
op_or_arg_name_mapper);
return translator.TranslateInternal();
}
@ -1876,9 +1889,22 @@ bool tflite::MlirToFlatBufferTranslateFunction(
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
OpOrArgNameMapper* op_or_arg_name_mapper) {
std::unordered_set<std::string> select_user_tf_ops;
return MlirToFlatBufferTranslateFunction(
module, serialized_flatbuffer, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
op_or_arg_name_mapper);
}
bool tflite::MlirToFlatBufferTranslateFunction(
ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
auto maybe_translated = Translator::Translate(
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
saved_model_tags, op_or_arg_name_mapper);
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
if (!maybe_translated) return true;
*serialized_flatbuffer = std::move(*maybe_translated);
return false;

View File

@ -52,6 +52,14 @@ bool MlirToFlatBufferTranslateFunction(
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
// Same as the above but with a list of allowed user's defined ops.
bool MlirToFlatBufferTranslateFunction(
mlir::ModuleOp module, std::string* serialized_flatbuffer,
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const std::unordered_set<std::string>& saved_model_tags,
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_

View File

@ -64,6 +64,7 @@ limitations under the License.
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -149,7 +150,7 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
is_signed, storage_type.getWidth()) +
is_weight_buffer;
static_cast<int>(is_weight_buffer);
int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
is_signed, storage_type.getWidth());
uint32_t flags =
@ -177,12 +178,25 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
quant_params.zero_point.at(0), storage_min, storage_max);
}
// import float tensor with calibration value into calibrated quantized type.
StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
Builder builder) {
if (tensor.quantization == nullptr) {
return errors::InvalidArgument("The tensor is not quantized.");
}
auto raw_elem_type = ConvertElementType(tensor.type, builder);
float min = tensor.quantization->min[0];
float max = tensor.quantization->max[0];
return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
}
// TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
// make that distinction and don't have to rely on context
// (input to main and constants must have static shape)
StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
bool shapeless_are_scalars = false,
bool is_constant = false) {
bool is_constant = false,
bool is_intermediate = false) {
mlir::Type elem_type = ConvertElementType(tensor.type, builder);
// TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
// if it's set
@ -191,6 +205,13 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
GetQuantizedType(tensor, builder, is_constant));
}
// Intermediate tensors with calibration value (but not scale and zero points)
// should return calibrated quantized type.
if (is_intermediate && tensor.quantization != nullptr &&
!IsQuantized(tensor)) {
TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
}
if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
return RankedTensorType::get({}, elem_type);
}
@ -1033,7 +1054,7 @@ StatusOr<FuncOp> ConvertSubgraph(
}
}
// Intermediate tensors for tfl.lstm are used to carry quantization range
// Intermediate tensors for LSTMs are used to carry quantization range
// in their types, so we only need and extract their types.
std::vector<mlir::TensorType> intermediate_types;
intermediate_types.reserve(5);
@ -1041,7 +1062,8 @@ StatusOr<FuncOp> ConvertSubgraph(
TF_ASSIGN_OR_RETURN(
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
/*is_constant=*/false,
/*is_intermediate=*/true));
intermediate_types.emplace_back(type);
}
@ -1135,7 +1157,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
auto builder = Builder(context);
std::vector<std::string> func_names;
for (auto& subgraph : model->subgraphs) {
func_names.push_back(subgraph->name);

View File

@ -1978,6 +1978,10 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1);
if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
return input();
}
// For now, only supports cast between integer types.
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!elements_attr) {

View File

@ -483,9 +483,9 @@ value of each element in `x`. For example, if x is an input element and y is
an output element, this operation computes \\(y = |x|\\).
}];
let arguments = (ins TFL_FpTensor:$x);
let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
let results = (outs TFL_FpTensor:$y);
let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
let hasFolder = 1;
}
@ -587,15 +587,15 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
let arguments = (ins
TFL_I32Tensor:$output_shape,
TFL_TensorOf<[F32, QI8, QUI8]>:$weights,
TFL_TensorOf<[F32, QI8, QUI8]>:$input,
TFL_TensorOfOrNone<[F32, QI32]>:$bias,
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$weights,
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias,
TFL_PaddingAttr:$padding,
Confined<I32Attr, [IntPositive]>:$stride_h,
Confined<I32Attr, [IntPositive]>:$stride_w
);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 1;
@ -624,7 +624,7 @@ def TFL_AveragePool2DOp:
}];
let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
I32Attr:$filter_height,
I32Attr:$filter_width,
TFL_PaddingAttr:$padding,
@ -633,7 +633,7 @@ def TFL_AveragePool2DOp:
TFL_AFAttr:$fused_activation_function
);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 1;
let customOption = "Pool2DOptions";
@ -947,7 +947,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let arguments = (ins
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter,
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter,
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
TFL_AFAttr:$fused_activation_function,
@ -999,14 +999,14 @@ in the batch dimensions and broadcasting.
}];
let arguments = (ins
TFL_TensorOf<[F32, QI8]>:$x,
TFL_TensorOf<[F32, QI8]>:$y,
TFL_TensorOf<[F32, QI8, QI16]>:$x,
TFL_TensorOf<[F32, QI8, QI16]>:$y,
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
DefaultValuedAttr<BoolAttr, "false">:$adj_y
);
let results = (outs
TFL_TensorOf<[F32, QI8]>:$output
TFL_TensorOf<[F32, QI8, QI16]>:$output
);
let hasOptions = 1;
@ -1026,7 +1026,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$params,
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
TFL_TensorOf<[I32, I64]>:$indices,
I32Attr:$axis
);
@ -1038,7 +1038,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
];
let results = (outs
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$output
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$output
);
let hasOptions = 1;
@ -1750,12 +1750,12 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
}];
let arguments = (
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input,
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$input,
// Slope of the activation function at x < 0.
F32Attr:$alpha
);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 0b1;
}
@ -1977,12 +1977,12 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
}];
let arguments = (
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
);
let results = (outs
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max
);
let builders = [TFL_BroadcastableBinaryBuilder];
@ -2005,13 +2005,13 @@ def TFL_MeanOp : TFL_Op<"mean", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$input,
TFL_TensorOf<[I32, I64]>:$axis,
BoolAttr:$keep_dims
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output);
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
@ -2090,13 +2090,13 @@ equivalent to setting:
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32OrI64Tensor:$begin,
TFL_I32OrI64Tensor:$size
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -2116,12 +2116,12 @@ def TFL_SumOp: TFL_Op<"sum", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
@ -2139,13 +2139,13 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
@ -2163,13 +2163,13 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
@ -2186,13 +2186,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
@ -2210,12 +2210,12 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
}];
let arguments = (
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
);
let results = (outs
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min
);
let builders = [TFL_BroadcastableBinaryBuilder];
@ -2364,10 +2364,10 @@ def TFL_PadOp : TFL_Op<"pad", [
```
}];
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
TFL_I32OrI64Tensor:$padding);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1;
}
@ -2500,9 +2500,9 @@ def TFL_ReluOp: TFL_Op<"relu", [
x -> max(0, x)
}];
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
@ -2828,11 +2828,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
}];
let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$input,
F32Attr:$beta
);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$output);
let hasOptions = 1;
@ -3058,12 +3058,12 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
}];
let arguments = (ins
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input,
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$input,
TFL_TensorOf<[I32]>:$perm
);
let results = (outs
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -3330,14 +3330,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
}];
let arguments = (ins
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input,
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input,
TFL_I32Tensor:$size,
BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output
);
let hasOptions = 1;
@ -3830,6 +3830,9 @@ Ba et al. 'Layer Normalization'
// Types of the optional intermediate tensors, which exist for fully
// quantized LSTM op and hold the ranges of the intermediate tensors.
// The type for intermediate tenssors are be quant.calibrated when imported
// to only store calibrated min, max values. The proper quantization spec is
// determined while going through quantization passes.
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
@ -3945,6 +3948,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
// Types of the optional intermediate tensors, which exist for fully
// quantized op and hold the ranges of the intermediate tensors.
// The type for intermediate tenssors are be quant.calibrated when imported
// to only store calibrated min, max values. The proper quantization spec is
// determined while going through quantization passes.
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,

View File

@ -55,7 +55,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
@ -89,6 +89,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
pass_config.unfold_batch_matmul = false;
}
return internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},

View File

@ -128,7 +128,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
@ -174,6 +174,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
pass_config.lower_tensor_list_ops = true;
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
pass_config.unfold_batch_matmul = false;
}
// TODO(b/153507667): Pass the session object when importing logic is removed.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
#include <ostream>
#include <unordered_set>
#include <utility>
#include "llvm/Support/ToolOutputFile.h"
@ -119,6 +120,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_INT32;
case toco::IODataType::INT64:
return DT_INT64;
case toco::IODataType::UINT64:
return DT_UINT64;
case toco::IODataType::STRING:
return DT_STRING;
case toco::IODataType::BOOL:
@ -185,7 +188,7 @@ Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs) {
quant_specs->inference_input_type =
@ -210,8 +213,12 @@ Status PopulateQuantizationSpecs(
node_dtypes->push_back(
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
}
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
if (flag.shape().unknown_rank()) {
node_shapes->push_back(llvm::None);
} else {
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
}
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
if (flag.has_mean_value() && flag.has_std_value()) {
@ -282,6 +289,10 @@ Status ConvertMLIRToTFLiteFlatBuffer(
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
const std::unordered_set<std::string> select_user_tf_ops(
toco_flags.select_user_tf_ops().begin(),
toco_flags.select_user_tf_ops().end());
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
@ -301,8 +312,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs,
saved_model_tags, result, &pm);
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
pass_config.quant_specs, saved_model_tags, result, &pm);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.

View File

@ -41,7 +41,7 @@ Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs);

View File

@ -53,10 +53,11 @@ struct QuantizationSpecs {
bool disable_per_channel = false;
// When set to true, the fixed output ranges of the activation ops (tanh,
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
// emulation ops should be specified after the ops in the input graph. This
// flag should be set to false for post-training quantization.
bool disable_enforced_fixed_output_range = false;
// sigmoid, etc.) and the weight constants are not inferred. Then, to quantize
// these ops, quantization emulation ops should be placed after the ops in the
// input graph. This flag should be set to false for post-training
// quantization.
bool disable_infer_tensor_range = false;
// The node type when the model is exported. Currently this is limited to
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the

View File

@ -100,13 +100,13 @@ class QuantizationDriver {
explicit QuantizationDriver(FuncOp fn, bool is_signed,
bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter,
bool enforce_fixed_output_range)
bool infer_tensor_range)
: fn_(fn),
builder_(fn.getBody()),
is_signed_(is_signed),
disable_per_channel_(disable_per_channel),
op_quant_spec_getter_(op_quant_spec_getter),
enforce_fixed_output_range_(enforce_fixed_output_range) {}
infer_tensor_range_(infer_tensor_range) {}
// The entry point of the quantization parameters propagation.
void Run();
@ -384,7 +384,9 @@ class QuantizationDriver {
OpQuantSpecGetter op_quant_spec_getter_;
bool enforce_fixed_output_range_;
// Infer output ranges for activation ops and constants. This is usually
// required for post-training quantization.
bool infer_tensor_range_;
};
} // namespace
@ -670,33 +672,43 @@ void QuantizationDriver::PreprocessConstantOps() {
Value value = cst.getResult();
builder_.setInsertionPoint(cst);
for (auto indexed_use : llvm::enumerate(value.getUses())) {
auto &use = indexed_use.value();
auto spec = GetQuantSpec(use.getOwner());
auto biases = spec->biases_params;
Operation *user = use.getOwner();
int operand_num = use.getOperandNumber();
// The following loop will change the value uses, thus we cache all the uses
// needs to be changed.
llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
for (auto &use : value.getUses()) {
uses.push_back({use.getOwner(), use.getOperandNumber()});
}
for (auto indexed_use : llvm::enumerate(uses)) {
Operation *user = indexed_use.value().first;
int operand_num = indexed_use.value().second;
auto spec = GetQuantSpec(user);
auto biases = spec->biases_params;
// The quantization parameters of a `weight` shouldn't be determined by
// other values. So any constants which are not bias, an operand of an
// op with same scale requirements, and haven't been quantized are
// weights.
if (biases.find(operand_num) == biases.end() &&
!llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
!llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
// Needs to scan the content to get the quantization parameters if there
// are no quantization parameters (FakeQuant ops).
// For this case, the weight isn't duplicated.
// Needs to scan the content of weights to get the quantization
// parameters if there are no quantization parameters (FakeQuant ops).
// For this case, the weight will not be duplicated.
weights_.insert(cst);
auto affine_user =
llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
if (affine_user &&
affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
affine_user.RequiredNarrowRangeAffineOperand()) {
optimized_weights_.insert(
{cst, affine_user.GetQuantizationDimIndex()});
}
} else {
// This is a bias, so the quantization parameter isn't determined by the
// local content. Same if the user can have quantization parameter
// propagated from other places.
// Duplicate this constant in case it is shared by different users.
// This is a bias or an operand of an op with same scale requirements,
// so the quantization parameter are propagated from or determined by
// other values. Duplicate this constant in case it is shared by
// different users.
if (indexed_use.index() > 0) {
cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
}
@ -786,12 +798,14 @@ bool QuantizationDriver::PropagateParams() {
quantized_.insert(op);
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
// If it isn't a weight or has been quantized, skip.
if (!IsWeight(cst) || IsQuantized(op)) continue;
// The quantization parameters are determined by the content of the
// constant.
changed |= SetConstantResultParams(op);
// If the workflow requires inferring ranges from the content
// (post-training quantization) and it is weight (filter) and hasn't
// been quantized, we infer the quantization parameters from the content.
if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
// The quantization parameters are determined by the content of the
// constant.
changed |= SetConstantResultParams(op);
}
continue;
}
@ -826,7 +840,9 @@ bool QuantizationDriver::PropagateParams() {
// TODO(fengliuai): make the bit width configurable.
auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
if (restricted && enforce_fixed_output_range_) {
if (restricted && infer_tensor_range_) {
// Infer ranges from the activation ops. This is usually required for
// the post-training quantization workflow.
// TODO(fengliuai): different result can have different fixed range.
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
for (auto i = 0; i < op->getNumResults(); ++i) {
@ -903,9 +919,9 @@ void QuantizationDriver::Run() {
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter,
bool post_training_quantization) {
bool infer_tensor_ranges) {
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
post_training_quantization)
infer_tensor_ranges)
.Run();
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
#include <string>
#include <unordered_map>
#include "llvm/ADT/SmallVector.h"
@ -103,8 +104,11 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
if (!stats) return failure();
for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
mins.push_back(FloatAttr::getValueAsDouble(*it++));
maxs.push_back(FloatAttr::getValueAsDouble(*it));
double min = FloatAttr::getValueAsDouble(*it++);
double max = FloatAttr::getValueAsDouble(*it);
TensorRangeSanityCheck(op, min, max);
mins.push_back(min);
maxs.push_back(max);
}
quant_type =
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins,
@ -112,6 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
} else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) {
double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
TensorRangeSanityCheck(op, rmin, rmax);
quant_type =
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
narrow_range, expressed, is_signed);
@ -134,6 +139,19 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
int num_bits;
bool narrow_range;
bool is_signed;
// Emits an op warning message if the calibrated range is larger than 10.0 and
// the storage type is less than or equal to 8 bits.
void TensorRangeSanityCheck(quant::StatisticsOp op, double min,
double max) const {
double range = std::fabs(max - min);
if (num_bits <= 8 && range >= 10.0) {
op.emitWarning(
"Tensor range is too wide to be quantized. Use tf.clip_by_value or "
"tf.relu6 to narrow the tensor range. Range: " +
std::to_string(range) + ", bit width: " + std::to_string(num_bits));
}
}
};
// A base rewrite pattern which matches any N-in-M-out operations with
@ -490,13 +508,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
// and the propagation results are materialized by inserting pairs of quantize
// and dequantize ops to this function. Set `disable_per_channel` to true to not
// use per channel quantization even the op supports it.
// Setting `enforce_fixed_output_range` to true, to infer quantization
// parameters from the fixed output range ops. This is only used for
// post-training quantization.
// Setting `infer_tensor_range` to true, to infer quantization parameters from
// the activation ops and weight constants. This is only used for post-training
// quantization.
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool disable_per_channel,
OpQuantSpecGetter op_quant_spec_getter,
bool enforce_fixed_output_range);
bool infer_tensor_ranges);
// The function might contain more stats ops than required, and it will
// introduce requantize if the calibration stats have conflicts. This method

View File

@ -638,4 +638,10 @@ func @cast_ui8_to_i32() -> tensor<4xi32> {
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_identity
func @cast_identity(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
%0 = "tfl.cast"(%arg0) : (tensor<7xf32>) -> tensor<7xf32>
return %0 : tensor<7xf32>
// CHECK: return %arg0 : tensor<7xf32>
}

View File

@ -0,0 +1,335 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>
// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>
// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>
// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>
// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>
{
"version": 3,
"operator_codes": [
{
"builtin_code": "UNIDIRECTIONAL_SEQUENCE_LSTM"
}
],
"subgraphs": [
{
"tensors": [
{
"shape": [1, 5],
"name": "input0"
},
{
"shape": [2, 5],
"buffer": 1,
"name": "input2input_weights1"
},
{
"shape": [2, 5],
"buffer": 2,
"name": "input2forget_weights2"
},
{
"shape": [2, 5],
"buffer": 3,
"name": "input2cell_weights3"
},
{
"shape": [2, 5],
"buffer": 4,
"name": "input2output_weights4"
},
{
"shape": [2, 4],
"buffer": 5,
"name": "rec2input_weights5"
},
{
"shape": [2, 4],
"buffer": 6,
"name": "rec2forget_weights6"
},
{
"shape": [2, 4],
"buffer": 7,
"name": "rec2cell_weights7"
},
{
"shape": [2, 4],
"buffer": 8,
"name": "rec2output_weights8"
},
{
"shape": [2],
"buffer": 9,
"name": "cell2input_weights9"
},
{
"shape": [2],
"buffer": 10,
"name": "cell2forget_weights10"
},
{
"shape": [2],
"buffer": 11,
"name": "cell2output_weights11"
},
{
"shape": [2],
"buffer": 12,
"name": "input_gate_bias12"
},
{
"shape": [2],
"buffer": 13,
"name": "forget_gate_bias13"
},
{
"shape": [2],
"buffer": 14,
"name": "cell_gate_bias14"
},
{
"shape": [2],
"buffer": 15,
"name": "output_gate_bias15"
},
{
"shape": [4, 2],
"buffer": 16,
"name": "proj_weights16"
},
{
"shape": [4],
"buffer": 17,
"name": "proj_bias17"
},
{
"shape": [1, 4],
"name": "input_activation_state18",
"is_variable": true,
"quantization": {
"min": [-0.9],
"max": [0.9]
}
},
{
"shape": [1, 2],
"name": "input_cell_state19",
"is_variable": true,
"quantization": {
"min": [-0.8],
"max": [0.8]
}
},
{
"shape": [2],
"buffer": 18,
"name": "input_norm20"
},
{
"shape": [2],
"buffer": 19,
"name": "forget_norm21"
},
{
"shape": [2],
"buffer": 20,
"name": "cell_norm22"
},
{
"shape": [2],
"buffer": 21,
"name": "output_norm23"
},
{
"shape": [],
"name": "output24"
},
{
"shape": [],
"name": "intermediate_0",
"is_variable": true,
"quantization": {
"min": [-32],
"max": [32]
}
},
{
"shape": [],
"name": "intermediate_1",
"is_variable": true,
"quantization": {
"min": [-16],
"max": [16]
}
},
{
"shape": [],
"name": "intermediate_2",
"is_variable": true,
"quantization": {
"min": [-4],
"max": [4]
}
},
{
"shape": [],
"name": "intermediate_3",
"is_variable": true,
"quantization": {
"min": [-1.0],
"max": [1.0]
}
},
{
"shape": [],
"name": "intermediate_4",
"is_variable": true,
"quantization": {
"min": [-0.5],
"max": [0.5]
}
}
],
"inputs": [0],
"outputs": [24],
"operators": [
{
"inputs": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
],
"outputs": [24],
"intermediates": [
25, 26, 27, 28, 29
],
"builtin_options_type": "UnidirectionalSequenceLSTMOptions",
"builtin_options": {
"fused_activation_function": "TANH",
"cell_clip": 50.0
},
"mutating_variable_inputs": [
false,
false, false, false, false,
false, false, false, false,
false, false, false,
false, false, false, false,
true, true,
false, false, false, false
]
}
]
}
],
"buffers": [
{
"data": []
},
{
"data": [
36, 167, 168, 63, 0, 140, 72, 191, 120, 20, 147, 62, 20, 152, 196, 190, 121, 98, 82, 187, 95, 128, 213, 61, 189, 3, 138, 63, 54, 103, 13, 62, 46, 224, 66, 63, 157, 204, 180, 191
]
},
{
"data": [
223, 20, 21, 64, 246, 166, 31, 191, 6, 51, 157, 188, 114, 90, 167, 62, 118, 240, 59, 63, 49, 162, 255, 62, 17, 91, 160, 63, 32, 47, 26, 63, 40, 136, 178, 191, 243, 154, 236, 61
]
},
{
"data": [
137, 231, 86, 63, 41, 154, 16, 63, 239, 37, 77, 191, 55, 189, 24, 189, 86, 63, 18, 63, 42, 55, 13, 191, 110, 139, 138, 191, 219, 148, 181, 63, 71, 232, 108, 191, 66, 226, 145, 191
]
},
{
"data": [
245, 179, 225, 190, 51, 202, 176, 189, 132, 47, 53, 191, 155, 25, 50, 191, 197, 130, 240, 191, 98, 125, 45, 62, 243, 70, 83, 62, 85, 155, 139, 63, 113, 239, 11, 192, 35, 251, 139, 62
]
},
{
"data": [
248, 188, 211, 191, 142, 11, 73, 62, 36, 8, 84, 63, 186, 208, 11, 191, 76, 208, 190, 191, 223, 200, 210, 63, 183, 170, 103, 63, 116, 129, 145, 63
]
},
{
"data": [
235, 202, 222, 190, 159, 201, 112, 191, 217, 248, 166, 63, 165, 199, 131, 191, 130, 59, 47, 63, 179, 11, 186, 62, 55, 168, 18, 192, 152, 213, 26, 64
]
},
{
"data": [
245, 123, 138, 62, 213, 106, 231, 59, 211, 218, 250, 62, 25, 157, 134, 63, 147, 22, 164, 63, 25, 221, 139, 62, 1, 230, 247, 62, 210, 185, 142, 63
]
},
{
"data": [
197, 123, 23, 192, 45, 96, 178, 190, 174, 87, 165, 62, 213, 225, 200, 191, 119, 248, 15, 191, 128, 125, 171, 189, 90, 125, 222, 63, 4, 76, 95, 62
]
},
{
"data": [
210, 73, 183, 63, 248, 177, 13, 191
]
},
{
"data": [
78, 251, 212, 191, 169, 29, 147, 63
]
},
{
"data": [
178, 227, 203, 191, 247, 155, 103, 63
]
},
{
"data": [
206, 111, 165, 190, 153, 77, 227, 63
]
},
{
"data": [
255, 114, 132, 191, 253, 202, 140, 191
]
},
{
"data": [
90, 247, 1, 192, 125, 120, 209, 191
]
},
{
"data": [
65, 75, 243, 191, 58, 122, 146, 190
]
},
{
"data": [
40, 135, 20, 63, 109, 50, 220, 191, 56, 241, 189, 63, 65, 12, 92, 63, 61, 14, 162, 62, 157, 138, 81, 63, 125, 61, 191, 61, 102, 231, 20, 63
]
},
{
"data": [
145, 79, 49, 189, 175, 235, 220, 190, 182, 111, 157, 190, 144, 236, 97, 191
]
},
{
"data": [
76, 188, 109, 63, 228, 150, 201, 190
]
},
{
"data": [
6, 146, 66, 191, 122, 127, 100, 191
]
},
{
"data": [
216, 59, 169, 190, 161, 178, 215, 191
]
},
{
"data": [
208, 144, 101, 191, 127, 233, 195, 190
]
}
]
}

View File

@ -20,20 +20,20 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
%cst = constant unit
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK-LABEL: testFullyQuantizedLSTM
// CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -318,6 +318,15 @@ func @any(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
// CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
}
func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi1> {
%0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi64>) -> tensor<?xi1>
return %0 : tensor<?xi1>
// CHECK-LABEL: any_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1>
}
func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
return %0 : tensor<8x16xf32>
@ -435,6 +444,16 @@ func @scatterNdHigherRankIndices(%arg0: tensor<4x2x2xi32>, %arg1: tensor<4x2x3xf
// CHECK: return %[[RES]]
}
func @scatter_nd_i64(%arg0: tensor<4x2x2xi64>, %arg1: tensor<4x2x3xf32>, %arg2: tensor<3xi64>) -> tensor<10x2x3xf32> {
%0 = "tf.ScatterNd"(%arg0, %arg1, %arg2) : (tensor<4x2x2xi64>, tensor<4x2x3xf32>, tensor<3xi64>) -> tensor<10x2x3xf32>
return %0 : tensor<10x2x3xf32>
// CHECK-LABEL:scatter_nd_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cast"
// CHECK: "tfl.scatter_nd"
}
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
@ -689,6 +708,16 @@ func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2
// CHECK: return
}
func @reverse_v2_i64(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi64>) -> tensor<1x2x3x4xf32> {
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi64>) -> tensor<1x2x3x4xf32>
return %0 : tensor<1x2x3x4xf32>
// CHECK-LABEL:reverse_v2_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.reverse_v2"
// CHECK: return
}
func @matrix_diag(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
%0 = "tf.MatrixDiag"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16x16xf32>
return %0 : tensor<8x16x16xf32>
@ -763,13 +792,31 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
}
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
func @matrix_set_diag_v3(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) {align = "RIGHT_LEFT"} : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK-LABEL: func @matrix_set_diag(
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return [[VAL_0]]
// CHECK-LABEL: func @matrix_set_diag_v3
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
}
func @matrix_set_diag_v3_non_zero_k(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
%cst = constant dense<1> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK-LABEL: @matrix_set_diag_v3_non_zero_k
// CHECK: tf.MatrixSetDiagV3
}
func @matrix_set_diag_v3_default_align(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK-LABEL: @matrix_set_diag_v3_default_align
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
}
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
@ -934,6 +981,15 @@ func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32
// CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: sum_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.sum"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@ -950,6 +1006,15 @@ func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
// CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_min_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@ -966,6 +1031,15 @@ func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
// CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_max_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@ -982,6 +1056,15 @@ func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tens
// CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: reduce_prod_i64axes
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
}
func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@ -996,6 +1079,15 @@ func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<
// CHECK: "tf.BatchToSpaceND"
}
func @batch_to_space_nd_i64(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<?xf32> {
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: batch_to_space_nd_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cast"
// CHECK: "tfl.batch_to_space_nd"
}
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<*xf32> {
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
@ -1003,6 +1095,15 @@ func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2:
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
}
func @space_to_batch_nd_i64(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<*xf32> {
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: space_to_batch_nd_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cast"
// CHECK: "tfl.space_to_batch_nd"
}
func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> {
%0:3 = "tf.Split"(%arg0, %arg1) : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
return %0#0 : tensor<1x4x3xf32>
@ -1116,10 +1217,10 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
return %0 : tensor<10x10xf32>
// CHECK-LABEL: strided_slice_with_constant_attributes
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
// CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
}
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
@ -1129,6 +1230,39 @@ func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tenso
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
}
func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: strided_slice_with_unranked_input_and_i64_parameters
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32>
}
func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
// CHECK-LABEL: strided_slice_with_i64_parameters
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
}
func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> {
%cst = constant dense<-1> : tensor<1xi64>
%cst_1 = constant dense<0> : tensor<1xi64>
%cst_2 = constant dense<1> : tensor<1xi64>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32>
return %0 : tensor<10x10xf32>
// CHECK-LABEL: strided_slice_with_i64_constant_attributes
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
// CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
}
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
@ -1361,8 +1495,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
// CHECK-LABEL: conv2d_backprop_input
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
// CHECK: %[[CST_0:.*]] = constant unit
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
@ -1797,10 +1930,25 @@ func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
}
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
func @cumsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
// CHECK-LABEL: cumsum_invalid
// CHECK-NOT: "tfl.cumsum"
// CHECK-LABEL: cumsum_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.cumsum"
}
func @segmentsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: segmentsum
// CHECK: "tfl.segment_sum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
}
func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32> {
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i64>) -> tensor<*xf32>
return %0 : tensor<*xf32>
// CHECK-LABEL: segmentsum_i64
// CHECK: "tfl.cast"
// CHECK: "tfl.segment_sum"
}

View File

@ -625,6 +625,35 @@ func @QuantizeSharedBiases2(
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
}
// Make sure constants are duplicataed for all users.
// CHECK-LABEL: QuantizeSharedConstantsMultipleUsers
func @QuantizeSharedConstantsMultipleUsers(
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
%arg1: tensor<32x!quant.uniform<u8:f32, 2.0>>,
%arg2: tensor<32x!quant.uniform<u8:f32, 3.0>>,
%arg3: tensor<32x!quant.uniform<u8:f32, 4.0>>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) {
%cst = constant dense<0.0> : tensor<32xf32>
%0 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
%1 = "tfl.dequantize"(%arg1) : (tensor<32x!quant.uniform<u8:f32, 2.0>>) -> tensor<32xf32>
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<u8:f32, 3.0>>) -> tensor<32xf32>
%3 = "tfl.dequantize"(%arg3) : (tensor<32x!quant.uniform<u8:f32, 4.0>>) -> tensor<32xf32>
%4 = "tfl.minimum"(%0, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%5 = "tfl.minimum"(%1, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%6 = "tfl.minimum"(%2, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
%7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>
// CHECK: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32>
// CHECK: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32>
// CHECK: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32>
// CHECK: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
}
// Make sure quantization parameters are scanned from weight, but not from bias.
// CHECK-LABEL: QuantizeWeight
func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {

View File

@ -551,35 +551,17 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf
return %0 : tensor<8x4x16x1xf32>
}
// CHECK-LABEL: @MatrixSetDiagV2Conversion
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<1> : tensor<i32>
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
// CHECK-LABEL: @MatrixSetDiagV3Conversion
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
%cst = constant dense<0> : tensor<i32>
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
return %0 : tensor<3x3xi32>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> {
%cst = constant dense<-1> : tensor<1xi32>
%cst_1 = constant dense<0> : tensor<1xi32>
%cst_2 = constant dense<1> : tensor<1xi32>
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
return %0 : tensor<10x10xf32>
// CHECK-LABEL: strided_slice_with_constant_attributes
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
// CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
}
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {

View File

@ -242,7 +242,8 @@ int main(int argc, char **argv) {
std::string result;
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm);
emit_select_tf_ops, emit_custom_ops,
/*select_user_tf_ops=*/{}, quant_specs, tags, &result, &pm);
if (!status.ok()) return kTrFailure;
std::string error_msg;

View File

@ -137,6 +137,7 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const mlir::TFL::QuantizationSpecs& quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
std::string* result, mlir::PassManager* pass_manager) {
@ -169,10 +170,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
}
// Write MLIR TFLite dialect into FlatBuffer
OpOrArgLocNameMapper op_or_arg_name_mapper;
if (!quant_specs.RunWeightQuantization()) {
if (tflite::MlirToFlatBufferTranslateFunction(
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
emit_custom_ops, saved_model_tags)) {
emit_custom_ops, select_user_tf_ops, saved_model_tags,
&op_or_arg_name_mapper)) {
return statusHandler.ConsumeStatus();
}
} else {
@ -181,7 +184,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
std::string pre_quantized_result;
if (tflite::MlirToFlatBufferTranslateFunction(
module, &pre_quantized_result, emit_builtin_tflite_ops,
emit_select_tf_ops, emit_custom_ops, saved_model_tags)) {
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
saved_model_tags, &op_or_arg_name_mapper)) {
return statusHandler.ConsumeStatus();
}
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);

View File

@ -63,6 +63,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
bool emit_select_tf_ops, bool emit_custom_ops,
const std::unordered_set<std::string>& select_user_tf_ops,
const mlir::TFL::QuantizationSpecs& quant_specs,
const std::unordered_set<std::string>& saved_model_tags,
std::string* result, mlir::PassManager* pass_manager);

View File

@ -54,7 +54,7 @@ def ExtractSingleElementAsInt32 : NativeCodeCall<
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
// Converts tensor with int64 to int32.
def CreateTFLCastToInt32Op : NativeCodeCall<
def CreateTFCastToInt32Op : NativeCodeCall<
"CreateCastToInt32($0, $_loc, $_builder)">;
// Checks whether the given operation has static shapes and same shapes of all inputs.
@ -193,8 +193,8 @@ def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids),
(TFL_SegmentSumOp $data, $segment_ids)>;
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids),
(TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>;
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
(TFL_SelectOp $cond, $x, $y)>;
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),
@ -221,7 +221,7 @@ def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
(TFL_TransposeOp $arg,
(CreateTFLCastToInt32Op $perm))>;
(CreateTFCastToInt32Op $perm))>;
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
@ -309,8 +309,9 @@ def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r),
(TFL_SquaredDifferenceOp $l, $r)>;
def LegalizeReverseV2 : Pat<(TF_ReverseV2Op $arg0, $arg1),
(TFL_ReverseV2Op $arg0, $arg1)>;
def LegalizeReverseV2 : Pat<
(TF_ReverseV2Op $arg0, $axis),
(TFL_ReverseV2Op $arg0, (CreateTFCastToInt32Op $axis))>;
def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1,
/*incompatible_shape_error=*/ConstBoolAttrTrue),
@ -327,33 +328,40 @@ def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_MeanOp $arg0, $arg1, $arg2)>;
def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2),
(TFL_SumOp $arg, $axes, $arg2)>;
(TFL_SumOp $arg, (CreateTFCastToInt32Op $axes), $arg2)>;
// TopK in TFL is always sorted so we ignore that attribute here.
def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted),
(TFL_TopKV2Op $input, $k)>;
def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
def LegalizeMin : Pat<
(TF_MinOp $arg0, $axes, BoolAttr:$arg2),
(TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
def LegalizeMax : Pat<
(TF_MaxOp $arg0, $axes, BoolAttr:$arg2),
(TFL_ReduceMaxOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2),
(TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
def LegalizeProd : Pat<
(TF_ProdOp $arg0, $axes, BoolAttr:$arg2),
(TFL_ReduceProdOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims),
(TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>;
def LegalizeAny : Pat<
(TF_AnyOp $input, $reduction_indices, $keep_dims),
(TFL_ReduceAnyOp $input, (CreateTFCastToInt32Op $reduction_indices),
$keep_dims)>;
def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
def LegalizeBatchToSpaceND : Pat<
(TF_BatchToSpaceNDOp $input, $block_shape, $crops),
(TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
(TFL_BatchToSpaceNdOp $input, (CreateTFCastToInt32Op $block_shape),
(CreateTFCastToInt32Op $crops))>;
def LegalizeSpaceToBatchND : Pat<
(TF_SpaceToBatchNDOp $input, $block_shape, $paddings),
(TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
(TFL_SpaceToBatchNdOp $input, (CreateTFCastToInt32Op $block_shape),
(CreateTFCastToInt32Op $paddings))>;
def LegalizeSpaceToDepth : Pat<
(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
@ -437,14 +445,45 @@ def LegalizeConv2DBackpropInput : Pat<
/*stride_h=*/ ExtractI32At<1>:$strides,
/*stride_w=*/ ExtractI32At<2>:$strides)>;
def IsRankZeroAttr
: CPred<"$_self.cast<DenseElementsAttr>().getType().getRank() == 0">;
def HasValueZero
: CPred<"$_self.cast<DenseElementsAttr>().getSplatValue()."
"cast<::mlir::IntegerAttr>().getInt() == 0">;
// TFLite only supports MatrixSetDiag ops with scalar zero k attribute.
def IsSupportedByTFLiteMatrixSetDiag
: ElementsAttrBase<And<[ElementsAttr.predicate,
IsRankZeroAttr, HasValueZero]>,
"MatrixSetDiag attribute verification">;
// Attribute align doesn't matter when k is zero.
def LegalizeMatrixSetDiag : Pat<
(TF_MatrixSetDiagOp $input, $diagonal),
(TF_MatrixSetDiagV3Op $input, $diagonal,
(ConstantLikeMatcher IsSupportedByTFLiteMatrixSetDiag:$k), $align),
(TFL_MatrixSetDiagOp $input, $diagonal)>;
def LegalizeScatterNd : Pat<
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
(TF_ScatterNdOp $indices, $updates, $shape),
(TFL_ScatterNdOp (CreateTFCastToInt32Op $indices), $updates,
(CreateTFCastToInt32Op $shape))>;
def LegalizeCumsum : Pat<
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;
(TFL_CumsumOp $input, (CreateTFCastToInt32Op $axis), $exclusive, $reverse)>;
def LegalizeReshape : Pat<
(TF_ReshapeOp $input, $shape),
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;
def LegalizeStridedSlice : Pat<
(TF_StridedSliceOp
$input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask,
$new_axis_mask, $shrink_axis_mask),
(TFL_StridedSliceOp $input,
(CreateTFCastToInt32Op $begin), (CreateTFCastToInt32Op $end),
(CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask),
(convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
(convertIntAttrTo32Bit $new_axis_mask),
(convertIntAttrTo32Bit $shrink_axis_mask))>;

View File

@ -123,7 +123,8 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
IntegerType new_ele_type = rewriter.getIntegerType(32);
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
return rewriter.create<TFL::CastOp>(loc, new_type, val);
return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
rewriter.getBoolAttr(false));
}
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
@ -145,10 +146,8 @@ DECL_CONVERT_OP(MatMul);
DECL_CONVERT_OP(MatrixDiagV2);
DECL_CONVERT_OP(MatrixDiagV3);
DECL_CONVERT_OP(Pack);
DECL_CONVERT_OP(Reshape);
DECL_CONVERT_OP(Split);
DECL_CONVERT_OP(SplitV);
DECL_CONVERT_OP(StridedSlice);
DECL_CONVERT_OP(Unpack);
DECL_CONVERT_OP(RandomUniform);
@ -299,30 +298,6 @@ LogicalResult ConvertTFPackOp::matchAndRewrite(
return success();
}
LogicalResult ConvertTFReshapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
auto input = tf_reshape_op.tensor();
auto shape = tf_reshape_op.shape();
ShapedType shape_type = shape.getType().cast<ShapedType>();
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
if (!shape_type.getElementType().isSignlessInteger(32)) {
auto new_shape = shape_type.getShape();
IntegerType new_ele_type = rewriter.getIntegerType(32);
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
// Uses TF::CastOp to be folded if the shape input is a constant.
shape = rewriter
.create<TF::CastOp>(op->getLoc(), new_type, shape,
rewriter.getBoolAttr(false))
.y();
}
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
input, shape);
return success();
}
LogicalResult ConvertTFSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op);
@ -349,81 +324,6 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite(
return success();
}
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
Value attribute,
ArrayRef<int32_t> padding_val, int* mask) {
DenseIntElementsAttr dense_elem_attr;
SmallVector<int32_t, 8> padded_val;
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
if (!ranked_attr_type ||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
// If the input attribute is neither ranked type nor constant, we
// can't do any padding. Instead we just return it.
return attribute;
}
for (const auto& idx : dense_elem_attr.getIntValues()) {
padded_val.push_back(idx.getSExtValue());
}
auto attr_dim_count = ranked_attr_type.getShape()[0];
int full_dim_count = padding_val.size();
for (int i = attr_dim_count; i < full_dim_count; ++i) {
padded_val.push_back(padding_val[i]);
if (mask) *mask |= 1 << i;
}
auto type =
RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
auto attr = DenseElementsAttr::get<int32_t>(type, padded_val);
return rewriter.create<ConstantOp>(op->getLoc(), type, attr);
}
LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
auto ranked_input_type =
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_input_type) {
// If input is not a ranked tensor, we can't deduce the padding dimensions
// from it, so we just do a plain conversion here.
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
tf_strided_slice_op.strides(),
rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
return success();
}
int num_input_dims = ranked_input_type.getRank();
// Pad `begin` array with zero values and update the `begin_mask`.
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
int begin_mask = tf_strided_slice_op.begin_mask();
Value padded_begin = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
// Pad `end` array with `input_shape` and update the `end_mask`.
int end_mask = tf_strided_slice_op.end_mask();
auto input_shape = ranked_input_type.getShape();
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
Value padded_end = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
// Pad `strides` array with ones.
SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
Value padded_strides = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
padded_begin, padded_end, padded_strides,
rewriter.getI32IntegerAttr(begin_mask),
rewriter.getI32IntegerAttr(end_mask),
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
return success();
}
LogicalResult ConvertTFUnpackOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_unpack_op = cast<TF::UnpackOp>(op);
@ -792,10 +692,9 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
populateWithGenerated(context, patterns);
patterns
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFRandomUniformOp>(
context);
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
ConvertTFRandomUniformOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,

View File

@ -376,14 +376,17 @@ void PrepareQuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
bool quantization_aware_training_mode = ContainsQuantizeOps(func);
// Enforce fixed output range for post-training quantization and
// when the model has quantization emulation ops, unless it was disabled
// explicitly by the flag.
bool enforced_output_range =
(quant_specs_.post_training_quantization ||
quantization_aware_training_mode) &&
!quant_specs_.disable_enforced_fixed_output_range;
// When this is true, the quantizer will try its best to extract the
// quantization parameters from the op quantization property and constant
// content. This is also set to true when the `quantize_allowlist` and
// `quantize_signed` test flags are enabled.
bool eager_quantize = ContainsQuantizeOps(func) ||
(!quantize_allowlist.empty() || quantize_signed);
// Infer the tensor range for the activation ops and weight constants unless
// it is disabled explicitly.
bool infer_tensor_range =
(quant_specs_.post_training_quantization || eager_quantize) &&
!quant_specs_.disable_infer_tensor_range;
if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
@ -403,7 +406,7 @@ void PrepareQuantizePass::runOnFunction() {
// values (tensors).
ApplyQuantizationParamsPropagation(
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
GetOpQuantSpec, enforced_output_range);
GetOpQuantSpec, infer_tensor_range);
ConvertMlirQuantOpsToTFLQuantOps(func);
}

View File

@ -519,9 +519,10 @@ struct ConvertTFStridedSlice : public RewritePattern {
explicit ConvertTFStridedSlice(MLIRContext *context)
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
LogicalResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
LogicalResult RewriteNewAxisMask(Operation *op,
PatternRewriter &rewriter) const {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
// Insert a new reshape op.
Value original_input = strided_slice_op.input();
@ -529,51 +530,51 @@ struct ConvertTFStridedSlice : public RewritePattern {
original_input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape();
SmallVector<int64_t, 4> new_shape;
SmallVector<int64_t, 4> revised_shape;
int index = 0;
const int original_input_rank = original_input_shape.size();
while (index < original_input_rank || new_axis_mask) {
if (new_axis_mask & 1) {
new_shape.emplace_back(1);
revised_shape.emplace_back(1);
} else {
new_shape.emplace_back(original_input_shape[index++]);
revised_shape.emplace_back(original_input_shape[index++]);
}
new_axis_mask >>= 1;
}
if (failed(TF::VerifyShapeOfReshapeOp(new_shape))) return failure();
if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
const int dim_size = new_shape.size();
const int dim_size = revised_shape.size();
Location loc = strided_slice_op.getLoc();
auto shape_type =
RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
SmallVector<Attribute, 4> result_shape_data(dim_size);
for (int i = 0; i < dim_size; ++i) {
result_shape_data[i] =
rewriter.getI32IntegerAttr(static_cast<int32_t>(new_shape[i]));
rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
}
auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
auto new_output_type =
RankedTensorType::get(new_shape, original_input_type.getElementType());
auto revised_output_type = RankedTensorType::get(
revised_shape, original_input_type.getElementType());
TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
loc, new_output_type, original_input, shape);
loc, revised_output_type, original_input, shape);
// Replace the original strided_slice.
uint64_t new_begin_mask = strided_slice_op.begin_mask();
uint64_t new_end_mask = strided_slice_op.end_mask();
uint64_t revised_begin_mask = strided_slice_op.begin_mask();
uint64_t revised_end_mask = strided_slice_op.end_mask();
// Since we expand the dims, we need to apply them to the begin_mask &
// end_mask.
new_begin_mask |= strided_slice_op.new_axis_mask();
new_end_mask |= strided_slice_op.new_axis_mask();
revised_begin_mask |= strided_slice_op.new_axis_mask();
revised_end_mask |= strided_slice_op.new_axis_mask();
auto attribute_type = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
strided_slice_op.end(), strided_slice_op.strides(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.ellipsis_mask()),
rewriter.getI64IntegerAttr(0),
@ -582,10 +583,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
return success();
}
LogicalResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
LogicalResult RewriteEllipsisMask(Operation *op,
PatternRewriter &rewriter) const {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
// Enforce operator precedence.
@ -632,8 +634,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
int64_t begin_mask = strided_slice_op.begin_mask();
int64_t end_mask = strided_slice_op.end_mask();
int64_t new_begin_mask = 0;
int64_t new_end_mask = 0;
int64_t revised_begin_mask = 0;
int64_t revised_end_mask = 0;
int64_t revised_shrink_axis_mask = 0;
SmallVector<int32_t, 4> padded_begin;
@ -647,8 +649,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
if ((shrink_axis_mask >> index) & 1)
revised_shrink_axis_mask |= (1 << new_index);
++index;
@ -657,8 +659,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
// Ellipsis.
for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
new_begin_mask |= (1 << new_index);
new_end_mask |= (1 << new_index);
revised_begin_mask |= (1 << new_index);
revised_end_mask |= (1 << new_index);
// Mimic the begin/end/strides mask behavior.
padded_begin.push_back(0);
@ -675,8 +677,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
if ((shrink_axis_mask >> index) & 1)
revised_shrink_axis_mask |= (1 << new_index);
@ -701,35 +703,135 @@ struct ConvertTFStridedSlice : public RewritePattern {
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
op, strided_slice_op.getType(), input, begin_op.getResult(),
end_op.getResult(), stride_op.getResult(),
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
rewriter.getIntegerAttr(attribute_type, new_end_mask),
/*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
/*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
return success();
}
void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
SmallVectorImpl<int32_t> &val,
SmallVectorImpl<int32_t> &padded_val,
ArrayRef<int32_t> padding_val,
int *mask) const {
for (const auto &idx : dense_elem_attr.getIntValues()) {
val.push_back(idx.getSExtValue());
padded_val.push_back(idx.getSExtValue());
}
int attr_dim_count = val.size();
int full_dim_count = padding_val.size();
for (int i = attr_dim_count; i < full_dim_count; ++i) {
padded_val.push_back(padding_val[i]);
if (mask) *mask |= 1 << i;
}
}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
// Handle new axis mask.
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
if (new_axis_mask != 0) {
if (strided_slice_op.new_axis_mask() != 0) {
// We currently don't handle simultaneous shrink_ and new_axis masks.
if (strided_slice_op.shrink_axis_mask()) {
return failure();
if (!strided_slice_op.shrink_axis_mask()) {
return RewriteNewAxisMask(strided_slice_op, rewriter);
}
return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
}
// Handle ellipsis mask.
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
if (ellipsis_mask != 0) {
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
if (strided_slice_op.ellipsis_mask() != 0) {
return RewriteEllipsisMask(strided_slice_op, rewriter);
}
return failure();
auto ranked_input_type =
strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_input_type) {
return failure();
}
auto begin_attr = strided_slice_op.begin();
auto end_attr = strided_slice_op.end();
auto strides_attr = strided_slice_op.strides();
auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
auto strides_attr_type =
strides_attr.getType().dyn_cast<RankedTensorType>();
DenseIntElementsAttr begin_elem_attr;
DenseIntElementsAttr end_elem_attr;
DenseIntElementsAttr strides_elem_attr;
if (!begin_attr_type ||
!matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
return failure();
}
if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
return failure();
}
if (!strides_attr_type ||
!matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
return failure();
}
SmallVector<int32_t, 4> begin, end, strides;
SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
int num_input_dims = ranked_input_type.getRank();
SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
auto input_shape = ranked_input_type.getShape();
SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
int begin_mask = strided_slice_op.begin_mask();
int end_mask = strided_slice_op.end_mask();
PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
padding_begin, &begin_mask);
PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
&end_mask);
PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
padding_strides, nullptr);
if (begin == padded_begin && end == padded_end &&
strides == padded_strides &&
begin_mask == strided_slice_op.begin_mask() &&
end_mask == strided_slice_op.end_mask()) {
return failure();
}
auto begin_end_type =
RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
auto new_begin_attr = rewriter.create<ConstantOp>(
op->getLoc(), begin_end_type,
DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
auto new_end_attr = rewriter.create<ConstantOp>(
op->getLoc(), begin_end_type,
DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
auto strides_type =
RankedTensorType::get({static_cast<long>(padded_strides.size())},
rewriter.getIntegerType(32));
auto new_strides_attr = rewriter.create<ConstantOp>(
op->getLoc(), strides_type,
DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
auto attribute_type = rewriter.getIntegerType(64);
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
op, strided_slice_op.output().getType(), strided_slice_op.input(),
new_begin_attr, new_end_attr, new_strides_attr,
rewriter.getIntegerAttr(attribute_type, begin_mask),
rewriter.getIntegerAttr(attribute_type, end_mask),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.ellipsis_mask()),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.new_axis_mask()),
rewriter.getIntegerAttr(attribute_type,
strided_slice_op.shrink_axis_mask()));
return success();
}
};

View File

@ -57,6 +57,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
return mlir::ComplexType::get(builder.getF64Type());
case tflite::TensorType_INT8:
return builder.getIntegerType(8);
case tflite::TensorType_UINT64:
return builder.getIntegerType(64, /*isSigned=*/false);
}
}
@ -86,6 +88,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
return tensorflow::DT_STRING;
case tflite::TensorType_UINT8:
return tensorflow::DT_UINT8;
case tflite::TensorType_UINT64:
return tensorflow::DT_UINT64;
}
}

View File

@ -41,6 +41,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core/common_runtime:core_cpu_base_no_ops",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],

View File

@ -30,6 +30,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
#include "tensorflow/core/common_runtime/function_body.h"
#include "tensorflow/core/common_runtime/function_def_utils.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op.h"
@ -111,8 +113,24 @@ std::string ImportFunction(const std::string &functiondef_proto,
}
const std::string &function_name = functiondef.signature().name();
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
if (fdef == nullptr) {
s = tensorflow::errors::NotFound("Cannot find function ", function_name);
Set_TF_Status_from_Status(status, s);
return "// error";
}
std::unique_ptr<tensorflow::FunctionBody> fbody;
s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
&fbody);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return "// error";
}
mlir::MLIRContext context;
auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
if (!module.ok()) {
Set_TF_Status_from_Status(status, module.status());
return "// error";

View File

@ -42,10 +42,10 @@ filegroup(
"ir/tf_ops.td",
"ir/tfrt_ops.td",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)
@ -337,6 +337,7 @@ cc_library(
":tensorflow",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
@ -824,6 +825,22 @@ cc_library(
],
)
gentbl(
name = "tf_pass_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
"-gen-pass-decls -name TensorFlow",
"transforms/tf_passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/tf_passes.td",
td_srcs = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)
cc_library(
name = "tensorflow_passes",
srcs = [
@ -859,6 +876,7 @@ cc_library(
"transforms/mark_ops_for_outside_compilation.cc",
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc",
"transforms/outside_compiled_to_host_launch.cc",
"transforms/parallel_execute_to_islands.cc",
"transforms/parallelize_embedding_params_ops_pass.cc",
"transforms/promote_resources_to_args.cc",
@ -918,6 +936,8 @@ cc_library(
includes = ["include"],
textual_hdrs = [
"ir/tf_ops_helpers.inc",
"transforms/passes_detail.h",
"transforms/tf_passes.h.inc",
],
deps = [
":attribute_utils",
@ -940,6 +960,7 @@ cc_library(
":tensorflow_optimize_inc_gen",
":tensorflow_types",
":tf_data_optimization",
":tf_pass_inc_gen",
":tpu_rewrite_device_util",
":translate_utils",
":unroll_batch_matmul_pass",
@ -1043,6 +1064,7 @@ cc_library(
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"@llvm-project//llvm:Support",
],
)

View File

@ -31,7 +31,7 @@ limitations under the License.
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the absolute value of a tensor.";
let description = [{
@ -604,29 +604,6 @@ If `condition` evaluates to false, print the list of tensors in `data`.
let hasCanonicalizer = 1;
}
def TF_AssignOp : TF_Op<"Assign", [NoSideEffect]> {
let summary = "Update 'ref' by assigning 'value' to it.";
let description = [{
This operation outputs "ref" after the assignment is done.
This makes it easier to chain operations that need to use the reset value.
}];
let arguments = (ins
TF_Tensor:$ref,
TF_Tensor:$value,
DefaultValuedAttr<BoolAttr, "true">:$validate_shape,
DefaultValuedAttr<BoolAttr, "true">:$use_locking
);
let results = (outs
TF_Tensor:$output_ref
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AssignAddVariableOp : TF_Op<"AssignAddVariableOp", []> {
let summary = "Adds a value to the current value of a variable.";
@ -1025,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description.
TF_Tensor:$output
);
let verifier = [{
return Verify(*this);
}];
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
@ -1509,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
let hasFolder = 1;
}
def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise smallest integer not less than x.";
let arguments = (ins
@ -2936,6 +2913,67 @@ Converts the given variant tensor to an iterator and stores it in the given reso
let results = (outs);
}
def TF_DeserializeSparseOp : TF_Op<"DeserializeSparse", [NoSideEffect]> {
let summary = "Deserialize `SparseTensor` objects.";
let description = [{
The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where
the last dimension stores serialized `SparseTensor` objects and the other N
dimensions (N >= 0) correspond to a batch. The ranks of the original
`SparseTensor` objects must all match. When the final `SparseTensor` is
created, its rank is the rank of the incoming `SparseTensor` objects plus N;
the sparse tensors have been concatenated along new dimensions, one for each
batch.
The output `SparseTensor` object's shape values for the original dimensions
are the max across the input `SparseTensor` objects' shape values for the
corresponding dimensions. The new dimensions match the size of the batch.
The input `SparseTensor` objects' indices are assumed ordered in
standard lexicographic order. If this is not the case, after this
step run `SparseReorder` to restore index ordering.
For example, if the serialized input is a `[2 x 3]` matrix representing two
original `SparseTensor` objects:
index = [ 0]
[10]
[20]
values = [1, 2, 3]
shape = [50]
and
index = [ 2]
[10]
values = [4, 5]
shape = [30]
then the final deserialized `SparseTensor` will be:
index = [0 0]
[0 10]
[0 20]
[1 2]
[1 10]
values = [1, 2, 3, 4, 5]
shape = [2 50]
}];
let arguments = (ins
TensorOf<[TF_Str, TF_Variant]>:$serialized_sparse
);
let results = (outs
TF_Int64Tensor:$sparse_indices,
TF_Tensor:$sparse_values,
TF_Int64Tensor:$sparse_shape
);
TF_DerivedOperandTypeAttr Tserialized = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<1>;
}
def TF_DestroyResourceOp : TF_Op<"DestroyResourceOp", []> {
let summary = "Deletes the resource specified by the handle.";
@ -3464,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True])
}];
let arguments = (ins
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
TF_Tensor:$x,
TF_Tensor:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
);
@ -3805,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
let arguments = (ins
F32Tensor:$gradients,
F32Tensor:$inputs,
TF_Float32Tensor:$gradients,
TF_Float32Tensor:$inputs,
DefaultValuedAttr<F32Attr, "-6.0f">:$min,
DefaultValuedAttr<F32Attr, "6.0f">:$max,
@ -3815,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
);
let results = (outs
F32Tensor:$backprops
TF_Float32Tensor:$backprops
);
}
@ -3873,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien
let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
let arguments = (ins
F32Tensor:$gradients,
F32Tensor:$inputs,
F32Tensor:$min,
F32Tensor:$max,
TF_Float32Tensor:$gradients,
TF_Float32Tensor:$inputs,
TF_Float32Tensor:$min,
TF_Float32Tensor:$max,
DefaultValuedAttr<I64Attr, "8">:$num_bits,
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
);
let results = (outs
F32Tensor:$backprops_wrt_input,
F32Tensor:$backprop_wrt_min,
F32Tensor:$backprop_wrt_max
TF_Float32Tensor:$backprops_wrt_input,
TF_Float32Tensor:$backprop_wrt_min,
TF_Float32Tensor:$backprop_wrt_max
);
}
@ -3988,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9]
];
}
def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> {
def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise largest integer not greater than x.";
let arguments = (ins
@ -4939,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
}];
let arguments = (ins
F32Tensor:$predictions,
TF_Float32Tensor:$predictions,
TF_I32OrI64Tensor:$targets,
TF_I32OrI64Tensor:$k
);
let results = (outs
I1Tensor:$precision
TF_BoolTensor:$precision
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
@ -6523,6 +6561,8 @@ tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
}
def TF_MatrixSetDiagV2Op : TF_Op<"MatrixSetDiagV2", [NoSideEffect]> {
@ -6615,6 +6655,8 @@ tf.matrix_set_diag(diagonals, k = (-1, 0))
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
}
def TF_MatrixSetDiagV3Op : TF_Op<"MatrixSetDiagV3", [NoSideEffect]> {
@ -6870,7 +6912,7 @@ retained with length 1.
];
}
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
let summary = "Performs max pooling on the input.";
let arguments = (ins
@ -6894,6 +6936,9 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
// TF_LayoutSensitiveInterface:
StringRef GetOptimalLayout(const RuntimeDevices& devices);
LogicalResult UpdateDataFormat(StringRef data_format);
}];
}
@ -7810,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
}];
let arguments = (ins
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
TF_Tensor:$x,
TF_Tensor:$y,
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
);
@ -7989,7 +8034,7 @@ times by rerunning "MakeIterator".
);
}
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of ones with the same shape and type as x.";
let arguments = (ins
@ -8387,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> {
Computes the QR decomposition of each inner matrix in `tensor` such that
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
Currently, the gradient for the QR decomposition is well-defined only when
the first `P` columns of the inner matrix are linearly independent, where
`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
```python
# a is a tensor.
# q is a tensor of orthonormal matrices.
@ -9072,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph.
TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
}
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
let summary = "Computes rectified linear: `max(features, 0)`.";
let description = [{
@ -9098,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32)
}];
}
def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
let arguments = (ins
@ -10496,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs)
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> {
def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns element-wise integer closest to x.";
let description = [{
@ -10563,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
}
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
Rounds the values of a tensor to the nearest integer, element-wise.
}];
@ -11293,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> {
def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns an element-wise indication of the sign of a number.";
let description = [{
@ -12303,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
}
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
let summary = [{
Picks the best counter-based RNG algorithm based on device.
}];
let summary = "Picks the best counter-based RNG algorithm based on device.";
let description = [{
This op picks the best counter-based RNG algorithm based on device.
}];
let arguments = (ins);
let results = (outs
TF_Int32Tensor:$alg
);
@ -14046,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
for the existing tensor cannot be re-used, a copy is made and updated.
If `indices` contains duplicates, then their updates are accumulated (summed).
If `indices` contains duplicates, then we pick the last update for the index.
**WARNING**: The order in which updates are applied is nondeterministic, so the
output will be nondeterministic if `indices` contains duplicates -- because
of some numerical approximation issues, numbers summed in different order
may yield different results.
If an out of bound index is found on CPU, an error is returned.
**WARNING**: There are some GPU specific semantics for this operation.
- If an out of bound index is found, the index is ignored.
- The order in which updates are applied is nondeterministic, so the output
will be nondeterministic if `indices` contains duplicates.
`indices` is an integer tensor containing indices into a new tensor of shape
`shape`. The last dimension of `indices` can be at most the rank of `shape`:
`shape`.
indices.shape[-1] <= shape.rank
* `indices` must have at least 2 axes: `(num_updates, index_depth)`.
* The last axis of `indices` is how deep to index into `tensor` so this index
depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim`
The last dimension of `indices` corresponds to indices into elements
(if `indices.shape[-1] = shape.rank`) or slices
(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of
`shape`. `updates` is a tensor with shape
if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements.
if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input
`tensor`.
indices.shape[:-1] + shape[indices.shape[-1]:]
Each `update` has a rank of `tensor.rank - indices.shape[-1]`.
The overall shape of `updates` is:
The simplest form of scatter is to insert individual elements in a tensor by
index. For example, say we want to insert 4 scattered elements in a rank-1
tensor with 8 elements.
```
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
```
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt>
</div>
In Python, this scatter operation would look like this:
>>> indices = tf.constant([[4], [3], [1], [7]])
>>> updates = tf.constant([9, 10, 11, 12])
>>> tensor = tf.ones([8], dtype=tf.int32)
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32)
We can also, insert entire slices of a higher rank tensor all at once. For
example, if we wanted to insert two slices in the first dimension of a
rank-3 tensor with two matrices of new values.
In Python, this scatter operation would look like this:
>>> indices = tf.constant([[0], [2]])
>>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
... [7, 7, 7, 7], [8, 8, 8, 8]],
... [[5, 5, 5, 5], [6, 6, 6, 6],
... [7, 7, 7, 7], [8, 8, 8, 8]]])
>>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
[[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]]
[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[1 1 1 1]
[1 1 1 1]
[1 1 1 1]
[1 1 1 1]]]
Note that on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, the index is ignored.
For usage examples see the python [tf.tensor_scatter_nd_update](
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
}];
let arguments = (ins
@ -15038,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
}];
let arguments = (ins
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
@ -15047,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
);
let results = (outs
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
@ -15415,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of zeros with the same shape and type as x.";
let arguments = (ins

View File

@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another
FlatSymbolRefAttr:$cond,
FlatSymbolRefAttr:$body,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
);
let results = (outs
@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another
);
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let verifier = [{
return Verify(*this);
@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
let arguments = (ins
Variadic<AnyTensor>:$input,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
BoolAttr:$is_stateless
BoolAttr:$is_stateless,
// In TensorFlow, While has a special behavior where if `output_shapes`
// attribute is not empty, those shapes are used in its shape function
// as result shapes instead of propagating operand shapes as result shapes.
// This allows for different result shapes from operand shapes. While these
// shapes are imported and set as a part of the result type, there is no
// indicator differentiating between having no output shapes compared to
// having all unranked shapes. Thus this attribute is set to determine
// which shape function behavior to use for this op, specifically
// propagating operand shapes as result shapes when this attribute is not
// set, or preserving result shapes as is when this attribute is set.
UnitAttr:$shape_invariant
);
let results = (outs Variadic<AnyTensor>:$output);
@ -1972,4 +1995,31 @@ operations inside a TPU host.
);
}
def TF_AssignOp : TF_Op<"Assign", []> {
let summary = "Update 'ref' by assigning 'value' to it.";
let description = [{
This operation outputs "ref" after the assignment is done.
This makes it easier to chain operations that need to use the reset value.
This is a side-effecting operation because it will change the value of its
argument "ref" in addition to returning the results.
}];
let arguments = (ins
TF_Tensor:$ref,
TF_Tensor:$value,
DefaultValuedAttr<BoolAttr, "true">:$validate_shape,
DefaultValuedAttr<BoolAttr, "true">:$use_locking
);
let results = (outs
TF_Tensor:$output_ref
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
#endif // TF_OPS

View File

@ -2428,6 +2428,24 @@ static LogicalResult Verify(MatrixBandPartOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// MatrixSetDiagOp
//===----------------------------------------------------------------------===//
//
void MatrixSetDiagOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<MatrixSetDiagToV3>(context);
}
//===----------------------------------------------------------------------===//
// MatrixSetDiagV2Op
//===----------------------------------------------------------------------===//
void MatrixSetDiagV2Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<MatrixSetDiagV2ToV3>(context);
}
//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//
@ -2449,6 +2467,33 @@ LogicalResult MaxPoolOp::FoldOperandsPermutation(
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
}
LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
StringRef src_data_format = data_format();
auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
if (perm.empty()) return failure();
// Update data_format attribute and result types.
if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
return failure();
stridesAttr(ShuffleArrayAttr(strides(), perm));
explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
ksizeAttr(ShuffleArrayAttr(ksize(), perm));
return success();
}
StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
// Keep current data format if no GPUs are available or if explicit placement
// does not allow to use GPU for this operation.
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
return data_format();
// Defaults to NCHW.
return "NCHW";
}
//===----------------------------------------------------------------------===//
// MaxPoolGradOp
//===----------------------------------------------------------------------===//

View File

@ -1280,3 +1280,24 @@ func @testSumFoldBypass(%arg0: tensor<4x?xf16>, %arg1: tensor<*xi64>) -> tensor<
%0 = "tf.Sum"(%arg0, %arg1) { keep_dims = false }: (tensor<4x?xf16>, tensor<*xi64>) -> tensor<4x?xf16>
return %0 : tensor<4x?xf16>
}
// CHECK-LABEL: @testMatrixSetDiag
func @testMatrixSetDiag(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi64>) -> tensor<3x3xi64> {
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi64>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV3"(%arg0, %arg1, %[[ZERO]])
// CHECK-SAME: {align = "RIGHT_LEFT"}
// CHECK-SAME: (tensor<3x3xi64>, tensor<3xi64>, tensor<i32>) -> tensor<3x3xi64>
}
// CHECK-LABEL: @testMatrixSetDiagV2
func @testMatrixSetDiagV2(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi64>, %arg2: tensor<i32>) -> tensor<3x3xi64> {
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %arg2) : (tensor<3x3xi64>, tensor<3xi64>, tensor<i32>) -> tensor<3x3xi64>
return %0 : tensor<3x3xi64>
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV3"(%arg0, %arg1, %arg2)
// CHECK-SAME: {align = "LEFT_LEFT"}
}

View File

@ -4,6 +4,8 @@
// CHECK: func @main(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %[[ARG_1:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
// CHECK-NEXT: %[[RESULT_0:.*]]:2 = tf_device.remote_run "/job:worker/replica:0/task:0" @_job_worker_replica_0_task_0(%[[ARG_0]])
// CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.remote_run "/job:worker/replica:0/task:1" @_job_worker_replica_0_task_1(%[[ARG_1]])
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Const"() {value = dense<16> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK-NEXT: %[[RESULT_3:.*]] = "tf.AddV2"(%[[RESULT_2]], %[[RESULT_2]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: return %[[RESULT_0]]#0, %[[RESULT_0]]#1, %[[RESULT_1]] : tensor<i32>, tensor<i32>, tensor<i32>
// CHECK: func @_job_worker_replica_0_task_0(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) -> (tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:1"})
@ -21,5 +23,8 @@ func @main(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:
%1 = "tf.Mul"(%0, %0) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%2 = "tf.AddV2"(%arg0, %arg0) {device = "/job:worker/replica:0/task:0/device:CPU:1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%3 = "tf.AddV2"(%arg1, %arg1) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%4 = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
%5 = "tf.AddV2"(%4, %4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>
}

View File

@ -1,16 +1,47 @@
// RUN: tf-opt -tf-tensor-device-copy %s | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @fold_identity
// CHECK-SAME: ([[arg0:%.*]]: tensor<2x2xf32>, [[arg1:%.*]]: tensor<2x2xf32>
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32}} {
func @fold_identity(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = tf_executor.graph {
// CHECK: tf.MatMul
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NOT: tf.Identity
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
}
return %0 : tensor<2x2xf32>
// CHECK-LABEL: func @fold_identity_test
func @fold_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = tf_executor.graph {
// CHECK: tf.MatMul
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NOT: tf.Identity
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
}
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @keep_identity_test
func @keep_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = tf_executor.graph {
// CHECK: tf.MatMul
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: tf.Identity
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = "/device:CPU:0"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
}
return %0 : tensor<2x2xf32>
}
// CHECK: func @while_loop_test(%[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<i32>, %arg2: tensor<*xf32>)
func @while_loop_test(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<*xf32>) {
// CHECK-NEXT: tf.WhileRegion
%0:2 = "tf.WhileRegion"(%arg0, %arg2) ( {
// CHECK-NEXT: bb0(%[[ARG_3:.*]]: tensor<i32>, %[[ARG_4:.*]]: tensor<*xf32>)
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[ARG_3]])
%1 = "tf.Identity"(%arg3) : (tensor<i32>) -> tensor<i32>
%2 = "tf.Identity"(%arg1) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.NotEqual"(%[[RESULT_1]], %[[ARG_1]])
%3 = "tf.NotEqual"(%1, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%3) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
%cst = constant dense<1> : tensor<i32>
%1 = "tf.Sub"(%arg3, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%1, %arg4) : (tensor<i32>, tensor<*xf32>) -> ()
}) {is_stateless = true} : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
return
}

View File

@ -41,3 +41,34 @@ func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32
// CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
// CHECK: %[[V1]] : tensor<3x5x7xf32>
}
// CHECK-LABEL: @broadcast_eq
func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
%cst = constant dense<[5, 7]> : tensor<2xi32>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
%1 = "tf.Equal"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
return %1 : tensor<5x7xi1>
// CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
// CHECK: %[[V0]] : tensor<5x7xi1>
}
// CHECK-LABEL: @broadcast_neq
func @broadcast_neq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
%cst = constant dense<[5, 7]> : tensor<2xi32>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
%1 = "tf.NotEqual"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
return %1 : tensor<5x7xi1>
// CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
// CHECK: %[[V0]] : tensor<5x7xi1>
}
// CHECK-LABEL: @broadcast_both_operand
func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> tensor<5x7xf32> {
%cst = constant dense<[5, 7]> : tensor<2xi64>
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi64>) -> tensor<5x7xf32>
%1 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5x1xf32>, tensor<2xi64>) -> tensor<5x7xf32>
%2 = "tf.Add"(%0, %1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
return %2 : tensor<5x7xf32>
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
// CHECK: %[[V0]] : tensor<5x7xf32>
}

View File

@ -1,6 +1,8 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck --check-prefix=NONE %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=SOME %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=*:* -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=UNKNOWN %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=?,1,?:1,?,1 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=DYNAMIC %s
node {
name: "Add"
@ -61,3 +63,19 @@ versions {
# NONE-SAME: outputs = "Add"
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# NONE: fetch %[[add]]
# UNKNOWN-LABEL: func @main
# UNKNOWN-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<*xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<*xi32>) -> tensor<*xi32>
# UNKNOWN-SAME: control_outputs = ""
# UNKNOWN-SAME: inputs = "input0,input1"
# UNKNOWN-SAME: outputs = "Add"
# UNKNOWN: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# UNKNOWN: fetch %[[add]]
# DYNAMIC-LABEL: func @main
# DYNAMIC-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<?x1x?xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<1x?x1xi32>) -> tensor<*xi32>
# DYNAMIC-SAME: control_outputs = ""
# DYNAMIC-SAME: inputs = "input0,input1"
# DYNAMIC-SAME: outputs = "Add"
# DYNAMIC: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# DYNAMIC: fetch %[[add]]

View File

@ -1,4 +1,4 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s -tf-output-arrays=hash_table_node -o - | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s -tf-output-arrays=hash_table_node,variable_node -o - | FileCheck %s
node: {
name: "hash_table_node"
@ -22,6 +22,29 @@ node: {
}
}
}
node {
name: "variable_node"
op: "VariableV2"
attr {
key: "dtype"
value {
type: DT_INT64
}
}
attr {
key: "shape"
value {
shape {
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "Call"
op: "PartitionedCall"
@ -90,6 +113,9 @@ library {
# CHECK: tf.HashTableV2
# CHECK-SAME: shared_name = "hash_table_node"
# CHECK: tf.VariableV2
# CHECK-SAME: shared_name = "variable_node"
# CHECK: func private @create_resource
# CHECK: tf.HashTableV2
# CHECK-SAME: shared_name = "hash_table_node@create_resource"

View File

@ -1,4 +1,4 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
# composite While op in MLIR with is_stateless attribute set accordingly to
@ -6,6 +6,7 @@
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes")
node {
name: "StatefulWhile"
@ -73,6 +74,51 @@ node {
experimental_debug_info {
}
}
node {
name: "WhileWithOutputShapes"
op: "While"
input: "iter"
input: "val"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_FLOAT
}
}
}
attr {
key: "body"
value {
func {
name: "body"
}
}
}
attr {
key: "cond"
value {
func {
name: "cond"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
unknown_rank: true
}
}
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
@ -107,6 +153,23 @@ node {
}
}
}
node {
name: "main2"
op: "_Retval"
input: "WhileWithOutputShapes:1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 2
}
}
}
node {
name: "iter"
op: "Placeholder"

View File

@ -82,3 +82,20 @@ func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) ->
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32>
return %0 : tensor<1x256x150x150xf32>
}
// CHECK-LABEL: maxpool_nchw
func @maxpool_nchw(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> {
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
// CHECK: %[[R0:.*]] = "tf.Transpose"(%arg0, %[[CST]])
// CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}
// CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
// CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]])
%0 = "tf.MaxPool"(%arg0)
{
data_format = "NCHW",
ksize = [1, 1, 3, 3],
padding = "SAME",
strides = [1, 1, 2, 2]
} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
return %0 : tensor<1x64x56x56xf32>
}

View File

@ -1703,3 +1703,127 @@ func @convert_iota_3d() -> tensor<5x7x9xi32> {
return %0 : tensor<5x7x9xi32>
}
// CHECK-LABEL: func @convert_avgpool_valid(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK: }
func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
%0 = mhlo.constant dense<0.0> : tensor<f32>
%1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%5) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
%3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
return %3 : tensor<4x7x7x8xf32>
}
// CHECK-LABEL: func @convert_avgpool_valid_rw(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
// CHECK: }
func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
%1 = mhlo.constant dense<0.0> : tensor<f32>
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
%3 = "mhlo.reduce_window"(%0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
%4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
return %4 : tensor<4x7x7x8xf32>
}
// CHECK-LABEL: func @convert_avgpool_valid_3d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
// CHECK: }
func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
%0 = mhlo.constant dense<0.0> : tensor<f32>
%1 = mhlo.constant dense<27.0> : tensor<4x7x7x7x8xf32>
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%5) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<5xi64>,
padding = dense<0> : tensor<5x2xi64>,
window_dilations = dense<1> : tensor<5xi64>,
window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32>
%3 = mhlo.divide %2, %1 : tensor<4x7x7x7x8xf32>
return %3 : tensor<4x7x7x7x8xf32>
}
// CHECK-LABEL: func @convert_avgpool_same(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
// CHECK: return %[[VAL_1]] : tensor<4x8x8x8xf32>
// CHECK: }
func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
%1 = mhlo.constant dense<0.0> : tensor<f32>
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
%3 = "mhlo.reduce_window"(%0, %1) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%6) : (tensor<f32>) -> ()
}) {
base_dilations = dense<1> : tensor<4xi64>,
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
window_dilations = dense<1> : tensor<4xi64>,
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
%4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
return %4 : tensor<4x8x8x8xf32>
}
// CHECK-LABEL: func @convert_pad(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x128xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f32>) -> tensor<11x131xf32> {
// CHECK: %[[VAL_2:.*]] = constant dense<{{\[\[}}1, 2], [0, 3]]> : tensor<2x2xi64>
// CHECK: %[[VAL_3:.*]] = "tf.PadV2"(%[[VAL_0]], %[[VAL_2]], %[[VAL_1]]) : (tensor<8x128xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<11x131xf32>
// CHECK: return %[[VAL_3]] : tensor<11x131xf32>
// CHECK: }
func @convert_pad(%arg0: tensor<8x128xf32>, %arg1: tensor<f32>) -> tensor<11x131xf32> {
%0 = "mhlo.pad"(%arg0, %arg1) {
edge_padding_low = dense<[1, 0]> : tensor<2xi64>,
edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
interior_padding = dense<0> : tensor<2xi64>
} : (tensor<8x128xf32>, tensor<f32>) -> tensor<11x131xf32>
return %0 : tensor<11x131xf32>
}

View File

@ -244,9 +244,9 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.AddV2"([[PADDINGS]]#0, [[PADDINGS]]#1)
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.AddV2"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
// CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1)
// CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
@ -338,10 +338,10 @@ func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]])
// CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]])
// CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]])
// CHECK-DAG: [[VAL8:%.+]] = "tf.Add"([[VAL7]], [[VAL4]])
// CHECK-DAG: [[VAL8:%.+]] = "tf.AddV2"([[VAL7]], [[VAL4]])
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
// CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]])
// CHECK-DAG: [[VAL11:%.+]] = "tf.Add"([[VAL10]], [[VAL2]])
// CHECK-DAG: [[VAL11:%.+]] = "tf.AddV2"([[VAL10]], [[VAL2]])
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: return [[VAL11]]
@ -361,7 +361,7 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
// CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]])
// CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]])
// CHECK-DAG: [[VAL12:%.+]] = "tf.Add"([[VAL2]], [[VAL9]])
// CHECK-DAG: [[VAL12:%.+]] = "tf.AddV2"([[VAL9]], [[VAL2]])
// CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]])
// CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) :
// CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]])
@ -371,10 +371,10 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
// CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]])
// CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]])
// CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]])
// CHECK-DAG: [[VAL22:%.+]] = "tf.Add"([[VAL21]], [[VAL3]])
// CHECK-DAG: [[VAL22:%.+]] = "tf.AddV2"([[VAL21]], [[VAL3]])
// CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]])
// CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]])
// CHECK-DAG: [[VAL25:%.+]] = "tf.Add"([[VAL24]], [[VAL17]])
// CHECK-DAG: [[VAL25:%.+]] = "tf.AddV2"([[VAL24]], [[VAL17]])
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: return [[VAL25]]
@ -746,7 +746,7 @@ func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
// CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
%0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
@ -761,7 +761,7 @@ func @round_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
// CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
%0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>

View File

@ -1,9 +1,10 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -tf-graph-as-function -o - | FileCheck %s
// Verify arg attributes are exported as device assignment for arg nodes.
// Verify arg/ret attributes are exported as device assignment for arg/retval
// nodes.
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 121 : i32}} {
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32> {tf.device = "/CPU:1"})
attributes {tf.entry_function = {inputs = "args_0,args_1", outputs = "rets_0,rets_1"}} {
%0:2 = tf_executor.graph {
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) {T = ["tfdtype$DT_FLOAT", "tfdtype$DT_INT32"], device = "", name = "identity"} : (tensor<*xf32>, tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
@ -15,18 +16,39 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK: node {
// CHECK-NEXT: name: "args_0"
// CHECK-NEXT: op: "_Arg"
// CHECK: device: "/CPU:0"
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 0
// CHECK_NEXT: }
//
// CHECK: node {
// CHECK-NEXT: name: "args_1"
// CHECK-NOT: device: "/CPU:0"
// CHECK-NEXT: op: "_Arg"
// CHECK-NOT: device
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 1
//
// CHECK: node {
// CHECK: op: "IdentityN"
//
// CHECK: node {
// CHECK-NEXT: name: "rets_0"
// CHECK-NEXT: op: "_Retval"
// CHECK-NOT: device
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 0
//
// CHECK: node {
// CHECK-NEXT: name: "rets_1"
// CHECK-NEXT: op: "_Retval"
// CHECK: device: "/CPU:1"
// CHECK: attr {
// CHECK: key: "index"
// CHECK-NEXT: value {
// CHECK-NEXT: i: 1
// CHECK_NEXT: }

View File

@ -1,12 +1,13 @@
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%0:2 = tf_executor.graph {
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) {
%0:3 = tf_executor.graph {
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
%outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes")
tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
}
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
}
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK-NOT: name:
// CHECK: op: "While"
// CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK-NOT: name:
// CHECK: op: "StatelessWhile"
// CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
// CHECK: }
// CHECK: }
// CHECK: name: "WhileWithOutputShapes"
// CHECK-NOT: name:
// CHECK: op: "While"
// CHECK-NOT: is_stateless
// CHECK-NOT: shape_invariant
// CHECK: attr {
// CHECK: key: "output_shapes"
// CHECK: value {
// CHECK: list {
// CHECK: shape {
// CHECK: dim {
// CHECK: size: 5
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }

View File

@ -0,0 +1,153 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-outside-compiled-to-host-launch | FILECHECK_OPTS="" FileCheck %s
// expected-error@+1 {{'module' op bad 'tf.devices' attribute at index 0, not a string}}
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = [1]} {
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
func @invalid_device_attribute() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
func @empty_outside_compilation_attribute() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
// expected-error@+1 {{'tf.B' op requires non empty '_xla_outside_compilation' string attribute}}
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// Tests that TPU cluster with no outside compilation does not generate launch op.
// CHECK-LABEL: func @no_outside_compilation
// CHECK-NOT: "tf_device.launch"
func @no_outside_compilation() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// Tests the launch wrap of a single outside compiled cluster with no input or output dependencies.
// CHECK-LABEL: func @nodep_single_outside_compilation
func @nodep_single_outside_compilation() -> () {
// CHECK: "tf.A"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return
}
// Tests the launch wrap of a single outside compiled cluster with data parallelism.
// CHECK-LABEL: func @single_outside_compilation_with_replicate
func @single_outside_compilation_with_replicate(%arg0: tensor<?xi32>) -> () {
// CHECK: "tf.A"
// CHECK: tf_device.replicate
// CHECK-NEXT: "tf_device.cluster"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.C"
// CHECK-NOT: _xla_outside_compilation
// CHECK: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
"tf_device.cluster"() ( {
"tf.B"() : () -> ()
"tf.C"(%ri_0) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.D"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
tf_device.return
}
return
}
// Tests launch wrap of a single outside compiled cluster with input/output.
// CHECK-LABEL: func @single_outside_compilation_input_output
func @single_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
// CHECK: tf_device.return %[[B_OUTPUT]]
// CHECK: "tf.C"(%[[LAUNCH_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests launch wrap of multiple outside compiled cluster with input/output.
// CHECK-LABEL: func @multiple_outside_compilation_input_output
func @multiple_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
// CHECK: tf_device.return %[[B_OUTPUT]]
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[LAUNCH_OUTPUT]])
// CHECK-NEXT: %[[LAUNCH_OUTPUT2:[0-9]*]] = "tf_device.launch"
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]])
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[D_OUTPUT]])
// CHECK: tf_device.return %[[E_OUTPUT]]
// CHECK: "tf.F"(%[[LAUNCH_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
%7 = "tf.E"(%6) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
%8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %8 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
}

View File

@ -1286,3 +1286,41 @@ func @while_body(%arg0: !tf_res) -> !tf_res {
%0 = "tf.Cast"(%arg0) : (!tf_res) -> !tf_res
return %0 : !tf_res
}
// -----
// Tests passthrough tf.Cast ops are removed.
!tf_res_static = type tensor<!tf.resource<tensor<f32>>>
!tf_res_dynamic = type tensor<*x!tf.resource<tensor<f32>>>
// CHECK-LABEL: func @tpu_computation
func @tpu_computation(%arg0: !tf_res_static) {
"tf_device.cluster"() ( {
%0 = "tf.While"(%arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (!tf_res_static) -> !tf_res_dynamic
%1 = "tf.WhileRegion"(%arg0) ( {
^cond(%carg0: !tf_res_static):
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
"tf.Yield"(%2) : (tensor<i1>) -> ()
}, {
^body(%barg0: !tf_res_static):
// CHECK-NOT: tf.Cast
%2 = "tf.Cast"(%barg0) : (!tf_res_static) -> !tf_res_dynamic
"tf.Yield"(%2) : (!tf_res_dynamic) -> ()
}) {is_stateless = false} : (!tf_res_static) -> !tf_res_dynamic
tf_device.return
}) {} : () -> ()
return
}
func @while_cond(%arg0: !tf_res_static) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @while_body
func @while_body(%arg0: !tf_res_static) -> !tf_res_dynamic {
// CHECK-NOT: tf.Cast
%0 = "tf.Cast"(%arg0) : (!tf_res_static) -> !tf_res_dynamic
return %0 : !tf_res_dynamic
}

View File

@ -68,6 +68,35 @@ func @main() -> tensor<i32> {
return %size_out : tensor<i32>
}
// -----
// Test inferring shape from the result type of gather.
// CHECK-LABEL: func @main
func @main() -> tensor<2x3xf32> {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<2x3xf32>
return %gather : tensor<2x3xf32>
}
// -----
// Test inferring shape from the element_shape attribute of gather.
// CHECK-LABEL: func @main
func @main() -> tensor<*xf32> {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = #tf.shape<3>} : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<*xf32>
return %gather : tensor<*xf32>
}
// -----
// Test tensor array concat and split.

View File

@ -104,6 +104,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
pm.addPass(mlir::createInlinerPass());
pm.addPass(CreateTPUClusterCleanupAttributesPass());
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
pm.addNestedPass<FuncOp>(createCSEPass());
pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
pm.addPass(CreateTPUOutsideCompilationClusterPass());

View File

@ -158,6 +158,26 @@ def LogicalNotOfLess : Pat<(TF_LogicalNotOp (TF_LessOp $arg0, $arg1)),
def LogicalNotOfLessEqual : Pat<(TF_LogicalNotOp (TF_LessEqualOp $arg0, $arg1)),
(TF_GreaterOp $arg0, $arg1)>;
//===----------------------------------------------------------------------===//
// MatrixSetDiag op patterns.
//===----------------------------------------------------------------------===//
class GetI32Attr<int x>: NativeCodeCall<
"$_builder.getI32IntegerAttr(" # x # ")">;
class GetStrAttr<string x>: NativeCodeCall<
"$_builder.getStringAttr(\"" # x # "\")">;
def MatrixSetDiagToV3 : Pat<(TF_MatrixSetDiagOp $input, $diag),
(TF_MatrixSetDiagV3Op $input, $diag,
(TF_ConstOp (GetI32Attr<0>)),
(GetStrAttr<"RIGHT_LEFT">))>;
// MatrixSetDiagToV2 op implicitly used LEFT_LEFT alignment.
def MatrixSetDiagV2ToV3 : Pat<(TF_MatrixSetDiagV2Op $input, $diag, $k),
(TF_MatrixSetDiagV3Op $input, $diag, $k,
(GetStrAttr<"LEFT_LEFT">))>;
//===----------------------------------------------------------------------===//
// RealDiv op patterns.
//===----------------------------------------------------------------------===//

View File

@ -43,6 +43,8 @@ using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
constexpr const char *kDeviceAttr = "device";
constexpr const char *kTFDeviceAttr = "tf.device";
// TODO(donglin): Handle the case where the address of localhost is different
// from /job:localhost/replica:0/task:0.
constexpr const char *kLocalhost = "/job:localhost/replica:0/task:0";
constexpr const char *kErrorMessage =
"The operation that uses the operand is on a different host than the "
@ -53,8 +55,9 @@ constexpr const char *kErrorMessage =
std::string GetHost(llvm::StringRef device) {
ParsedName parsed_name;
DeviceNameUtils::ParseFullName(device.str(), &parsed_name);
return DeviceNameUtils::ParsedNameToString(
std::string result = DeviceNameUtils::ParsedNameToString(
DeviceNameUtils::AddressSpace(parsed_name));
return result.empty() ? kLocalhost : result;
}
std::string GetHost(Operation *op) {
@ -70,12 +73,9 @@ std::string GetHost(Operation *op) {
// 1) None of the job/replica/task is specified in the device name.
// 2) The job/replica/task in the device name are explicitly specified as
// /job:localhost/replica:0/task:0.
//
// TODO(dnglin): Handle the case where the address of localhost is different
// from /job:localhost/replica:0/task:0.
bool IsOnLocalHost(llvm::StringRef device) {
std::string host = GetHost(device);
return host.empty() || host == kLocalhost;
return host == kLocalhost;
}
// This structure contains the metadata of the per-host function. All operations

View File

@ -36,21 +36,15 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/types.h"
namespace mlir {
namespace TF {
namespace collection_ops_util {
Value CreateScalarConst(int value, OpBuilder builder, Location loc) {
tensorflow::Tensor scalar_tensor(tensorflow::DT_INT32, {});
scalar_tensor.scalar<tensorflow::int32>()() = value;
return builder.create<TF::ConstOp>(
loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc) {
auto attr = DenseIntElementsAttr::get(
RankedTensorType::get({}, builder.getI32Type()), value);
return builder.create<TF::ConstOp>(loc, attr);
}
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mlir {
namespace TF {
@ -34,7 +33,7 @@ namespace collection_ops_util {
// shape [max_element_count, element_shape].
// Creates an i32 scalar tf.Const.
Value CreateScalarConst(int value, OpBuilder builder, Location loc);
Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc);
// Creates an integer vector tf.Const.
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
@ -38,6 +39,12 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override;
private:
template <typename Op>
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
};
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
@ -47,7 +54,27 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure();
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
return RewriteOp(op, rewriter);
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
// incompatible_shape_error is `true` (what is also checked by the verifier).
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
return failure();
}
template <typename Op>
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
Operation* op, PatternRewriter& rewriter) const {
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
return failure();
}
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
Operation* op, PatternRewriter& rewriter) const {
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
return failure();
@ -56,6 +83,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape()) return failure();
bool changed = false;
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
// Check that the i'th operand is a broadcast.
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
@ -89,10 +117,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
// Update the operand of the op to be the operand of the broadcast.
rewriter.updateRootInPlace(
op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
return success();
changed = true;
}
return failure();
return success(changed);
}
void BroadcastFoldPass::runOnFunction() {

View File

@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
LogicalResult ConvertWhileOp(WhileOp while_op) {
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
while_op.output_shapes(), while_op.parallel_iterations(),
while_op.is_stateless());
while_op.parallel_iterations(), while_op.is_stateless(),
while_op.shape_invariant());
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
YieldOp cond_yield =

View File

@ -200,7 +200,16 @@ class FuseConv2DBiasAdd
// Performs a fusion of the following pattern(s), if possible:
// MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
using FuseMatMulBiasAdd = FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp>;
class FuseMatMulBiasAdd
: public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
using FuseContractionWithBiasAdd<MatMulOp,
_FusedMatMulOp>::FuseContractionWithBiasAdd;
bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
PatternRewriter &rewriter) const override {
return matmul.T().isF32() || matmul.T().isBF16();
}
};
void FusedKernelMatcherPass::runOnFunction() {
OwningRewritePatternList patterns;

View File

@ -21,15 +21,19 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
@ -44,6 +48,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace mlir {
namespace TF {
@ -479,18 +484,16 @@ template <typename ReductionOp>
LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
Block &body = function.front();
if (body.getNumArguments() != 2) return failure();
if (body.getOperations().size() != 2) return failure();
ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front());
if (!reduce_op) return failure();
if (reduce_op.lhs() != body.getArgument(0) ||
reduce_op.rhs() != body.getArgument(1))
return failure();
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
if (!return_op) return failure();
if (return_op.getNumOperands() != 1 ||
return_op.results().front() != reduce_op)
if (return_op.getNumOperands() != 1) return failure();
ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
return_op.getOperands().front().getDefiningOp());
if (!reduce_op) return failure();
if (reduce_op.lhs() != body.getArgument(0) ||
reduce_op.rhs() != body.getArgument(1))
return failure();
return success();
@ -654,6 +657,190 @@ class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
}
};
// Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D}
// operation when they cleanly map to 2D or 3D average pool with VALID or SAME
// padding:
// * div(reduce_sum_window(x), constant(sizeof(window)))
// * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DivOp div_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
auto rw =
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
if (!rw) return failure();
// Check that the reduce-window is a sum-reduce-window.
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
return failure();
// Check that this is a floating point reduce window with a rank of 4 or 5.
RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>();
if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
rw_type.getRank() <= 3 || rw_type.getRank() > 5)
return failure();
// Check that the Div op doesn't do broadcasting on the output of the reduce
// window.
if (div_op.getType() != rw.getType()) return failure();
// tf.avg_pool need at least 3 dimensions (batch, spatial, channel)
const uint64_t rank = rw.window_dimensions().size();
if (rank <= 2) return failure();
// If the init value isn't zero then it can't be an average pool.
if (!isFloatZero(rw.init_value())) return failure();
llvm::SmallVector<int64_t, 5> window_strides;
if (rw.window_strides().hasValue()) {
window_strides.insert(window_strides.end(),
rw.window_strides()->getValues<int64_t>().begin(),
rw.window_strides()->getValues<int64_t>().end());
} else {
window_strides.resize(rank, 1);
}
llvm::SmallVector<int64_t, 10> padding;
if (rw.padding().hasValue()) {
padding.insert(padding.begin(),
rw.padding()->getValues<int64_t>().begin(),
rw.padding()->getValues<int64_t>().end());
} else {
padding.resize(2 * rank, 0);
}
// Check that we don't do any reduction along the batch (first) and channel
// (last) dimensions.
const uint64_t batch_dim = 0;
const uint64_t channel_dim = rank - 1;
if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 ||
padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 ||
padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0)
return failure();
if (rw.window_dilations().hasValue() &&
!(rw.window_dilations()->isSplat() &&
rw.window_dilations()->getSplatValue<APInt>() == 1))
return failure();
if (rw.base_dilations().hasValue() &&
!(rw.base_dilations()->isSplat() &&
rw.base_dilations()->getSplatValue<APInt>() == 1))
return failure();
DenseFPElementsAttr divisor;
if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
// If the divisor is a constant then check that it matches with the number
// of elements inside the window what is required for a VALID AvgPool.
if (!divisor.isSplat()) return failure();
int64_t window_size = 1;
for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
window_size *= w;
}
if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
return failure();
// Check that we have no padding.
if (!llvm::all_of(padding, [](int64_t i) { return i == 0; }))
return failure();
return replaceWithAvgPool(
div_op, rw.operand(),
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
window_strides, "VALID", rewriter);
}
auto rw_rhs =
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
if (rw_rhs) {
// Check that RHS is a sum-reduce-window.
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
return failure();
// Check that the RHS is a reduce_window over a constant 1 input with 0 as
// the init value.
DenseFPElementsAttr rhs_input;
if (!isFloatZero(rw_rhs.init_value()) ||
!matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) ||
!rhs_input.isSplat() ||
!rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
return failure();
// Check that the two reduce window have the same window configuration.
if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
rw.window_strides() != rw_rhs.window_strides() ||
rw.window_dilations() != rw_rhs.window_dilations() ||
rw.base_dilations() != rw_rhs.base_dilations() ||
rw.padding() != rw_rhs.padding())
return failure();
if (llvm::all_of(padding, [](int64_t i) { return i == 0; }))
return replaceWithAvgPool(
div_op, rw.operand(),
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
window_strides, "VALID", rewriter);
RankedTensorType input_type =
rw.operand().getType().dyn_cast<RankedTensorType>();
RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>();
if (!input_type || !output_type) return failure();
// Check that the individual padding values are corresponding to SAME
// padding from TensorFlow.
for (uint64_t i = 1; i < rank - 1; ++i) {
int64_t padding_size =
(output_type.getShape()[i] - 1) * window_strides[i] +
rw.window_dimensions().getValue<int64_t>({i}) -
input_type.getShape()[i];
if (padding[2 * i] !=
tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) ||
padding[2 * i + 1] !=
tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2)))
return failure();
}
return replaceWithAvgPool(
div_op, rw.operand(),
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
window_strides, "SAME", rewriter);
}
return failure();
}
private:
bool isFloatZero(Value value) const {
DenseFPElementsAttr initial_value;
return matchPattern(value, m_Constant(&initial_value)) &&
initial_value.getNumElements() == 1 &&
initial_value.getValue<APFloat>({}).isZero();
}
LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
llvm::ArrayRef<int64_t> ksizes,
llvm::ArrayRef<int64_t> kstrides,
llvm::StringRef padding,
ConversionPatternRewriter &rewriter) const {
if (ksizes.size() == 4) {
rewriter.replaceOpWithNewOp<AvgPoolOp>(
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
rewriter.getStringAttr("NHWC"));
return success();
} else if (ksizes.size() == 5) {
rewriter.replaceOpWithNewOp<AvgPool3DOp>(
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
rewriter.getStringAttr("NDHWC"));
return success();
}
return failure();
}
};
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TF::TensorFlowDialect>();
@ -729,6 +916,25 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
return reshape.getResult();
}
// Converts mhlo.pad to tf.PadV2
Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
auto pad_op = cast<mhlo::PadOp>(old_op);
mlir::Location loc = pad_op.getLoc();
llvm::SmallVector<APInt, 8> padding;
for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
pad_op.edge_padding_high().getValues<APInt>())) {
padding.push_back(std::get<0>(p));
padding.push_back(std::get<1>(p));
}
auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
rewriter.getI64Type());
auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
auto padding_op = rewriter.create<ConstantOp>(loc, attr_type, padding_attr);
return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
padding_op, pad_op.padding_value());
}
// Returns true if broadcast_dimensions obey Tensorflow convention, as in new
// dimensions are added as prefix.
bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
@ -794,10 +1000,10 @@ static PassRegistration<LegalizeHloToTf> pass(
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
MLIRContext *context) {
patterns->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertSliceOp,
ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
populateWithGenerated(context, *patterns);
patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum,
ConvertIotaOpToTfRange>(context);
}
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {

View File

@ -216,3 +216,12 @@ def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
$precision_config),
(ConvertDotGeneralOp $old_value)>;
def IsZero : Constraint<CPred<
"$0.isSplat() && $0.getSplatValue<APInt>() == 0">>;
def ConvertPadOp : NativeCodeCall<
"ConvertPadOp($_builder, $0.getDefiningOp())">;
def : Pat<(HLO_PadOp:$old_value $input, $pad_value, $pad_low, $pad_high,
$pad_interior),
(ConvertPadOp $old_value),
[(IsZero $pad_interior)]>;

View File

@ -332,12 +332,13 @@ class LowerDynamicStitchOp : public RewritePattern {
class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
public:
explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
: RewritePattern(FakeQuantWithMinMaxVarsOp::getOperationName(),
{SubOp::getOperationName(), ConstOp::getOperationName(),
MulOp::getOperationName(), FloorOp::getOperationName(),
ClipByValueOp::getOperationName(),
DivOp::getOperationName(), RoundOp::getOperationName()},
1, context) {}
: RewritePattern(
FakeQuantWithMinMaxVarsOp::getOperationName(),
{AddV2Op::getOperationName(), SubOp::getOperationName(),
ConstOp::getOperationName(), MulOp::getOperationName(),
FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
DivOp::getOperationName(), RoundOp::getOperationName()},
1, context) {}
LogicalResult matchAndRewrite(Operation *src_op,
PatternRewriter &rewriter) const override {
@ -419,8 +420,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
op.getLoc(),
DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
quantized_input = rewriter.create<AddOp>(op.getLoc(), input_ty,
quantized_input, half_val);
quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
quantized_input, half_val);
quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
@ -428,8 +429,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
quantized_input, quant_to_float);
output =
rewriter.create<AddOp>(op.getLoc(), input_ty, output, nudged_float_min);
output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
nudged_float_min);
rewriter.replaceOp(op, {output});
return success();
@ -811,7 +812,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
CastOp::getOperationName(),
ConstOp::getOperationName(),
ConcatV2Op::getOperationName(),
AddOp::getOperationName(),
AddV2Op::getOperationName(),
PadOp::getOperationName(),
SplitOp::getOperationName(),
UnpackOp::getOperationName(),
@ -907,8 +908,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
auto paddings_split = rewriter.create<UnpackOp>(
loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
rewriter.getI64IntegerAttr(1));
auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0),
paddings_split.getResult(1));
auto paddings_sum = rewriter.create<AddV2Op>(
loc, paddings_split.getResult(0), paddings_split.getResult(1));
auto input_shape_tensor = rewriter.create<ConstOp>(
loc,
@ -918,7 +919,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
// padded_shape_tensor is the shape of padded.
auto padded_shape_tensor =
rewriter.create<AddOp>(loc, paddings_sum, input_shape_tensor);
rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
auto zero_i32 = rewriter.create<ConstOp>(
loc, GetScalarOfType(rewriter.getIntegerType(32), 0));

View File

@ -237,7 +237,7 @@ def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input),
(TF_SubOp $input, (TF_FloorOp:$floor $input)),
(TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
$floor,
(TF_AddOp
(TF_AddV2Op
(TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;

Some files were not shown because too many files have changed in this diff Show More