Merge remote-tracking branch origin/upstream/master
This commit is contained in:
commit
477e3022a8
10
ISSUES.md
10
ISSUES.md
@ -1,7 +1,9 @@
|
||||
If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance
|
||||
issue or a feature request or a build issue or a documentation issue (for small
|
||||
doc fixes please send a PR instead). 2. Make sure the Issue Template is filled
|
||||
out. 3. The issue should be related to the repo it is created in.
|
||||
If you open a GitHub Issue, here is our policy:
|
||||
|
||||
1. It must be a bug/performance issue or a feature request or a build issue or
|
||||
a documentation issue (for small doc fixes please send a PR instead).
|
||||
1. Make sure the Issue Template is filled out.
|
||||
1. The issue should be related to the repo it is created in.
|
||||
|
||||
**Here's why we have this policy:** We want to focus on the work that benefits
|
||||
the whole community, e.g., fixing bugs and adding features. Individual support
|
||||
|
15
RELEASE.md
15
RELEASE.md
@ -49,11 +49,13 @@
|
||||
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
|
||||
* Added support for saved model's session initializer through
|
||||
`TFLiteConverter.from_saved_model`.
|
||||
* Added dynamic range quantization support for the BatchMatMul op.
|
||||
|
||||
* TF Core:
|
||||
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
|
||||
`tf.while_loop`, and compositions like `tf.foldl`) computed with
|
||||
`tf.GradientTape` inside a `tf.function`.
|
||||
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
|
||||
|
||||
* `tf.summary`:
|
||||
* New `tf.summary.graph` allows manual write of TensorFlow graph
|
||||
@ -65,6 +67,19 @@
|
||||
supported MSVC version to 16.4 (current: 16.8).
|
||||
* See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
|
||||
|
||||
* TensorRT
|
||||
* Removed the deprecated `session_config` parameter for the TF1-TRT
|
||||
converter `TrtGraphConverter`. Previously, we issued a warning when the
|
||||
value of the parameter is not None.
|
||||
* The TF2-TRT converter `TrtGraphConverterV2` takes an object of class
|
||||
TrtConversionParams as a parameter. Removed three deprecated fields from
|
||||
this class: `rewriter_config_template`, `is_dynamic_op`, and
|
||||
`max_batch_size`. Previously, we issued a warning when the value of
|
||||
`rewriter_config_template` is not None. We issued an error when the
|
||||
value of `is_dynamic_op` is not True. We didn't use the value for
|
||||
`max_batch_size` for building TensorRT engines.
|
||||
* Issue a warning when function get_tensorrt_rewriter_config is used.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
10
configure.py
10
configure.py
@ -55,16 +55,16 @@ NCCL_LIB_PATHS = [
|
||||
|
||||
# List of files to configure when building Bazel on Apple platforms.
|
||||
APPLE_BAZEL_FILES = [
|
||||
'tensorflow/lite/experimental/ios/BUILD',
|
||||
'tensorflow/lite/experimental/objc/BUILD',
|
||||
'tensorflow/lite/experimental/swift/BUILD',
|
||||
'tensorflow/lite/ios/BUILD',
|
||||
'tensorflow/lite/objc/BUILD',
|
||||
'tensorflow/lite/swift/BUILD',
|
||||
'tensorflow/lite/tools/benchmark/experimental/ios/BUILD'
|
||||
]
|
||||
|
||||
# List of files to move when building for iOS.
|
||||
IOS_FILES = [
|
||||
'tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec',
|
||||
'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec',
|
||||
'tensorflow/lite/objc/TensorFlowLiteObjC.podspec',
|
||||
'tensorflow/lite/swift/TensorFlowLiteSwift.podspec',
|
||||
]
|
||||
|
||||
|
||||
|
@ -72,6 +72,14 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Config setting that disables the default logger, only logging
|
||||
# to registered TFLogSinks
|
||||
config_setting(
|
||||
name = "no_default_logger",
|
||||
define_values = {"no_default_logger": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Config setting for determining if we are building for Android.
|
||||
config_setting(
|
||||
name = "android",
|
||||
@ -588,9 +596,11 @@ config_setting(
|
||||
# DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST!
|
||||
# Instead, please use public APIs or public build rules TF provides.
|
||||
# If you need functionality that is not exposed, we will work with you to expand our public APIs.
|
||||
# TODO(b/173549186): Move Google-internal TF code out of learning/brain
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//learning/brain/mlir/...",
|
||||
"//learning/lib/ami/simple_ml/...",
|
||||
"//tensorflow/...",
|
||||
],
|
||||
@ -730,6 +740,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/c:kernels_hdrs",
|
||||
"//tensorflow/c:logging",
|
||||
"//tensorflow/c:ops_hdrs",
|
||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||
"//tensorflow/core/common_runtime:core_cpu_impl",
|
||||
|
@ -522,6 +522,7 @@ cc_library(
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_tensor",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
@ -542,13 +543,17 @@ tf_cuda_library(
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
":c_api_internal",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":c_api_internal",
|
||||
":tf_tensor",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor",
|
||||
"//tensorflow/c/experimental/stream_executor:stream_executor_internal",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
@ -51,6 +51,7 @@ tf_cuda_library(
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":immediate_execution_distributed_manager",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
@ -70,6 +71,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
@ -119,6 +121,7 @@ filegroup(
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_distributed_manager.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tape.h",
|
||||
@ -298,7 +301,7 @@ tf_cuda_cc_test(
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156
|
||||
deps = [
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
@ -466,6 +469,7 @@ tf_cuda_cc_test(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"nomac",
|
||||
"no_cuda_asan", # b/173825513
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
@ -584,6 +588,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_distributed_manager",
|
||||
hdrs = ["immediate_execution_distributed_manager.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
@ -592,12 +609,14 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_distributed_manager",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -21,16 +21,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
@ -39,59 +34,40 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
// "tensorflow/core/platform/platform.h" must be included first before using
|
||||
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
||||
#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed.h"
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
@ -100,611 +76,6 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
return (d == nullptr) ? "cpu:0" : d->name();
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||
tensorflow::mutex remote_devices_mu;
|
||||
int num_added_workers = added_remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_added_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_added_workers);
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
tensorflow::NewRemoteDevices(
|
||||
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
|
||||
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
||||
const tensorflow::Status& s,
|
||||
std::vector<tensorflow::Device*>* devices) {
|
||||
statuses[i] = s;
|
||||
if (s.ok()) {
|
||||
tensorflow::mutex_lock l(remote_devices_mu);
|
||||
for (tensorflow::Device* d : *devices) {
|
||||
remote_devices.emplace_back(d);
|
||||
}
|
||||
}
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_added_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status GetAllRemoteDevices(
|
||||
const std::vector<string>& remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
||||
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
|
||||
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
|
||||
remote_device_mgr.get()));
|
||||
*device_mgr = std::move(remote_device_mgr);
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status RemoveRemoteDevicesFromMgr(
|
||||
const std::vector<string>& removed_remote_workers,
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
||||
const std::vector<tensorflow::Device*> remote_devices =
|
||||
(remote_device_mgr->ListDevices());
|
||||
std::vector<tensorflow::Device*> devices_to_remove;
|
||||
for (tensorflow::Device* d : remote_devices) {
|
||||
for (const string& remote_worker : removed_remote_workers) {
|
||||
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
|
||||
d->name())) {
|
||||
devices_to_remove.emplace_back(d);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
|
||||
const string& local_worker,
|
||||
std::vector<string>* remote_workers) {
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(server);
|
||||
if (grpc_server == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
|
||||
}
|
||||
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
|
||||
remote_workers->erase(
|
||||
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
|
||||
remote_workers->end());
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void DifferentiateWorkerLists(const std::vector<string>* current_list,
|
||||
const std::vector<string>* new_list,
|
||||
std::vector<string>* added,
|
||||
std::vector<string>* removed,
|
||||
std::vector<string>* existing) {
|
||||
// Get STL set_difference and set_intersection with one list traversal.
|
||||
// Similar to the set_difference library function, the input lists
|
||||
// (`current_list` and `new_list`) must be sorted before calling the function.
|
||||
added->resize(new_list->size());
|
||||
removed->resize(current_list->size());
|
||||
existing->resize(current_list->size());
|
||||
std::vector<string>::const_iterator curr_it = current_list->begin();
|
||||
std::vector<string>::const_iterator new_it = new_list->begin();
|
||||
std::vector<string>::iterator added_it = added->begin();
|
||||
std::vector<string>::iterator removed_it = removed->begin();
|
||||
std::vector<string>::iterator existing_it = existing->begin();
|
||||
while (curr_it != current_list->end() && new_it != new_list->end()) {
|
||||
if (*curr_it < *new_it) {
|
||||
*removed_it++ = *curr_it++;
|
||||
} else if (*curr_it > *new_it) {
|
||||
*added_it++ = *new_it++;
|
||||
} else {
|
||||
*existing_it++ = *curr_it++;
|
||||
new_it++;
|
||||
}
|
||||
}
|
||||
removed_it = std::copy(curr_it, current_list->end(), removed_it);
|
||||
added_it = std::copy(new_it, new_list->end(), added_it);
|
||||
added->resize(added_it - added->begin());
|
||||
removed->resize(removed_it - removed->begin());
|
||||
existing->resize(existing_it - existing->begin());
|
||||
}
|
||||
|
||||
tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* client_cache,
|
||||
std::vector<string>* replaced_workers) {
|
||||
tensorflow::BlockingCounter counter(existing_workers->size());
|
||||
std::vector<tensorflow::Status> statuses(existing_workers->size());
|
||||
tensorflow::eager::KeepAliveRequest request;
|
||||
request.set_context_id(context_id);
|
||||
std::vector<tensorflow::eager::KeepAliveResponse> responses(
|
||||
existing_workers->size());
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] =
|
||||
client_cache->GetClient(existing_workers->at(i), &eager_client);
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
eager_client->KeepAliveAsync(
|
||||
&request, &responses[i],
|
||||
[i, &statuses, &counter](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < existing_workers->size(); i++) {
|
||||
// If the RPC fails (indicating that the requested ID doesn't exist on
|
||||
// remote), or the returned view ID is not equal to the local one
|
||||
// (indicating that the remote worker has a stale view of cluster), treat
|
||||
// the worker as replaced.
|
||||
if (!statuses[i].ok() ||
|
||||
responses[i].context_view_id() != context_view_id) {
|
||||
replaced_workers->emplace_back(existing_workers->at(i));
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
lazy_copy_remote_function_inputs);
|
||||
|
||||
eager_client->CreateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
tensorflow::StatusGroup sg;
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
|
||||
sg.Update(statuses[i]);
|
||||
}
|
||||
}
|
||||
return sg.as_summary_status();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||
&parsed_name)) {
|
||||
statuses[i] = tensorflow::errors::InvalidArgument(
|
||||
"Unable to parse ", remote_worker, " as a device name");
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
||||
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
||||
if (eager_client == nullptr) {
|
||||
statuses[i] = tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
}
|
||||
if (!statuses[i].ok()) {
|
||||
counter.DecrementCount();
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
||||
statuses[i] = s;
|
||||
delete response;
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, bool reset_context) {
|
||||
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
||||
// server object (which currently CHECK-fails) and we miss the error, instead,
|
||||
// we log the error, and then return to allow the user to see the error
|
||||
// message.
|
||||
#define LOG_AND_RETURN_IF_ERROR(...) \
|
||||
do { \
|
||||
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
||||
LOG(ERROR) << _status.error_message(); \
|
||||
return _status; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
string worker_name =
|
||||
tensorflow::strings::StrCat("/job:", server_def.job_name(),
|
||||
"/replica:0/task:", server_def.task_index());
|
||||
|
||||
// List of current remote workers before updating server_def. Unused if
|
||||
// resetting the server_def.
|
||||
std::vector<string> curr_remote_workers;
|
||||
// List of updated remote workers.
|
||||
std::vector<string> remote_workers;
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
||||
// if the server is a GRPC server or not.
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
}
|
||||
|
||||
tensorflow::uint64 context_id = context->GetContextId();
|
||||
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
||||
if (reset_context) {
|
||||
context_id = tensorflow::EagerContext::NewContextId();
|
||||
context_view_id = 0;
|
||||
// Make master eager context accessible by local eager service, which might
|
||||
// receive send tensor requests from remote workers.
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||
&remote_eager_workers));
|
||||
|
||||
// For cluster update, use a status group to aggregate statuses from
|
||||
// * adding and removing remote devices
|
||||
// * creating remote contexts on newly added workers
|
||||
// * updating remote contexts on existing workers
|
||||
// * updating the master context
|
||||
// Note that we should not return immediately on errors in the middle of these
|
||||
// updates to prevent cluster from having inconsistent context views.
|
||||
//
|
||||
// Unused if `reset_context` is True.
|
||||
tensorflow::StatusGroup sg;
|
||||
|
||||
// When updating an existing context, populate the following lists with:
|
||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
|
||||
// * replaced_workers: workers with the same task names and potentially the
|
||||
// same `hostname:port`s, but replaced by different processes
|
||||
std::vector<string> added_workers;
|
||||
std::vector<string> removed_workers;
|
||||
std::vector<string> existing_workers;
|
||||
std::vector<string> replaced_workers;
|
||||
|
||||
// New remote device manager created for new server_def. Unused if updating
|
||||
// server_def.
|
||||
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
|
||||
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
||||
remote_workers, grpc_server->master_env()->worker_cache,
|
||||
&new_remote_device_mgr));
|
||||
remote_device_mgr = new_remote_device_mgr.get();
|
||||
} else {
|
||||
context->ClearCachesAndDefaultExecutor();
|
||||
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
||||
// tensors for removed / replaced workers.
|
||||
|
||||
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
||||
if (remote_device_mgr == nullptr) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
||||
"Updating context with an invalid set of remote devices."));
|
||||
}
|
||||
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
|
||||
std::sort(remote_workers.begin(), remote_workers.end());
|
||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||
&added_workers, &removed_workers,
|
||||
&existing_workers);
|
||||
sg.Update(GetReplacedFromExistingWorkers(
|
||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||
remote_eager_workers.get(), &replaced_workers));
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "Updating cluster with following changes";
|
||||
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
||||
for (const string& w : removed_workers)
|
||||
VLOG(1) << " Removed worker " << w;
|
||||
for (const string& w : replaced_workers)
|
||||
VLOG(1) << " Replaced worker " << w;
|
||||
}
|
||||
if (!replaced_workers.empty()) {
|
||||
// Treat replaced workers as removed then added back, so that we recreate
|
||||
// remote devices and contexts, and re-register functions on those workers
|
||||
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
added_workers.insert(added_workers.end(), replaced_workers.begin(),
|
||||
replaced_workers.end());
|
||||
for (const string& w : replaced_workers) {
|
||||
existing_workers.erase(
|
||||
std::remove(existing_workers.begin(), existing_workers.end(), w),
|
||||
existing_workers.end());
|
||||
}
|
||||
}
|
||||
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||
grpc_server->master_env()->worker_cache,
|
||||
remote_device_mgr));
|
||||
}
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
||||
|
||||
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
||||
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
||||
&local_device_attributes);
|
||||
|
||||
// This request make sure that we can create Rendezvous properly between
|
||||
// Local and Remote context.
|
||||
tensorflow::eager::CreateContextRequest base_request;
|
||||
for (const auto& da : cluster_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
|
||||
// Initialize remote eager workers.
|
||||
if (reset_context) {
|
||||
const tensorflow::Status s = CreateRemoteContexts(
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request);
|
||||
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
||||
// the CreateRemoteContexts to fail. We currently only log instead of
|
||||
// directly returning the error, since returning here will cause the server
|
||||
// object to be destroyed (which currently CHECK-fails). The client will
|
||||
// see additional errors if ops are subsequently sent to the failed workers.
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
LOG(ERROR) << "Error when creating contexts on remote targets: "
|
||||
<< s.error_message()
|
||||
<< "\nExecuting remote ops or functions on these remote "
|
||||
"targets will fail.";
|
||||
}
|
||||
} else {
|
||||
if (sg.ok()) {
|
||||
// Create remote contexts on the newly added workers only if the master
|
||||
// has collected all device information from them (i.e., the
|
||||
// GetAllRemoteDevices call returns succussfully). Note that in rare cases
|
||||
// GetAllRemoteDevices can still fail even with RPCs configured to wait
|
||||
// until the remote workers to become alive. If the master creates remote
|
||||
// contexts on the workers whose devices are still not collected, those
|
||||
// workers will be treated as existing workers subsequently, so the master
|
||||
// will never get devices from them even with retrying UpdateServerDef.
|
||||
sg.Update(CreateRemoteContexts(
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
}
|
||||
if (!existing_workers.empty()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (const string& w : existing_workers) {
|
||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||
}
|
||||
}
|
||||
// The master's context_view_id will be incremented by one in the
|
||||
// UpdateRemoteMaster call later. We want existing workers to also have
|
||||
// the updated context_view_id, so we must set their context_view_id to
|
||||
// the master's current context_view_id + 1.
|
||||
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||
removed_workers, context_id,
|
||||
context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
}
|
||||
}
|
||||
|
||||
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||
if (reset_context) {
|
||||
tensorflow::RemoteRendezvous* r =
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
worker_session.get());
|
||||
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
||||
/*is_master=*/true, context);
|
||||
|
||||
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
||||
std::move(new_server), grpc_server->worker_env(), worker_session,
|
||||
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
||||
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
||||
std::move(remote_mgr)));
|
||||
|
||||
// NOTE: We start the server after all other initialization, because the
|
||||
// GrpcServer cannot be destroyed after it is started.
|
||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||
} else {
|
||||
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
/*isolate_session_state=*/true));
|
||||
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||
std::move(remote_eager_workers),
|
||||
added_workers, removed_workers));
|
||||
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||
}
|
||||
#undef LOG_AND_RETURN_IF_ERROR
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -731,11 +102,21 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
if (opts->use_tfrt) {
|
||||
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
||||
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
|
||||
tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
tfrt_context->SetDistributedManager(
|
||||
std::make_unique<tfrt::tf::DistributedManagerContextInterface>(
|
||||
tfrt_context->GetCoreRuntime()->GetHostContext()));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return tensorflow::wrap(tfrt_context);
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
#endif
|
||||
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
||||
}
|
||||
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||
@ -747,13 +128,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r));
|
||||
/*device_mgr_owned*/ true, r);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
eager_context->SetDistributedManager(
|
||||
std::make_unique<tensorflow::EagerContextDistributedManager>(
|
||||
eager_context));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
return tensorflow::wrap(eager_context);
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
@ -791,26 +177,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/true, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -835,14 +204,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
||||
server_def, /*reset_context=*/false, keep_alive_secs);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -854,44 +218,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
"TFE_ContextSetServerDef not supported on mobile");
|
||||
return false;
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
status->status =
|
||||
tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer.");
|
||||
return false;
|
||||
}
|
||||
tensorflow::WorkerInterface* wi =
|
||||
grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name);
|
||||
if (wi == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to find worker interface corresponding to task ", worker_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::GetStatusRequest request;
|
||||
tensorflow::GetStatusResponse response;
|
||||
tensorflow::Status remote_status;
|
||||
tensorflow::Notification done;
|
||||
wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true,
|
||||
[&remote_status, &done](const tensorflow::Status& s) {
|
||||
remote_status = s;
|
||||
done.Notify();
|
||||
});
|
||||
done.WaitForNotification();
|
||||
|
||||
// We set OK status so the call does not raise any exceptions. Instead, caller
|
||||
// users the return value to tell if the remote worker is alive.
|
||||
status->status = tensorflow::Status::OK();
|
||||
|
||||
if (remote_status.ok()) {
|
||||
return true;
|
||||
}
|
||||
LOG(INFO) << "Remote worker " << worker_name
|
||||
<< " is not alive: " << remote_status.error_message();
|
||||
return false;
|
||||
bool is_alive;
|
||||
status->status =
|
||||
tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
|
||||
worker_name, &is_alive);
|
||||
return is_alive;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
@ -226,7 +226,7 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
// the state of the ForwardOperation and call the tape as appropriate.
|
||||
// These APIs are mainly to faciliate testing and are subject to change.
|
||||
// These APIs are mainly to facilitate testing and are subject to change.
|
||||
namespace internal {
|
||||
Status Reset(AbstractOperation* op_, const char* op,
|
||||
const char* raw_device_name, ForwardOperation* forward_op_) {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_distributed_manager.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
@ -174,6 +176,18 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
|
||||
ImmediateExecutionTensorHandle* handle) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Distributed runtime related functions.
|
||||
//===--------------------------------------------------------------------===//
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// Set a distributed manager that helps set up, update, and check liveness
|
||||
// of member tasks in the cluster.
|
||||
virtual void SetDistributedManager(
|
||||
std::unique_ptr<ImmediateExecutionDistributedManager> distributed) = 0;
|
||||
|
||||
virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0;
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
|
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
45
tensorflow/c/eager/immediate_execution_distributed_manager.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
class ImmediateExecutionContext;
|
||||
class ServerDef;
|
||||
|
||||
class ImmediateExecutionDistributedManager {
|
||||
public:
|
||||
virtual ~ImmediateExecutionDistributedManager() {}
|
||||
|
||||
// Set up distributed execution environment on local and remote tasks.
|
||||
// When `reset_context` is true, initialize new cluster context state based on
|
||||
// cluster configurations provided in `server_def`; otherwise, update existing
|
||||
// context state with the provided `server_def`.
|
||||
// Contexts created on remote tasks will be considered stale and garbage
|
||||
// collected after `keep_alive_secs` of inactivity.
|
||||
virtual Status SetOrUpdateServerDef(const ServerDef& server_def,
|
||||
bool reset_context,
|
||||
int keep_alive_secs) = 0;
|
||||
|
||||
// Check if the remote task is alive.
|
||||
virtual Status CheckRemoteAlive(const std::string& remote_task_name,
|
||||
bool* is_alive) = 0;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/util/abstract_stack_trace.h"
|
||||
#include "tensorflow/core/util/managed_stack_trace.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
@ -48,10 +48,10 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Set stack trace to be used for potential async error reporting.
|
||||
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
|
||||
virtual void SetStackTrace(ManagedStackTrace stack_trace) = 0;
|
||||
|
||||
// Returns the stack trace set by `SetStackTrace` if exists.
|
||||
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
|
||||
virtual absl::optional<ManagedStackTrace> GetStackTrace() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
|
@ -7,10 +7,21 @@ load(
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = [
|
||||
"stream_executor.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stream_executor_hdrs",
|
||||
hdrs = ["stream_executor.h"],
|
||||
@ -49,9 +60,11 @@ cc_library(
|
||||
"stream_executor.h",
|
||||
"stream_executor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/stream_executor:executor_cache",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -44,6 +43,7 @@ using tensorflow::StatusFromTF_Status;
|
||||
|
||||
namespace stream_executor {
|
||||
using tensorflow::StringPiece;
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -188,41 +188,6 @@ port::Status ValidateSEPlatformRegistrationParams(
|
||||
}
|
||||
#undef VALIDATE_MEMBER
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
// Converts SE_EventStatus to Event::Status.
|
||||
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
switch (s) {
|
||||
@ -237,82 +202,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
|
||||
}
|
||||
}
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
OwnedTFStatus c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
// Converts DeviceMemoryBase to a C struct.
|
||||
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
|
||||
@ -321,14 +210,12 @@ SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
|
||||
device_memory_base.opaque = const_cast<void*>(mem->opaque());
|
||||
device_memory_base.size = mem->size();
|
||||
device_memory_base.payload = mem->payload();
|
||||
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
|
||||
return device_memory_base;
|
||||
}
|
||||
|
||||
DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
|
||||
DeviceMemoryBase base(mem.opaque, mem.size);
|
||||
base.SetPayload(mem.payload);
|
||||
// TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here.
|
||||
return base;
|
||||
}
|
||||
|
||||
@ -426,7 +313,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface {
|
||||
LOG(ERROR) << status.error_message();
|
||||
return absl::nullopt;
|
||||
}
|
||||
// TODO(annarev): validate SP_AllocatorStats.
|
||||
::stream_executor::AllocatorStats stats;
|
||||
stats.num_allocs = c_stats.num_allocs;
|
||||
stats.bytes_in_use = c_stats.bytes_in_use;
|
||||
|
@ -140,8 +140,9 @@ typedef enum SE_EventStatus {
|
||||
// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57
|
||||
typedef struct SP_DeviceMemoryBase {
|
||||
size_t struct_size;
|
||||
void* ext; // free-form data set by plugin
|
||||
void* ext; // Reserved for future use
|
||||
// Platform-dependent value representing allocated memory.
|
||||
// Note that the pointer does not have to be to the virtual address itself.
|
||||
void* opaque;
|
||||
uint64_t size; // Size in bytes of this allocation.
|
||||
uint64_t payload; // Value for plugin's use
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
@ -37,6 +38,13 @@ port::Status InitStreamExecutorPlugin(void* dso_handle);
|
||||
// testing).
|
||||
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn);
|
||||
|
||||
struct TFStatusDeleter {
|
||||
void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
|
||||
};
|
||||
|
||||
// This file implements core stream executor base classes in terms of
|
||||
// the C API defined in stream_executor.h. A class "CSomething" represents a
|
||||
// "Something" that can be manipulated via calls in the C interface.
|
||||
class CPlatform : public Platform {
|
||||
public:
|
||||
explicit CPlatform(SP_Platform platform,
|
||||
@ -83,5 +91,111 @@ class CPlatform : public Platform {
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
};
|
||||
|
||||
class CStream : public internal::StreamInterface {
|
||||
public:
|
||||
CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
stream_handle_(nullptr) {}
|
||||
~CStream() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
|
||||
port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
|
||||
return s;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (stream_handle_ != nullptr) {
|
||||
stream_executor_->destroy_stream(device_, stream_handle_);
|
||||
stream_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Stream Handle() { return stream_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Stream stream_handle_;
|
||||
};
|
||||
|
||||
class CEvent : public internal::EventInterface {
|
||||
public:
|
||||
CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
event_handle_(nullptr) {}
|
||||
~CEvent() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_event(device_, &event_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
port::Status Record(SP_Stream stream_handle) {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->record_event(device_, stream_handle, event_handle_,
|
||||
c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (event_handle_ != nullptr) {
|
||||
stream_executor_->destroy_event(device_, event_handle_);
|
||||
event_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Event Handle() { return event_handle_; }
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Event event_handle_;
|
||||
};
|
||||
|
||||
class CTimer : public internal::TimerInterface {
|
||||
public:
|
||||
CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
|
||||
SP_TimerFns* timer_fns)
|
||||
: device_(device),
|
||||
stream_executor_(stream_executor),
|
||||
timer_handle_(nullptr),
|
||||
timer_fns_(timer_fns) {}
|
||||
~CTimer() override { Destroy(); }
|
||||
|
||||
port::Status Create() {
|
||||
std::unique_ptr<TF_Status, TFStatusDeleter> c_status(TF_NewStatus());
|
||||
stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
|
||||
return tensorflow::StatusFromTF_Status(c_status.get());
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
if (timer_handle_ != nullptr) {
|
||||
stream_executor_->destroy_timer(device_, timer_handle_);
|
||||
timer_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
SP_Timer Handle() { return timer_handle_; }
|
||||
|
||||
uint64 Microseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_) / 1000;
|
||||
}
|
||||
|
||||
uint64 Nanoseconds() const override {
|
||||
return timer_fns_->nanoseconds(timer_handle_);
|
||||
}
|
||||
|
||||
private:
|
||||
SP_Device* device_;
|
||||
SP_StreamExecutor* stream_executor_;
|
||||
SP_Timer timer_handle_;
|
||||
SP_TimerFns* timer_fns_;
|
||||
};
|
||||
|
||||
} // namespace stream_executor
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
|
||||
|
@ -24,7 +24,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
// Required for IS_MOBILE_PLATFORM definition
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
|
||||
|
||||
// This file forms the basis of a stable ABI for third-party kernel
|
||||
// implementations. It is crucial that changes to this file are made cautiously
|
||||
@ -168,6 +174,35 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// This function is only for pluggable device.
|
||||
// It will return nullptr in all other cases.
|
||||
// This function is experimental and subject to change.
|
||||
SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Accessing device stream is not supported on mobile. File a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
||||
"important to you");
|
||||
return nullptr;
|
||||
#else
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
if (cc_ctx->op_device_context() == nullptr) { // CPU Device
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"Accessing device stream is not supported for a CPU device.");
|
||||
return nullptr;
|
||||
} else if (!cc_ctx->op_device_context()->IsPluggableDevice()) {
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"Accessing device stream is only supported for pluggable devices.");
|
||||
return nullptr;
|
||||
} else { // Is a PluggableDevice
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto c_stream = static_cast<stream_executor::CStream*>(
|
||||
cc_ctx->op_device_context()->stream()->implementation());
|
||||
return c_stream->Handle();
|
||||
}
|
||||
#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
|
||||
}
|
||||
|
||||
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
return cc_ctx->num_inputs();
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
@ -65,6 +66,11 @@ typedef struct TF_KernelBuilder TF_KernelBuilder;
|
||||
typedef struct TF_OpKernelConstruction TF_OpKernelConstruction;
|
||||
typedef struct TF_OpKernelContext TF_OpKernelContext;
|
||||
|
||||
// TF_InitKernel to do op/kernel registration.
|
||||
// Plugin should implement TF_InitKernel to register kernels. This function
|
||||
// should register all kernels in a plugin.
|
||||
void TF_InitKernel();
|
||||
|
||||
// Allocates a new kernel builder and returns a pointer to it.
|
||||
//
|
||||
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
||||
@ -128,6 +134,16 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
||||
// --------------------------------------------------------------------------
|
||||
// OpKernelContext routines
|
||||
|
||||
// TF_GetStream returns the SP_Stream available in ctx.
|
||||
// This function returns a stream only for devices registered using the
|
||||
// StreamExecutor C API
|
||||
// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return
|
||||
// nullptr and set error status in all other cases.
|
||||
// Experimental: this function doesn't have compatibility guarantees and subject
|
||||
// to change at any time.
|
||||
TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// TF_NumInputs returns the number of inputs available in ctx.
|
||||
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
||||
|
||||
|
@ -49,14 +49,12 @@ Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) {
|
||||
constexpr char longTagParam[] = "LONGTAG____________________________";
|
||||
constexpr float largeValueParam = 2352352.2623433;
|
||||
|
||||
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
|
||||
void BM_ScalarSummary##name##device(int iters) { \
|
||||
testing::StopTiming(); \
|
||||
TensorShape tensorshape(DIMARGS dims); \
|
||||
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
#define BM_ScalarSummaryDev(device, dims, name, tag, value) \
|
||||
void BM_ScalarSummary##name##device(::testing::benchmark::State& state) { \
|
||||
TensorShape tensorshape(DIMARGS dims); \
|
||||
auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \
|
||||
test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state); \
|
||||
} \
|
||||
BENCHMARK(BM_ScalarSummary##name##device);
|
||||
|
||||
BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2);
|
||||
|
@ -378,6 +378,23 @@ template <typename T>
|
||||
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
|
||||
TF_OpKernelContext* ctx);
|
||||
|
||||
REGISTER_OP("StreamOp").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestStream) {
|
||||
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
SP_Stream stream = TF_GetStream(ctx, s);
|
||||
// Stream is always null if device is not a pluggable device. More test
|
||||
// cases will be added when pluggable device mechanism is supported.
|
||||
EXPECT_EQ(stream, nullptr);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(s));
|
||||
TF_DeleteStatus(s);
|
||||
};
|
||||
|
||||
SetupOp("StreamOp", "StreamOp", my_compute_func);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
}
|
||||
|
||||
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
|
||||
|
||||
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
|
||||
|
@ -115,8 +115,8 @@ genrule(
|
||||
# have control of the full GPU.
|
||||
cmd = "CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location :make_test_graphs) --out_dir $(@D)",
|
||||
exec_tools = [":make_test_graphs"],
|
||||
tags = ["manual"],
|
||||
tools = [":make_test_graphs"],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
|
@ -127,7 +127,7 @@ def tf_library(
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --dump_fetch_nodes > $@"),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
# Run tfcompile on the build host, rather than forge, since it's
|
||||
# typically way faster on the local machine.
|
||||
local = 1,
|
||||
@ -162,7 +162,7 @@ def tf_library(
|
||||
"//tensorflow/python/tools:freeze_graph)" +
|
||||
freeze_args
|
||||
),
|
||||
exec_tools = ["//tensorflow/python/tools:freeze_graph"],
|
||||
tools = ["//tensorflow/python/tools:freeze_graph"],
|
||||
tags = tags,
|
||||
)
|
||||
tfcompile_graph = freeze_file
|
||||
@ -242,7 +242,7 @@ def tf_library(
|
||||
" --out_function_object=$(@D)/" + function_object_file +
|
||||
" " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag
|
||||
),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
# Run tfcompile on the build host since it's typically faster on the
|
||||
@ -281,7 +281,7 @@ def tf_library(
|
||||
" --out_session_module=$(@D)/" + session_module_pb +
|
||||
" " + flags
|
||||
),
|
||||
exec_tools = [tfcompile_tool],
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
local = 1,
|
||||
|
@ -39,7 +39,7 @@ struct XlaAutoJitFlag {
|
||||
int32 optimization_level_general;
|
||||
};
|
||||
|
||||
// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax
|
||||
// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
|
||||
// is:
|
||||
// <number>: sets general and single_gpu setting to the provided number.
|
||||
// single-gpu(<number>): sets the single_gpu setting to the provided number.
|
||||
|
@ -103,7 +103,9 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
|
||||
if (flr->config_proto()) {
|
||||
config_proto = *flr->config_proto();
|
||||
}
|
||||
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
|
||||
MlirBridgeRolloutPolicy policy =
|
||||
GetMlirBridgeRolloutPolicy(*fbody->graph, config_proto);
|
||||
if (policy != MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
|
||||
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
|
||||
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
|
||||
|
@ -449,15 +449,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
|
||||
xla::Shape output_host_shape = output.on_host_shape();
|
||||
xla::Shape output_device_shape = output.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output, &output_host_shape, &output_device_shape));
|
||||
stream, &output, &output_device_shape));
|
||||
|
||||
output.set_shapes(output_host_shape, output_device_shape);
|
||||
output.set_shapes(output_device_shape, output_device_shape);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes.push_back(shape);
|
||||
|
@ -51,10 +51,10 @@ filegroup(
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td",
|
||||
"include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td",
|
||||
],
|
||||
)
|
||||
@ -464,6 +464,7 @@ cc_library(
|
||||
":hlo",
|
||||
":lhlo",
|
||||
":lhlo_gpu",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -499,6 +500,7 @@ cc_library(
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
@ -637,10 +639,12 @@ cc_library(
|
||||
":lhlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:ViewLikeInterface",
|
||||
],
|
||||
@ -664,6 +668,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:ShapeTransforms",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:StandardOpsTransforms",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
@ -698,10 +703,12 @@ cc_library(
|
||||
deps = [
|
||||
":cycle_detector",
|
||||
":hlo",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -732,6 +739,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -752,6 +760,7 @@ cc_library(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -769,6 +778,8 @@ cc_library(
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -787,6 +798,7 @@ cc_library(
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -827,9 +839,11 @@ cc_library(
|
||||
":hlo",
|
||||
":lower_complex_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -417,6 +417,32 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
|
||||
[NoSideEffect, SameOperandsAndResultShape],
|
||||
HLO_FpTensor> {
|
||||
let summary = "Erfc operator";
|
||||
|
||||
let description = [{
|
||||
Computes the Gauss error function of `x` element-wise.
|
||||
|
||||
erf(x) = erf_impl(x) if |x| < 1
|
||||
= 1 - erfc_impl(x) otherwise
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
|
||||
[NoSideEffect, SameOperandsAndResultShape],
|
||||
HLO_FpTensor> {
|
||||
let summary = "Erfc operator";
|
||||
|
||||
let description = [{
|
||||
Computes an approximation of the error function complement (1 - erf(x)).
|
||||
|
||||
erfc(x) = erfc_impl(x) if |x| > 1
|
||||
= 1 - erf_impl(x) otherwise
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Broadcasting compare op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1046,6 +1046,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
|
||||
|
||||
let results = (outs HLO_StaticShapeTensor);
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
|
||||
Arg<I32Buffer, "", [MemWrite]>:$info,
|
||||
BoolAttr:$is_upper);
|
||||
BoolAttr:$is_lower);
|
||||
}
|
||||
|
||||
#endif // LHLO_GPU_OPS
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
|
@ -197,10 +197,11 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/139813999): specify required function signature in a type-safe way.
|
||||
def LHLO_ReduceOp: LHLO_Op<"reduce", [
|
||||
SameVariadicOperandSize,
|
||||
SingleBlockImplicitTerminator<"TerminatorOp">
|
||||
]>, BASE_HLO_ReduceOp {
|
||||
//
|
||||
// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
|
||||
// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
|
||||
// TODO(timshen): cleanup lmhlo.TerminatorOp.
|
||||
def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp {
|
||||
let arguments = (ins
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
|
||||
@ -655,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||
let builders = [
|
||||
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<Value, 4> getInputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load.memref());
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getOutputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.memref());
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getFusionParameters() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load);
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getFusionResults() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.tensor());
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TerminatorOp :
|
||||
|
@ -1939,6 +1939,12 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
return {};
|
||||
}
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<IdentityBroadcastReshape, IdentityBroadcastInDimReshape>(
|
||||
context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Case Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -31,3 +31,15 @@ def DynamicBroadcastToOwnShape_2 : Pat<
|
||||
def ShapeOfDynamicReshape : Pat<
|
||||
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
|
||||
(replaceWithValue $shape)>;
|
||||
|
||||
def HasSameType : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
|
||||
def IdentityBroadcastReshape : Pat<
|
||||
(HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)),
|
||||
(replaceWithValue $input),
|
||||
[(HasSameType $input, $op)]>;
|
||||
|
||||
def IdentityBroadcastInDimReshape : Pat<
|
||||
(HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)),
|
||||
(replaceWithValue $input),
|
||||
[(HasSameType $input, $op)]>;
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
@ -578,9 +579,8 @@ struct HloLegalizeToLhlo
|
||||
});
|
||||
|
||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||
populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
|
||||
lmhlo::CopyOp>(
|
||||
&context, converter, patterns);
|
||||
populateFuncOpTypeConversionPattern(patterns, &context, converter);
|
||||
populateCallOpTypeConversionPattern(patterns, &context, converter);
|
||||
populateShapeStructuralTypeConversionsAndLegality(&context, converter,
|
||||
patterns, target);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
@ -91,6 +91,31 @@ class LhloFuseLinalgPass
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
|
||||
auto alias = tensor_load.memref();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
|
||||
auto alias = tensor_to_memref.tensor();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
|
||||
auto alias = tensor_cast.source();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto regionInterface =
|
||||
@ -142,10 +167,10 @@ class LhloFuseLinalgPass
|
||||
|
||||
// Fuse producers of tiled linalg ops.
|
||||
llvm::SmallDenseSet<Operation*> erase_set;
|
||||
SmallVector<Operation*, 8> linalg_ops;
|
||||
SmallVector<LinalgOp, 8> linalg_ops;
|
||||
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
|
||||
for (auto* op : llvm::reverse(linalg_ops)) {
|
||||
for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
|
||||
for (LinalgOp op : llvm::reverse(linalg_ops)) {
|
||||
for (unsigned id = 0, e = op.getNumInputs(); id < e; ++id) {
|
||||
linalg::Aliases aliases;
|
||||
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
|
||||
if (auto info = fuseProducerOfBuffer(b, op, id, graph)) {
|
||||
|
@ -49,8 +49,9 @@ namespace {
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AtanOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) \
|
||||
sep fn(TanOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
@ -61,12 +61,16 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||
// mapValues always takes a function returning APInt, even when the output
|
||||
// is actually float.
|
||||
using func_type = llvm::APInt(const llvm::APInt&);
|
||||
|
||||
// TODO(hinsu): Correctly handle unsigned element types.
|
||||
bool is_bool = old_type.isInteger(1);
|
||||
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
||||
// Int -> Float
|
||||
return elements.mapValues(
|
||||
new_type, llvm::function_ref<func_type>([&newFloatType](
|
||||
new_type, llvm::function_ref<func_type>([&newFloatType, &is_bool](
|
||||
const llvm::APInt& intVal) {
|
||||
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
|
||||
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||
llvm::APFloat newDouble(static_cast<double>(val));
|
||||
bool loses_info = false;
|
||||
newDouble.convert(newFloatType.getFloatSemantics(),
|
||||
llvm::APFloat::rmNearestTiesToEven, &loses_info);
|
||||
@ -76,9 +80,10 @@ mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||
// new_type is Integer
|
||||
// Int -> Int
|
||||
return elements.mapValues(
|
||||
new_type,
|
||||
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
|
||||
return llvm::APInt(bit_width, intVal.getSExtValue());
|
||||
new_type, llvm::function_ref<func_type>([&bit_width, &is_bool](
|
||||
const llvm::APInt& intVal) {
|
||||
int64_t val = is_bool ? intVal.getZExtValue() : intVal.getSExtValue();
|
||||
return llvm::APInt(bit_width, val);
|
||||
}));
|
||||
}
|
||||
|
||||
|
@ -1483,3 +1483,21 @@ func @pad_fold() -> tensor<4x5xi32> {
|
||||
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
|
||||
// CHECK-SAME: ]> : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_broadcast_reshape
|
||||
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||
%0 = "mhlo.broadcast"(%arg0) {
|
||||
broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||
return %1 : tensor<128xf32>
|
||||
// CHECK: return %arg0 : tensor<128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_broadcast_in_dim_reshape
|
||||
func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32>
|
||||
return %1 : tensor<128xf32>
|
||||
// CHECK: return %arg0 : tensor<128xf32>
|
||||
}
|
||||
|
@ -123,6 +123,17 @@ func @const_int_bf16() -> tensor<bf16> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_bool_f32
|
||||
func @const_bool_f32() -> tensor<2xf32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32>
|
||||
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
||||
@ -145,8 +156,8 @@ func @const_int_narrowing() -> tensor<i32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_widening
|
||||
func @const_int_widening() -> tensor<i64> {
|
||||
// CHECK-LABEL: func @const_bool_widening
|
||||
func @const_bool_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
@ -156,6 +167,17 @@ func @const_int_widening() -> tensor<i64> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_widening
|
||||
func @const_int_widening() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_negative_int_widening
|
||||
func @const_negative_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
||||
|
@ -372,3 +372,58 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
||||
// PLOOP: else
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: scf.yield
|
||||
|
||||
// -----
|
||||
|
||||
// Confirm that tiling information is passed through tensor_load, tensor_cast
|
||||
// and memref_to_tensor operations.
|
||||
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||
-> memref<?xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%1 = alloc() : memref<32xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = tensor_load %1 : memref<32xf32>
|
||||
%3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
|
||||
%4 = tensor_to_memref %3 : memref<?xf32>
|
||||
return %4 : memref<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_ops
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: tensor_load
|
||||
// CHECK: tensor_cast
|
||||
// CHECK: tensor_to_memref
|
||||
|
||||
// TILED-LABEL: func @tensor_ops
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: tensor_load
|
||||
// TILED: tensor_cast
|
||||
// TILED: tensor_to_memref
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @tensor_ops
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: tensor_load
|
||||
// PLOOP: tensor_cast
|
||||
// PLOOP: tensor_to_memref
|
||||
|
@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
||||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
||||
%scratch = alloc() : memref<32xi8>
|
||||
%info = alloc() : memref<32xi32>
|
||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
|
||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
|
||||
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -35,9 +35,9 @@ filegroup(
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -667,6 +667,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:logging",
|
||||
@ -704,6 +705,7 @@ cc_library(
|
||||
":convert_type",
|
||||
":flatbuffer_tflite_operator_lib",
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
|
@ -167,7 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
case 32:
|
||||
return tflite::TensorType_INT32;
|
||||
case 64:
|
||||
return tflite::TensorType_INT64;
|
||||
return itype.isUnsigned() ? tflite::TensorType_UINT64
|
||||
: tflite::TensorType_INT64;
|
||||
}
|
||||
} else if (auto q_uniform_type =
|
||||
type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
|
||||
@ -386,19 +387,23 @@ class Translator {
|
||||
// internal error.
|
||||
static Optional<std::string> Translate(
|
||||
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
|
||||
bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
private:
|
||||
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
|
||||
explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper)
|
||||
: module_(module),
|
||||
name_mapper_(*op_or_arg_name_mapper),
|
||||
builder_(kInitialBufferSize),
|
||||
saved_model_tags_(saved_model_tags) {
|
||||
saved_model_tags_(saved_model_tags),
|
||||
select_user_tf_ops_(select_user_tf_ops) {
|
||||
// The first buffer must be empty according to the schema definition.
|
||||
empty_buffer_ = tflite::CreateBuffer(builder_);
|
||||
buffers_.push_back(empty_buffer_);
|
||||
@ -574,6 +579,8 @@ class Translator {
|
||||
|
||||
// Set of saved model tags, if any.
|
||||
const std::unordered_set<std::string> saved_model_tags_;
|
||||
// User's defined ops allowed with Flex.
|
||||
const std::unordered_set<std::string> select_user_tf_ops_;
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Value val) {
|
||||
@ -1103,12 +1110,15 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
resource_ops_.insert(node_def->op());
|
||||
}
|
||||
|
||||
const bool is_allowed_flex_op =
|
||||
IsAllowlistedFlexOp(node_def->op()) ||
|
||||
((select_user_tf_ops_.count(node_def->op()) != 0) &&
|
||||
(tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != nullptr));
|
||||
// Flex op case
|
||||
// Eventually, the allowlist will go away and we will rely on some TF op
|
||||
// trait (e.g. No side effect) to determine if it is a supported "Flex"
|
||||
// op or not.
|
||||
if (enabled_op_types_.contains(OpType::kSelectTf) &&
|
||||
IsAllowlistedFlexOp(node_def->op())) {
|
||||
if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
|
||||
// Construct ops as flex op encoding TensorFlow node definition
|
||||
// as custom options.
|
||||
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
||||
@ -1159,7 +1169,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
}
|
||||
|
||||
// Insert failed op to `flex_ops` or `custom_ops`.
|
||||
if (IsAllowlistedFlexOp(node_def->op())) {
|
||||
if (is_allowed_flex_op) {
|
||||
failed_flex_ops_.insert(os.str());
|
||||
} else {
|
||||
failed_custom_ops_.insert(os.str());
|
||||
@ -1619,12 +1629,15 @@ bool UpdateEntryFunction(ModuleOp module) {
|
||||
|
||||
Optional<std::string> Translator::Translate(
|
||||
ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops,
|
||||
bool emit_custom_ops, const std::unordered_set<std::string>& tags,
|
||||
bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
if (!UpdateEntryFunction(module)) return llvm::None;
|
||||
if (!IsValidTFLiteMlirModule(module)) return llvm::None;
|
||||
Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, tags, op_or_arg_name_mapper);
|
||||
emit_custom_ops, select_user_tf_ops, tags,
|
||||
op_or_arg_name_mapper);
|
||||
return translator.TranslateInternal();
|
||||
}
|
||||
|
||||
@ -1876,9 +1889,22 @@ bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
std::unordered_set<std::string> select_user_tf_ops;
|
||||
return MlirToFlatBufferTranslateFunction(
|
||||
module, serialized_flatbuffer, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
||||
op_or_arg_name_mapper);
|
||||
}
|
||||
|
||||
bool tflite::MlirToFlatBufferTranslateFunction(
|
||||
ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper) {
|
||||
auto maybe_translated = Translator::Translate(
|
||||
module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops,
|
||||
saved_model_tags, op_or_arg_name_mapper);
|
||||
select_user_tf_ops, saved_model_tags, op_or_arg_name_mapper);
|
||||
if (!maybe_translated) return true;
|
||||
*serialized_flatbuffer = std::move(*maybe_translated);
|
||||
return false;
|
||||
|
@ -52,6 +52,14 @@ bool MlirToFlatBufferTranslateFunction(
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
|
||||
// Same as the above but with a list of allowed user's defined ops.
|
||||
bool MlirToFlatBufferTranslateFunction(
|
||||
mlir::ModuleOp module, std::string* serialized_flatbuffer,
|
||||
bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper);
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_
|
||||
|
@ -64,6 +64,7 @@ limitations under the License.
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
@ -149,7 +150,7 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
|
||||
|
||||
int64_t storage_min = QuantizedType::getDefaultMinimumForInteger(
|
||||
is_signed, storage_type.getWidth()) +
|
||||
is_weight_buffer;
|
||||
static_cast<int>(is_weight_buffer);
|
||||
int64_t storage_max = QuantizedType::getDefaultMaximumForInteger(
|
||||
is_signed, storage_type.getWidth());
|
||||
uint32_t flags =
|
||||
@ -177,12 +178,25 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
|
||||
quant_params.zero_point.at(0), storage_min, storage_max);
|
||||
}
|
||||
|
||||
// import float tensor with calibration value into calibrated quantized type.
|
||||
StatusOr<QuantizedType> GetCalibratedQuantizedType(const TensorT& tensor,
|
||||
Builder builder) {
|
||||
if (tensor.quantization == nullptr) {
|
||||
return errors::InvalidArgument("The tensor is not quantized.");
|
||||
}
|
||||
auto raw_elem_type = ConvertElementType(tensor.type, builder);
|
||||
float min = tensor.quantization->min[0];
|
||||
float max = tensor.quantization->max[0];
|
||||
return mlir::quant::CalibratedQuantizedType::get(raw_elem_type, min, max);
|
||||
}
|
||||
|
||||
// TODO(b/138222071) Remove shapeless_are_scalars once we can reliably
|
||||
// make that distinction and don't have to rely on context
|
||||
// (input to main and constants must have static shape)
|
||||
StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
||||
bool shapeless_are_scalars = false,
|
||||
bool is_constant = false) {
|
||||
bool is_constant = false,
|
||||
bool is_intermediate = false) {
|
||||
mlir::Type elem_type = ConvertElementType(tensor.type, builder);
|
||||
// TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere
|
||||
// if it's set
|
||||
@ -191,6 +205,13 @@ StatusOr<mlir::TensorType> GetTensorType(const TensorT& tensor, Builder builder,
|
||||
GetQuantizedType(tensor, builder, is_constant));
|
||||
}
|
||||
|
||||
// Intermediate tensors with calibration value (but not scale and zero points)
|
||||
// should return calibrated quantized type.
|
||||
if (is_intermediate && tensor.quantization != nullptr &&
|
||||
!IsQuantized(tensor)) {
|
||||
TF_ASSIGN_OR_RETURN(elem_type, GetCalibratedQuantizedType(tensor, builder));
|
||||
}
|
||||
|
||||
if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) {
|
||||
return RankedTensorType::get({}, elem_type);
|
||||
}
|
||||
@ -1033,7 +1054,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
}
|
||||
|
||||
// Intermediate tensors for tfl.lstm are used to carry quantization range
|
||||
// Intermediate tensors for LSTMs are used to carry quantization range
|
||||
// in their types, so we only need and extract their types.
|
||||
std::vector<mlir::TensorType> intermediate_types;
|
||||
intermediate_types.reserve(5);
|
||||
@ -1041,7 +1062,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto type, GetTensorType(*subgraph.tensors[intermediate], builder,
|
||||
/*shapeless_are_scalars=*/true,
|
||||
/*is_constant=*/true));
|
||||
/*is_constant=*/false,
|
||||
/*is_intermediate=*/true));
|
||||
intermediate_types.emplace_back(type);
|
||||
}
|
||||
|
||||
@ -1135,7 +1157,6 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
|
||||
auto builder = Builder(context);
|
||||
|
||||
|
||||
std::vector<std::string> func_names;
|
||||
for (auto& subgraph : model->subgraphs) {
|
||||
func_names.push_back(subgraph->name);
|
||||
|
@ -1978,6 +1978,10 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1);
|
||||
if (getElementTypeOrSelf(input()) == getElementTypeOrSelf(getType())) {
|
||||
return input();
|
||||
}
|
||||
|
||||
// For now, only supports cast between integer types.
|
||||
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!elements_attr) {
|
||||
|
@ -483,9 +483,9 @@ value of each element in `x`. For example, if x is an input element and y is
|
||||
an output element, this operation computes \\(y = |x|\\).
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_FpTensor:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QI8, QI16]>:$x);
|
||||
|
||||
let results = (outs TFL_FpTensor:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
@ -587,15 +587,15 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
|
||||
let arguments = (ins
|
||||
TFL_I32Tensor:$output_shape,
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$weights,
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
TFL_TensorOfOrNone<[F32, QI32]>:$bias,
|
||||
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$weights,
|
||||
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
|
||||
TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias,
|
||||
TFL_PaddingAttr:$padding,
|
||||
Confined<I32Attr, [IntPositive]>:$stride_h,
|
||||
Confined<I32Attr, [IntPositive]>:$stride_w
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -624,7 +624,7 @@ def TFL_AveragePool2DOp:
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
|
||||
I32Attr:$filter_height,
|
||||
I32Attr:$filter_width,
|
||||
TFL_PaddingAttr:$padding,
|
||||
@ -633,7 +633,7 @@ def TFL_AveragePool2DOp:
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "Pool2DOptions";
|
||||
@ -947,7 +947,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input,
|
||||
TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter,
|
||||
TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter,
|
||||
TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias,
|
||||
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
@ -999,14 +999,14 @@ in the batch dimensions and broadcasting.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, QI8]>:$x,
|
||||
TFL_TensorOf<[F32, QI8]>:$y,
|
||||
TFL_TensorOf<[F32, QI8, QI16]>:$x,
|
||||
TFL_TensorOf<[F32, QI8, QI16]>:$y,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, QI8]>:$output
|
||||
TFL_TensorOf<[F32, QI8, QI16]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -1026,7 +1026,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$params,
|
||||
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$params,
|
||||
TFL_TensorOf<[I32, I64]>:$indices,
|
||||
I32Attr:$axis
|
||||
);
|
||||
@ -1038,7 +1038,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
];
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8]>:$output
|
||||
TFL_TensorOf<[F32, I1, I8, I32, I64, TFL_Str, UI8, QI8, QUI8, QI16]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -1750,12 +1750,12 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$input,
|
||||
ins TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$input,
|
||||
// Slope of the activation function at x < 0.
|
||||
F32Attr:$alpha
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
@ -1977,12 +1977,12 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$max
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max
|
||||
);
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
@ -2005,13 +2005,13 @@ def TFL_MeanOp : TFL_Op<"mean", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$input,
|
||||
TFL_TensorOf<[I32, I64]>:$axis,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8]>:$output);
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, UI8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -2090,13 +2090,13 @@ equivalent to setting:
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||
TFL_I32OrI64Tensor:$begin,
|
||||
TFL_I32OrI64Tensor:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2116,12 +2116,12 @@ def TFL_SumOp: TFL_Op<"sum", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||
TFL_I32Tensor:$axes,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -2139,13 +2139,13 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||
TFL_I32Tensor:$axes,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -2163,13 +2163,13 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||
TFL_I32Tensor:$axes,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -2186,13 +2186,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||
TFL_I32Tensor:$axes,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -2210,12 +2210,12 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$lhs,
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$rhs
|
||||
ins TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8]>:$min
|
||||
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min
|
||||
);
|
||||
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
@ -2364,10 +2364,10 @@ def TFL_PadOp : TFL_Op<"pad", [
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
let arguments = (ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||
TFL_I32OrI64Tensor:$padding);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
@ -2500,9 +2500,9 @@ def TFL_ReluOp: TFL_Op<"relu", [
|
||||
x -> max(0, x)
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8, QI16]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
@ -2828,11 +2828,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$input,
|
||||
F32Attr:$beta
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8, QI16]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -3058,12 +3058,12 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$input,
|
||||
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$input,
|
||||
TFL_TensorOf<[I32]>:$perm
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64]>:$output
|
||||
TFL_TensorOf<[I32, F32, I8, UI8, QI8, QUI8, TFL_Quint8, I1, I64, QI16]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -3330,14 +3330,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input,
|
||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$input,
|
||||
TFL_I32Tensor:$size,
|
||||
BoolAttr:$align_corners,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output
|
||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8, QI16]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -3830,6 +3830,9 @@ Ba et al. 'Layer Normalization'
|
||||
|
||||
// Types of the optional intermediate tensors, which exist for fully
|
||||
// quantized LSTM op and hold the ranges of the intermediate tensors.
|
||||
// The type for intermediate tenssors are be quant.calibrated when imported
|
||||
// to only store calibrated min, max values. The proper quantization spec is
|
||||
// determined while going through quantization passes.
|
||||
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
|
||||
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
|
||||
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
|
||||
@ -3945,6 +3948,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
|
||||
// Types of the optional intermediate tensors, which exist for fully
|
||||
// quantized op and hold the ranges of the intermediate tensors.
|
||||
// The type for intermediate tenssors are be quant.calibrated when imported
|
||||
// to only store calibrated min, max values. The proper quantization spec is
|
||||
// determined while going through quantization passes.
|
||||
OptionalAttr<TypeAttr>:$input_to_input_intermediate,
|
||||
OptionalAttr<TypeAttr>:$input_to_forget_intermediate,
|
||||
OptionalAttr<TypeAttr>:$input_to_cell_intermediate,
|
||||
|
@ -55,7 +55,7 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
@ -89,6 +89,11 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
|
||||
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
|
||||
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
|
||||
pass_config.unfold_batch_matmul = false;
|
||||
}
|
||||
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
|
||||
|
@ -128,7 +128,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<llvm::Optional<std::vector<int>>> node_shapes;
|
||||
std::vector<llvm::Optional<double>> node_mins;
|
||||
std::vector<llvm::Optional<double>> node_maxs;
|
||||
|
||||
@ -174,6 +174,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops;
|
||||
pass_config.lower_tensor_list_ops = true;
|
||||
// Disable the unfolding of the 16x16 TF::BatchMatMulOp to avoid the
|
||||
// conversion to an unsupported 16x16 TFL::FullyConnectedOp.
|
||||
if (toco_flags.inference_type() == toco::IODataType::QUANTIZED_INT16) {
|
||||
pass_config.unfold_batch_matmul = false;
|
||||
}
|
||||
|
||||
// TODO(b/153507667): Pass the session object when importing logic is removed.
|
||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h"
|
||||
|
||||
#include <ostream>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
@ -119,6 +120,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_INT32;
|
||||
case toco::IODataType::INT64:
|
||||
return DT_INT64;
|
||||
case toco::IODataType::UINT64:
|
||||
return DT_UINT64;
|
||||
case toco::IODataType::STRING:
|
||||
return DT_STRING;
|
||||
case toco::IODataType::BOOL:
|
||||
@ -185,7 +188,7 @@ Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
@ -210,8 +213,12 @@ Status PopulateQuantizationSpecs(
|
||||
node_dtypes->push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
|
||||
}
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
if (flag.shape().unknown_rank()) {
|
||||
node_shapes->push_back(llvm::None);
|
||||
} else {
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
}
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
if (flag.has_mean_value() && flag.has_std_value()) {
|
||||
@ -282,6 +289,10 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
|
||||
const std::unordered_set<std::string> select_user_tf_ops(
|
||||
toco_flags.select_user_tf_ops().begin(),
|
||||
toco_flags.select_user_tf_ops().end());
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
module.get(),
|
||||
@ -301,8 +312,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
|
||||
auto status = ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, pass_config.quant_specs,
|
||||
saved_model_tags, result, &pm);
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
|
||||
pass_config.quant_specs, saved_model_tags, result, &pm);
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
// rename once we enable the new converter feature flag.
|
||||
|
@ -41,7 +41,7 @@ Status PopulateQuantizationSpecs(
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<llvm::Optional<std::vector<int>>>* node_shapes,
|
||||
std::vector<llvm::Optional<double>>* node_mins,
|
||||
std::vector<llvm::Optional<double>>* node_maxs);
|
||||
|
||||
|
@ -53,10 +53,11 @@ struct QuantizationSpecs {
|
||||
bool disable_per_channel = false;
|
||||
|
||||
// When set to true, the fixed output ranges of the activation ops (tanh,
|
||||
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
|
||||
// emulation ops should be specified after the ops in the input graph. This
|
||||
// flag should be set to false for post-training quantization.
|
||||
bool disable_enforced_fixed_output_range = false;
|
||||
// sigmoid, etc.) and the weight constants are not inferred. Then, to quantize
|
||||
// these ops, quantization emulation ops should be placed after the ops in the
|
||||
// input graph. This flag should be set to false for post-training
|
||||
// quantization.
|
||||
bool disable_infer_tensor_range = false;
|
||||
|
||||
// The node type when the model is exported. Currently this is limited to
|
||||
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
|
||||
|
@ -100,13 +100,13 @@ class QuantizationDriver {
|
||||
explicit QuantizationDriver(FuncOp fn, bool is_signed,
|
||||
bool disable_per_channel,
|
||||
OpQuantSpecGetter op_quant_spec_getter,
|
||||
bool enforce_fixed_output_range)
|
||||
bool infer_tensor_range)
|
||||
: fn_(fn),
|
||||
builder_(fn.getBody()),
|
||||
is_signed_(is_signed),
|
||||
disable_per_channel_(disable_per_channel),
|
||||
op_quant_spec_getter_(op_quant_spec_getter),
|
||||
enforce_fixed_output_range_(enforce_fixed_output_range) {}
|
||||
infer_tensor_range_(infer_tensor_range) {}
|
||||
|
||||
// The entry point of the quantization parameters propagation.
|
||||
void Run();
|
||||
@ -384,7 +384,9 @@ class QuantizationDriver {
|
||||
|
||||
OpQuantSpecGetter op_quant_spec_getter_;
|
||||
|
||||
bool enforce_fixed_output_range_;
|
||||
// Infer output ranges for activation ops and constants. This is usually
|
||||
// required for post-training quantization.
|
||||
bool infer_tensor_range_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -670,33 +672,43 @@ void QuantizationDriver::PreprocessConstantOps() {
|
||||
|
||||
Value value = cst.getResult();
|
||||
builder_.setInsertionPoint(cst);
|
||||
for (auto indexed_use : llvm::enumerate(value.getUses())) {
|
||||
auto &use = indexed_use.value();
|
||||
auto spec = GetQuantSpec(use.getOwner());
|
||||
auto biases = spec->biases_params;
|
||||
Operation *user = use.getOwner();
|
||||
int operand_num = use.getOperandNumber();
|
||||
|
||||
// The following loop will change the value uses, thus we cache all the uses
|
||||
// needs to be changed.
|
||||
llvm::SmallVector<std::pair<Operation *, int>, 4> uses;
|
||||
for (auto &use : value.getUses()) {
|
||||
uses.push_back({use.getOwner(), use.getOperandNumber()});
|
||||
}
|
||||
for (auto indexed_use : llvm::enumerate(uses)) {
|
||||
Operation *user = indexed_use.value().first;
|
||||
int operand_num = indexed_use.value().second;
|
||||
|
||||
auto spec = GetQuantSpec(user);
|
||||
auto biases = spec->biases_params;
|
||||
|
||||
// The quantization parameters of a `weight` shouldn't be determined by
|
||||
// other values. So any constants which are not bias, an operand of an
|
||||
// op with same scale requirements, and haven't been quantized are
|
||||
// weights.
|
||||
if (biases.find(operand_num) == biases.end() &&
|
||||
!llvm::dyn_cast<mlir::SameScalesOpInterface>(user) &&
|
||||
!llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
// Needs to scan the content to get the quantization parameters if there
|
||||
// are no quantization parameters (FakeQuant ops).
|
||||
// For this case, the weight isn't duplicated.
|
||||
// Needs to scan the content of weights to get the quantization
|
||||
// parameters if there are no quantization parameters (FakeQuant ops).
|
||||
// For this case, the weight will not be duplicated.
|
||||
weights_.insert(cst);
|
||||
auto affine_user =
|
||||
llvm::dyn_cast<mlir::AffineQuantizedOpInterface>(user);
|
||||
if (affine_user &&
|
||||
affine_user.GetAffineOperandIndex() == use.getOperandNumber() &&
|
||||
if (affine_user && affine_user.GetAffineOperandIndex() == operand_num &&
|
||||
affine_user.RequiredNarrowRangeAffineOperand()) {
|
||||
optimized_weights_.insert(
|
||||
{cst, affine_user.GetQuantizationDimIndex()});
|
||||
}
|
||||
} else {
|
||||
// This is a bias, so the quantization parameter isn't determined by the
|
||||
// local content. Same if the user can have quantization parameter
|
||||
// propagated from other places.
|
||||
// Duplicate this constant in case it is shared by different users.
|
||||
// This is a bias or an operand of an op with same scale requirements,
|
||||
// so the quantization parameter are propagated from or determined by
|
||||
// other values. Duplicate this constant in case it is shared by
|
||||
// different users.
|
||||
if (indexed_use.index() > 0) {
|
||||
cst = builder_.create<ConstantOp>(cst.getLoc(), cst.getValue());
|
||||
}
|
||||
@ -786,12 +798,14 @@ bool QuantizationDriver::PropagateParams() {
|
||||
quantized_.insert(op);
|
||||
|
||||
if (auto cst = llvm::dyn_cast<ConstantOp>(op)) {
|
||||
// If it isn't a weight or has been quantized, skip.
|
||||
if (!IsWeight(cst) || IsQuantized(op)) continue;
|
||||
|
||||
// The quantization parameters are determined by the content of the
|
||||
// constant.
|
||||
changed |= SetConstantResultParams(op);
|
||||
// If the workflow requires inferring ranges from the content
|
||||
// (post-training quantization) and it is weight (filter) and hasn't
|
||||
// been quantized, we infer the quantization parameters from the content.
|
||||
if (infer_tensor_range_ && IsWeight(cst) && !IsQuantized(op)) {
|
||||
// The quantization parameters are determined by the content of the
|
||||
// constant.
|
||||
changed |= SetConstantResultParams(op);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -826,7 +840,9 @@ bool QuantizationDriver::PropagateParams() {
|
||||
|
||||
// TODO(fengliuai): make the bit width configurable.
|
||||
auto restricted = llvm::dyn_cast<FixedOutputRangeInterface>(op);
|
||||
if (restricted && enforce_fixed_output_range_) {
|
||||
if (restricted && infer_tensor_range_) {
|
||||
// Infer ranges from the activation ops. This is usually required for
|
||||
// the post-training quantization workflow.
|
||||
// TODO(fengliuai): different result can have different fixed range.
|
||||
auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8);
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
@ -903,9 +919,9 @@ void QuantizationDriver::Run() {
|
||||
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||
bool disable_per_channel,
|
||||
OpQuantSpecGetter op_quant_spec_getter,
|
||||
bool post_training_quantization) {
|
||||
bool infer_tensor_ranges) {
|
||||
QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter,
|
||||
post_training_quantization)
|
||||
infer_tensor_ranges)
|
||||
.Run();
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -103,8 +104,11 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
if (!stats) return failure();
|
||||
|
||||
for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
|
||||
mins.push_back(FloatAttr::getValueAsDouble(*it++));
|
||||
maxs.push_back(FloatAttr::getValueAsDouble(*it));
|
||||
double min = FloatAttr::getValueAsDouble(*it++);
|
||||
double max = FloatAttr::getValueAsDouble(*it);
|
||||
TensorRangeSanityCheck(op, min, max);
|
||||
mins.push_back(min);
|
||||
maxs.push_back(max);
|
||||
}
|
||||
quant_type =
|
||||
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, *op.axis(), mins,
|
||||
@ -112,6 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
} else if (auto stats = op.layerStats().dyn_cast<DenseFPElementsAttr>()) {
|
||||
double rmin = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}));
|
||||
double rmax = FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}));
|
||||
TensorRangeSanityCheck(op, rmin, rmax);
|
||||
quant_type =
|
||||
quant::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax,
|
||||
narrow_range, expressed, is_signed);
|
||||
@ -134,6 +139,19 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
int num_bits;
|
||||
bool narrow_range;
|
||||
bool is_signed;
|
||||
|
||||
// Emits an op warning message if the calibrated range is larger than 10.0 and
|
||||
// the storage type is less than or equal to 8 bits.
|
||||
void TensorRangeSanityCheck(quant::StatisticsOp op, double min,
|
||||
double max) const {
|
||||
double range = std::fabs(max - min);
|
||||
if (num_bits <= 8 && range >= 10.0) {
|
||||
op.emitWarning(
|
||||
"Tensor range is too wide to be quantized. Use tf.clip_by_value or "
|
||||
"tf.relu6 to narrow the tensor range. Range: " +
|
||||
std::to_string(range) + ", bit width: " + std::to_string(num_bits));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// A base rewrite pattern which matches any N-in-M-out operations with
|
||||
@ -490,13 +508,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias(
|
||||
// and the propagation results are materialized by inserting pairs of quantize
|
||||
// and dequantize ops to this function. Set `disable_per_channel` to true to not
|
||||
// use per channel quantization even the op supports it.
|
||||
// Setting `enforce_fixed_output_range` to true, to infer quantization
|
||||
// parameters from the fixed output range ops. This is only used for
|
||||
// post-training quantization.
|
||||
// Setting `infer_tensor_range` to true, to infer quantization parameters from
|
||||
// the activation ops and weight constants. This is only used for post-training
|
||||
// quantization.
|
||||
void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||
bool disable_per_channel,
|
||||
OpQuantSpecGetter op_quant_spec_getter,
|
||||
bool enforce_fixed_output_range);
|
||||
bool infer_tensor_ranges);
|
||||
|
||||
// The function might contain more stats ops than required, and it will
|
||||
// introduce requantize if the calibration stats have conflicts. This method
|
||||
|
@ -638,4 +638,10 @@ func @cast_ui8_to_i32() -> tensor<4xi32> {
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cast_identity
|
||||
func @cast_identity(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
|
||||
%0 = "tfl.cast"(%arg0) : (tensor<7xf32>) -> tensor<7xf32>
|
||||
return %0 : tensor<7xf32>
|
||||
// CHECK: return %arg0 : tensor<7xf32>
|
||||
}
|
||||
|
||||
|
335
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json
Normal file
335
tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.json
Normal file
@ -0,0 +1,335 @@
|
||||
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
|
||||
|
||||
// CHECK: effective_hidden_scale_intermediate = tensor<!quant.calibrated<f32<-5.000000e-01:5.000000e-01>>>
|
||||
// CHECK: input_to_cell_intermediate = tensor<!quant.calibrated<f32<-4.000000e+00:4.000000e+00>>>
|
||||
// CHECK: input_to_forget_intermediate = tensor<!quant.calibrated<f32<-1.600000e+01:1.600000e+01>>>
|
||||
// CHECK: input_to_input_intermediate = tensor<!quant.calibrated<f32<-3.200000e+01:3.200000e+01>>>
|
||||
// CHECK: input_to_output_intermediate = tensor<!quant.calibrated<f32<-1.000000e+00:1.000000e+00>>>
|
||||
|
||||
{
|
||||
"version": 3,
|
||||
"operator_codes": [
|
||||
{
|
||||
"builtin_code": "UNIDIRECTIONAL_SEQUENCE_LSTM"
|
||||
}
|
||||
],
|
||||
"subgraphs": [
|
||||
{
|
||||
"tensors": [
|
||||
{
|
||||
"shape": [1, 5],
|
||||
"name": "input0"
|
||||
},
|
||||
{
|
||||
"shape": [2, 5],
|
||||
"buffer": 1,
|
||||
"name": "input2input_weights1"
|
||||
},
|
||||
{
|
||||
"shape": [2, 5],
|
||||
"buffer": 2,
|
||||
"name": "input2forget_weights2"
|
||||
},
|
||||
{
|
||||
"shape": [2, 5],
|
||||
"buffer": 3,
|
||||
"name": "input2cell_weights3"
|
||||
},
|
||||
{
|
||||
"shape": [2, 5],
|
||||
"buffer": 4,
|
||||
"name": "input2output_weights4"
|
||||
},
|
||||
{
|
||||
"shape": [2, 4],
|
||||
"buffer": 5,
|
||||
"name": "rec2input_weights5"
|
||||
},
|
||||
{
|
||||
"shape": [2, 4],
|
||||
"buffer": 6,
|
||||
"name": "rec2forget_weights6"
|
||||
},
|
||||
{
|
||||
"shape": [2, 4],
|
||||
"buffer": 7,
|
||||
"name": "rec2cell_weights7"
|
||||
},
|
||||
{
|
||||
"shape": [2, 4],
|
||||
"buffer": 8,
|
||||
"name": "rec2output_weights8"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 9,
|
||||
"name": "cell2input_weights9"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 10,
|
||||
"name": "cell2forget_weights10"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 11,
|
||||
"name": "cell2output_weights11"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 12,
|
||||
"name": "input_gate_bias12"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 13,
|
||||
"name": "forget_gate_bias13"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 14,
|
||||
"name": "cell_gate_bias14"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 15,
|
||||
"name": "output_gate_bias15"
|
||||
},
|
||||
{
|
||||
"shape": [4, 2],
|
||||
"buffer": 16,
|
||||
"name": "proj_weights16"
|
||||
},
|
||||
{
|
||||
"shape": [4],
|
||||
"buffer": 17,
|
||||
"name": "proj_bias17"
|
||||
},
|
||||
{
|
||||
"shape": [1, 4],
|
||||
"name": "input_activation_state18",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-0.9],
|
||||
"max": [0.9]
|
||||
}
|
||||
},
|
||||
{
|
||||
"shape": [1, 2],
|
||||
"name": "input_cell_state19",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-0.8],
|
||||
"max": [0.8]
|
||||
}
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 18,
|
||||
"name": "input_norm20"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 19,
|
||||
"name": "forget_norm21"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 20,
|
||||
"name": "cell_norm22"
|
||||
},
|
||||
{
|
||||
"shape": [2],
|
||||
"buffer": 21,
|
||||
"name": "output_norm23"
|
||||
},
|
||||
{
|
||||
"shape": [],
|
||||
"name": "output24"
|
||||
},
|
||||
{
|
||||
"shape": [],
|
||||
"name": "intermediate_0",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-32],
|
||||
"max": [32]
|
||||
}
|
||||
},
|
||||
{
|
||||
"shape": [],
|
||||
"name": "intermediate_1",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-16],
|
||||
"max": [16]
|
||||
}
|
||||
},
|
||||
{
|
||||
"shape": [],
|
||||
"name": "intermediate_2",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-4],
|
||||
"max": [4]
|
||||
}
|
||||
},
|
||||
{
|
||||
"shape": [],
|
||||
"name": "intermediate_3",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-1.0],
|
||||
"max": [1.0]
|
||||
}
|
||||
},
|
||||
{
|
||||
"shape": [],
|
||||
"name": "intermediate_4",
|
||||
"is_variable": true,
|
||||
"quantization": {
|
||||
"min": [-0.5],
|
||||
"max": [0.5]
|
||||
}
|
||||
}
|
||||
],
|
||||
"inputs": [0],
|
||||
"outputs": [24],
|
||||
"operators": [
|
||||
{
|
||||
"inputs": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23
|
||||
],
|
||||
"outputs": [24],
|
||||
"intermediates": [
|
||||
25, 26, 27, 28, 29
|
||||
],
|
||||
"builtin_options_type": "UnidirectionalSequenceLSTMOptions",
|
||||
"builtin_options": {
|
||||
"fused_activation_function": "TANH",
|
||||
"cell_clip": 50.0
|
||||
},
|
||||
"mutating_variable_inputs": [
|
||||
false,
|
||||
false, false, false, false,
|
||||
false, false, false, false,
|
||||
false, false, false,
|
||||
false, false, false, false,
|
||||
true, true,
|
||||
false, false, false, false
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"buffers": [
|
||||
{
|
||||
"data": []
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
36, 167, 168, 63, 0, 140, 72, 191, 120, 20, 147, 62, 20, 152, 196, 190, 121, 98, 82, 187, 95, 128, 213, 61, 189, 3, 138, 63, 54, 103, 13, 62, 46, 224, 66, 63, 157, 204, 180, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
223, 20, 21, 64, 246, 166, 31, 191, 6, 51, 157, 188, 114, 90, 167, 62, 118, 240, 59, 63, 49, 162, 255, 62, 17, 91, 160, 63, 32, 47, 26, 63, 40, 136, 178, 191, 243, 154, 236, 61
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
137, 231, 86, 63, 41, 154, 16, 63, 239, 37, 77, 191, 55, 189, 24, 189, 86, 63, 18, 63, 42, 55, 13, 191, 110, 139, 138, 191, 219, 148, 181, 63, 71, 232, 108, 191, 66, 226, 145, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
245, 179, 225, 190, 51, 202, 176, 189, 132, 47, 53, 191, 155, 25, 50, 191, 197, 130, 240, 191, 98, 125, 45, 62, 243, 70, 83, 62, 85, 155, 139, 63, 113, 239, 11, 192, 35, 251, 139, 62
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
248, 188, 211, 191, 142, 11, 73, 62, 36, 8, 84, 63, 186, 208, 11, 191, 76, 208, 190, 191, 223, 200, 210, 63, 183, 170, 103, 63, 116, 129, 145, 63
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
235, 202, 222, 190, 159, 201, 112, 191, 217, 248, 166, 63, 165, 199, 131, 191, 130, 59, 47, 63, 179, 11, 186, 62, 55, 168, 18, 192, 152, 213, 26, 64
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
245, 123, 138, 62, 213, 106, 231, 59, 211, 218, 250, 62, 25, 157, 134, 63, 147, 22, 164, 63, 25, 221, 139, 62, 1, 230, 247, 62, 210, 185, 142, 63
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
197, 123, 23, 192, 45, 96, 178, 190, 174, 87, 165, 62, 213, 225, 200, 191, 119, 248, 15, 191, 128, 125, 171, 189, 90, 125, 222, 63, 4, 76, 95, 62
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
210, 73, 183, 63, 248, 177, 13, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
78, 251, 212, 191, 169, 29, 147, 63
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
178, 227, 203, 191, 247, 155, 103, 63
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
206, 111, 165, 190, 153, 77, 227, 63
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
255, 114, 132, 191, 253, 202, 140, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
90, 247, 1, 192, 125, 120, 209, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
65, 75, 243, 191, 58, 122, 146, 190
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
40, 135, 20, 63, 109, 50, 220, 191, 56, 241, 189, 63, 65, 12, 92, 63, 61, 14, 162, 62, 157, 138, 81, 63, 125, 61, 191, 61, 102, 231, 20, 63
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
145, 79, 49, 189, 175, 235, 220, 190, 182, 111, 157, 190, 144, 236, 97, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
76, 188, 109, 63, 228, 150, 201, 190
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
6, 146, 66, 191, 122, 127, 100, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
216, 59, 169, 190, 161, 178, 215, 191
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
208, 144, 101, 191, 127, 233, 195, 190
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
@ -20,20 +20,20 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
|
||||
|
||||
func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, %arg1: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, %arg2: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, %arg3: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, %arg4: tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, %arg5: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, %arg6: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, %arg7: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, %arg8: tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, %arg9: tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, %arg10: tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, %arg11: tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, %arg12: tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, %arg13: tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, %arg14: tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, %arg15: tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, %arg16: tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, %arg17: tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, %arg18: tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>, %arg19: tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, %arg20: tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>> {
|
||||
%cst = constant unit
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18) ({}) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, kernel_type = "FULL", proj_clip = 0.01 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8<-127:127>:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8<-127:127>:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8<-127:127>:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.601389680352559E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
return %0 : tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
// CHECK-LABEL: testFullyQuantizedLSTM
|
||||
// CHECK: %cst = constant unit
|
||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %cst, %cst, %cst, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg19, %arg20, %arg15, %arg16, %arg17, %arg18)
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0075630000792443752:2>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0087630003690719604>>, input_to_forget_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0078849997371435165>>, input_to_input_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0049890000373125076>>, input_to_output_intermediate = tensor<0x!quant.uniform<i16:f32, 0.0057529998011887074>>, kernel_type = "FULL", proj_clip = 0.00999999977 : f32} : (tensor<1x528x!quant.uniform<i8:f32, 0.037248000502586365:-19>>, tensor<2048x528x!quant.uniform<i8:f32, 0.059801999479532242>>, tensor<2048x528x!quant.uniform<i8:f32, 0.031925998628139496>>, tensor<2048x528x!quant.uniform<i8:f32, 0.056272000074386597>>, tensor<2048x528x!quant.uniform<i8:f32, 0.063763998448848724>>, tensor<2048x640x!quant.uniform<i8:f32, 0.013358999975025654>>, tensor<2048x640x!quant.uniform<i8:f32, 0.022830000147223473>>, tensor<2048x640x!quant.uniform<i8:f32, 0.032276000827550888>>, tensor<2048x640x!quant.uniform<i8:f32, 0.035427000373601913>>, none, none, none, tensor<2048x!quant.uniform<i32:f32, 4.2675782196965883E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.0742187583900886E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.6406249869760359E-7>>, tensor<2048x!quant.uniform<i32:f32, 1.523437447303877E-7>>, tensor<640x2048x!quant.uniform<i8:f32, 0.021174000576138496>>, tensor<640x!quant.uniform<i32:f32, 1.6013896674849093E-4>>, tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>, tensor<1x2048x!quant.uniform<i16:f32, 4.8799999058246613E-4>>, tensor<2048x!quant.uniform<i16:f32, 4.3700000969693065E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.1000000085914508E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.6799999866634607E-4>>, tensor<2048x!quant.uniform<i16:f32, 1.55999994603917E-4>>) -> tensor<1x640x!quant.uniform<i8:f32, 0.09671100229024887:10>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithIntermediates
|
||||
func @testUnidirectionalSequenceLstmWithIntermediates(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8<-127:127>:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform<i8:f32, 0.0077881771139800549>>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -318,6 +318,15 @@ func @any(%arg0: tensor<2x2xi1>, %arg1: tensor<i32>) -> tensor<i1> {
|
||||
// CHECK: "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i32>) -> tensor<i1>
|
||||
}
|
||||
|
||||
func @any_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi1> {
|
||||
%0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi64>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
|
||||
// CHECK-LABEL: any_i64axes
|
||||
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||
// CHECK: "tfl.reduce_any"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1>
|
||||
}
|
||||
|
||||
func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
@ -435,6 +444,16 @@ func @scatterNdHigherRankIndices(%arg0: tensor<4x2x2xi32>, %arg1: tensor<4x2x3xf
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
func @scatter_nd_i64(%arg0: tensor<4x2x2xi64>, %arg1: tensor<4x2x3xf32>, %arg2: tensor<3xi64>) -> tensor<10x2x3xf32> {
|
||||
%0 = "tf.ScatterNd"(%arg0, %arg1, %arg2) : (tensor<4x2x2xi64>, tensor<4x2x3xf32>, tensor<3xi64>) -> tensor<10x2x3xf32>
|
||||
return %0 : tensor<10x2x3xf32>
|
||||
|
||||
// CHECK-LABEL:scatter_nd_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.scatter_nd"
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x3x5x20xf32>
|
||||
@ -689,6 +708,16 @@ func @reverse_v2(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi32>) -> tensor<1x2
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @reverse_v2_i64(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1xi64>) -> tensor<1x2x3x4xf32> {
|
||||
%0 = "tf.ReverseV2"(%arg0, %arg1) : (tensor<1x2x3x4xf32>, tensor<1xi64>) -> tensor<1x2x3x4xf32>
|
||||
return %0 : tensor<1x2x3x4xf32>
|
||||
|
||||
// CHECK-LABEL:reverse_v2_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.reverse_v2"
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @matrix_diag(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
%0 = "tf.MatrixDiag"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16x16xf32>
|
||||
return %0 : tensor<8x16x16xf32>
|
||||
@ -763,13 +792,31 @@ func @matrix_diag_v3(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> {
|
||||
// CHECK: return [[VAL_6]] : tensor<8x16x16xf32>
|
||||
}
|
||||
|
||||
func @matrix_set_diag(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
func @matrix_set_diag_v3(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) {align = "RIGHT_LEFT"} : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK-LABEL: func @matrix_set_diag(
|
||||
// CHECK: [[VAL_0:%.*]] = "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return [[VAL_0]]
|
||||
// CHECK-LABEL: func @matrix_set_diag_v3
|
||||
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
|
||||
}
|
||||
|
||||
func @matrix_set_diag_v3_non_zero_k(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK-LABEL: @matrix_set_diag_v3_non_zero_k
|
||||
// CHECK: tf.MatrixSetDiagV3
|
||||
}
|
||||
|
||||
func @matrix_set_diag_v3_default_align(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi32>) -> tensor<3x3xi64> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi64>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK-LABEL: @matrix_set_diag_v3_default_align
|
||||
// CHECK: "tfl.matrix_set_diag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi32>) -> tensor<3x3xi64>
|
||||
}
|
||||
|
||||
func @maximum(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
@ -934,6 +981,15 @@ func @sum_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32
|
||||
// CHECK: "tfl.sum"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @sum_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||
%0 = "tf.Sum"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: sum_i64axes
|
||||
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||
// CHECK: "tfl.sum"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_min(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
@ -950,6 +1006,15 @@ func @reduce_min_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
|
||||
// CHECK: "tfl.reduce_min"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_min_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||
%0 = "tf.Min"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: reduce_min_i64axes
|
||||
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||
// CHECK: "tfl.reduce_min"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_max(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
@ -966,6 +1031,15 @@ func @reduce_max_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tenso
|
||||
// CHECK: "tfl.reduce_max"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_max_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||
%0 = "tf.Max"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: reduce_max_i64axes
|
||||
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||
// CHECK: "tfl.reduce_max"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_prod(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
@ -982,6 +1056,15 @@ func @reduce_prod_true(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi32>) -> tens
|
||||
// CHECK: "tfl.reduce_prod"(%arg0, %arg1) {keep_dims = true} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @reduce_prod_i64axes(%arg0: tensor<8x16x16xf32>, %arg1: tensor<2xi64>) -> tensor<?xf32> {
|
||||
%0 = "tf.Prod"(%arg0, %arg1) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi64>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
||||
// CHECK-LABEL: reduce_prod_i64axes
|
||||
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
|
||||
// CHECK: "tfl.reduce_prod"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xf32>, tensor<2xi32>) -> tensor<?xf32>
|
||||
}
|
||||
|
||||
func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
@ -996,6 +1079,15 @@ func @batch_to_space_nd_unsupported(%arg0: tensor<?x1x1x1x4xf32>, %arg1: tensor<
|
||||
// CHECK: "tf.BatchToSpaceND"
|
||||
}
|
||||
|
||||
func @batch_to_space_nd_i64(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<?xf32> {
|
||||
%0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
// CHECK-LABEL: batch_to_space_nd_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.batch_to_space_nd"
|
||||
}
|
||||
|
||||
func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor<*xf32> {
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
@ -1003,6 +1095,15 @@ func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2:
|
||||
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
func @space_to_batch_nd_i64(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi64>, %arg2: tensor<2x2xi64>) -> tensor<*xf32> {
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi64>, tensor<2x2xi64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: space_to_batch_nd_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.space_to_batch_nd"
|
||||
}
|
||||
|
||||
func @split(%arg0: tensor<i32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> {
|
||||
%0:3 = "tf.Split"(%arg0, %arg1) : (tensor<i32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>)
|
||||
return %0#0 : tensor<1x4x3xf32>
|
||||
@ -1116,10 +1217,10 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
|
||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||
return %0 : tensor<10x10xf32>
|
||||
// CHECK-LABEL: strided_slice_with_constant_attributes
|
||||
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
|
||||
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
|
||||
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
|
||||
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
|
||||
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
|
||||
@ -1129,6 +1230,39 @@ func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tenso
|
||||
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
|
||||
}
|
||||
|
||||
func @strided_slice_with_unranked_input_and_i64_parameters(%arg0: tensor<*xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<*xf32> {
|
||||
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<*xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: strided_slice_with_unranked_input_and_i64_parameters
|
||||
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
|
||||
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
|
||||
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<*xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
func @strided_slice_with_i64_parameters(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>, %arg3: tensor<1xi64>) -> tensor<1x2x2x5xf32> {
|
||||
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1x2x2x5xf32>
|
||||
return %0 : tensor<1x2x2x5xf32>
|
||||
// CHECK-LABEL: strided_slice_with_i64_parameters
|
||||
// CHECK-DAG: [[BEGIN:%.*]] = "tfl.cast"(%arg1) : (tensor<1xi64>) -> tensor<1xi32>
|
||||
// CHECK-DAG: [[END:%.*]] = "tfl.cast"(%arg2) : (tensor<1xi64>) -> tensor<1xi32>
|
||||
// CHECK-DAG: [[STRIDES:%.*]] = "tfl.cast"(%arg3) : (tensor<1xi64>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
func @strided_slice_with_i64_constant_attributes(%arg0: tensor<10x10x10xf32>) -> tensor<10x10xf32> {
|
||||
%cst = constant dense<-1> : tensor<1xi64>
|
||||
%cst_1 = constant dense<0> : tensor<1xi64>
|
||||
%cst_2 = constant dense<1> : tensor<1xi64>
|
||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<10x10xf32>
|
||||
return %0 : tensor<10x10xf32>
|
||||
// CHECK-LABEL: strided_slice_with_i64_constant_attributes
|
||||
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<-1> : tensor<1xi32>
|
||||
// CHECK-DAG: [[END:%cst.*]] = constant dense<0> : tensor<1xi32>
|
||||
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<1xi32>
|
||||
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
|
||||
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
|
||||
return %0 : tensor<?x3x5xf32>
|
||||
@ -1361,8 +1495,7 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
|
||||
|
||||
// CHECK-LABEL: conv2d_backprop_input
|
||||
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
|
||||
// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant unit
|
||||
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
@ -1797,10 +1930,25 @@ func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
|
||||
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
|
||||
func @cumsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
// CHECK-LABEL: cumsum_invalid
|
||||
// CHECK-NOT: "tfl.cumsum"
|
||||
// CHECK-LABEL: cumsum_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.cumsum"
|
||||
}
|
||||
|
||||
func @segmentsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<*xf32> {
|
||||
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: segmentsum
|
||||
// CHECK: "tfl.segment_sum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
func @segmentsum_i64(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<*xf32> {
|
||||
%0 = "tf.SegmentSum"(%arg0, %arg1) : (tensor<3x3xf32>, tensor<i64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: segmentsum_i64
|
||||
// CHECK: "tfl.cast"
|
||||
// CHECK: "tfl.segment_sum"
|
||||
}
|
||||
|
@ -625,6 +625,35 @@ func @QuantizeSharedBiases2(
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||
}
|
||||
|
||||
// Make sure constants are duplicataed for all users.
|
||||
// CHECK-LABEL: QuantizeSharedConstantsMultipleUsers
|
||||
func @QuantizeSharedConstantsMultipleUsers(
|
||||
%arg0: tensor<32x!quant.uniform<u8:f32, 1.0>>,
|
||||
%arg1: tensor<32x!quant.uniform<u8:f32, 2.0>>,
|
||||
%arg2: tensor<32x!quant.uniform<u8:f32, 3.0>>,
|
||||
%arg3: tensor<32x!quant.uniform<u8:f32, 4.0>>) -> (tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>) {
|
||||
%cst = constant dense<0.0> : tensor<32xf32>
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<32x!quant.uniform<u8:f32, 1.0>>) -> tensor<32xf32>
|
||||
%1 = "tfl.dequantize"(%arg1) : (tensor<32x!quant.uniform<u8:f32, 2.0>>) -> tensor<32xf32>
|
||||
%2 = "tfl.dequantize"(%arg2) : (tensor<32x!quant.uniform<u8:f32, 3.0>>) -> tensor<32xf32>
|
||||
%3 = "tfl.dequantize"(%arg3) : (tensor<32x!quant.uniform<u8:f32, 4.0>>) -> tensor<32xf32>
|
||||
|
||||
%4 = "tfl.minimum"(%0, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
%5 = "tfl.minimum"(%1, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
%6 = "tfl.minimum"(%2, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
%7 = "tfl.minimum"(%3, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
return %4, %5, %6, %7 : tensor<32xf32>, tensor<32xf32>, tensor<32xf32>, tensor<32xf32>
|
||||
|
||||
// CHECK: %[[cst1:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<32xf32>
|
||||
// CHECK: %[[cst2:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 3.000000e+00>>) -> tensor<32xf32>
|
||||
// CHECK: %[[cst3:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 4.000000e+00>>) -> tensor<32xf32>
|
||||
// CHECK: %[[cst4:.*]] = "tfl.dequantize"(%{{.*}}) : (tensor<32x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<32xf32>
|
||||
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst1]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst2]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
// CHECK: "tfl.minimum"(%{{.*}}, %[[cst3]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32>
|
||||
}
|
||||
|
||||
// Make sure quantization parameters are scanned from weight, but not from bias.
|
||||
// CHECK-LABEL: QuantizeWeight
|
||||
func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> {
|
||||
|
@ -551,35 +551,17 @@ func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf
|
||||
return %0 : tensor<8x4x16x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV2Conversion
|
||||
func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV2NonZeroK
|
||||
func @MatrixSetDiagV2NonZeroK(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV2"(%arg0, %arg1, %[[CST]]) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @MatrixSetDiagV3Conversion
|
||||
func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tf.MatrixSetDiagV3"(%arg0, %arg1, %cst) : (tensor<3x3xi32>, tensor<3xi32>, tensor<i32>) -> tensor<3x3xi32>
|
||||
return %0 : tensor<3x3xi32>
|
||||
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
|
||||
// CHECK: return %[[RES]]
|
||||
func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<10x10xf32> {
|
||||
%cst = constant dense<-1> : tensor<1xi32>
|
||||
%cst_1 = constant dense<0> : tensor<1xi32>
|
||||
%cst_2 = constant dense<1> : tensor<1xi32>
|
||||
%0 = "tf.StridedSlice"(%arg0, %cst, %cst_1, %cst_2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<10x10xf32>
|
||||
return %0 : tensor<10x10xf32>
|
||||
// CHECK-LABEL: strided_slice_with_constant_attributes
|
||||
// CHECK-DAG: [[BEGIN:%cst.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32>
|
||||
// CHECK-DAG: [[END:%cst.*]] = constant dense<[0, 10, 10]> : tensor<3xi32>
|
||||
// CHECK-DAG: [[STRIDES:%cst.*]] = constant dense<1> : tensor<3xi32>
|
||||
// CHECK-NEXT: "tf.StridedSlice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
|
||||
|
@ -242,7 +242,8 @@ int main(int argc, char **argv) {
|
||||
std::string result;
|
||||
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, quant_specs, tags, &result, &pm);
|
||||
emit_select_tf_ops, emit_custom_ops,
|
||||
/*select_user_tf_ops=*/{}, quant_specs, tags, &result, &pm);
|
||||
if (!status.ok()) return kTrFailure;
|
||||
|
||||
std::string error_msg;
|
||||
|
@ -137,6 +137,7 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
std::string* result, mlir::PassManager* pass_manager) {
|
||||
@ -169,10 +170,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
}
|
||||
|
||||
// Write MLIR TFLite dialect into FlatBuffer
|
||||
OpOrArgLocNameMapper op_or_arg_name_mapper;
|
||||
if (!quant_specs.RunWeightQuantization()) {
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, result, emit_builtin_tflite_ops, emit_select_tf_ops,
|
||||
emit_custom_ops, saved_model_tags)) {
|
||||
emit_custom_ops, select_user_tf_ops, saved_model_tags,
|
||||
&op_or_arg_name_mapper)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
} else {
|
||||
@ -181,7 +184,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
std::string pre_quantized_result;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module, &pre_quantized_result, emit_builtin_tflite_ops,
|
||||
emit_select_tf_ops, emit_custom_ops, saved_model_tags)) {
|
||||
emit_select_tf_ops, emit_custom_ops, select_user_tf_ops,
|
||||
saved_model_tags, &op_or_arg_name_mapper)) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
|
||||
|
@ -63,6 +63,7 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops,
|
||||
bool emit_select_tf_ops, bool emit_custom_ops,
|
||||
const std::unordered_set<std::string>& select_user_tf_ops,
|
||||
const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
const std::unordered_set<std::string>& saved_model_tags,
|
||||
std::string* result, mlir::PassManager* pass_manager);
|
||||
|
@ -54,7 +54,7 @@ def ExtractSingleElementAsInt32 : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
|
||||
|
||||
// Converts tensor with int64 to int32.
|
||||
def CreateTFLCastToInt32Op : NativeCodeCall<
|
||||
def CreateTFCastToInt32Op : NativeCodeCall<
|
||||
"CreateCastToInt32($0, $_loc, $_builder)">;
|
||||
|
||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||
@ -193,8 +193,8 @@ def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>;
|
||||
def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>;
|
||||
def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>;
|
||||
def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>;
|
||||
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids),
|
||||
(TFL_SegmentSumOp $data, $segment_ids)>;
|
||||
def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids),
|
||||
(TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>;
|
||||
def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y),
|
||||
(TFL_SelectOp $cond, $x, $y)>;
|
||||
def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y),
|
||||
@ -221,7 +221,7 @@ def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||
|
||||
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
|
||||
(TFL_TransposeOp $arg,
|
||||
(CreateTFLCastToInt32Op $perm))>;
|
||||
(CreateTFCastToInt32Op $perm))>;
|
||||
|
||||
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
|
||||
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
@ -309,8 +309,9 @@ def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>;
|
||||
def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r),
|
||||
(TFL_SquaredDifferenceOp $l, $r)>;
|
||||
|
||||
def LegalizeReverseV2 : Pat<(TF_ReverseV2Op $arg0, $arg1),
|
||||
(TFL_ReverseV2Op $arg0, $arg1)>;
|
||||
def LegalizeReverseV2 : Pat<
|
||||
(TF_ReverseV2Op $arg0, $axis),
|
||||
(TFL_ReverseV2Op $arg0, (CreateTFCastToInt32Op $axis))>;
|
||||
|
||||
def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1,
|
||||
/*incompatible_shape_error=*/ConstBoolAttrTrue),
|
||||
@ -327,33 +328,40 @@ def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2),
|
||||
(TFL_MeanOp $arg0, $arg1, $arg2)>;
|
||||
|
||||
def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2),
|
||||
(TFL_SumOp $arg, $axes, $arg2)>;
|
||||
(TFL_SumOp $arg, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||
|
||||
// TopK in TFL is always sorted so we ignore that attribute here.
|
||||
def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted),
|
||||
(TFL_TopKV2Op $input, $k)>;
|
||||
|
||||
def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2),
|
||||
(TFL_ReduceMinOp $arg0, $arg1, $arg2)>;
|
||||
def LegalizeMin : Pat<
|
||||
(TF_MinOp $arg0, $axes, BoolAttr:$arg2),
|
||||
(TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||
|
||||
def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2),
|
||||
(TFL_ReduceMaxOp $arg0, $arg1, $arg2)>;
|
||||
def LegalizeMax : Pat<
|
||||
(TF_MaxOp $arg0, $axes, BoolAttr:$arg2),
|
||||
(TFL_ReduceMaxOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||
|
||||
def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2),
|
||||
(TFL_ReduceProdOp $arg0, $arg1, $arg2)>;
|
||||
def LegalizeProd : Pat<
|
||||
(TF_ProdOp $arg0, $axes, BoolAttr:$arg2),
|
||||
(TFL_ReduceProdOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>;
|
||||
|
||||
def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims),
|
||||
(TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>;
|
||||
def LegalizeAny : Pat<
|
||||
(TF_AnyOp $input, $reduction_indices, $keep_dims),
|
||||
(TFL_ReduceAnyOp $input, (CreateTFCastToInt32Op $reduction_indices),
|
||||
$keep_dims)>;
|
||||
|
||||
def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>;
|
||||
|
||||
def LegalizeBatchToSpaceND : Pat<
|
||||
(TF_BatchToSpaceNDOp $input, $block_shape, $crops),
|
||||
(TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>;
|
||||
(TFL_BatchToSpaceNdOp $input, (CreateTFCastToInt32Op $block_shape),
|
||||
(CreateTFCastToInt32Op $crops))>;
|
||||
|
||||
def LegalizeSpaceToBatchND : Pat<
|
||||
(TF_SpaceToBatchNDOp $input, $block_shape, $paddings),
|
||||
(TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>;
|
||||
(TFL_SpaceToBatchNdOp $input, (CreateTFCastToInt32Op $block_shape),
|
||||
(CreateTFCastToInt32Op $paddings))>;
|
||||
|
||||
def LegalizeSpaceToDepth : Pat<
|
||||
(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format),
|
||||
@ -437,14 +445,45 @@ def LegalizeConv2DBackpropInput : Pat<
|
||||
/*stride_h=*/ ExtractI32At<1>:$strides,
|
||||
/*stride_w=*/ ExtractI32At<2>:$strides)>;
|
||||
|
||||
def IsRankZeroAttr
|
||||
: CPred<"$_self.cast<DenseElementsAttr>().getType().getRank() == 0">;
|
||||
|
||||
def HasValueZero
|
||||
: CPred<"$_self.cast<DenseElementsAttr>().getSplatValue()."
|
||||
"cast<::mlir::IntegerAttr>().getInt() == 0">;
|
||||
|
||||
// TFLite only supports MatrixSetDiag ops with scalar zero k attribute.
|
||||
def IsSupportedByTFLiteMatrixSetDiag
|
||||
: ElementsAttrBase<And<[ElementsAttr.predicate,
|
||||
IsRankZeroAttr, HasValueZero]>,
|
||||
"MatrixSetDiag attribute verification">;
|
||||
|
||||
// Attribute align doesn't matter when k is zero.
|
||||
def LegalizeMatrixSetDiag : Pat<
|
||||
(TF_MatrixSetDiagOp $input, $diagonal),
|
||||
(TF_MatrixSetDiagV3Op $input, $diagonal,
|
||||
(ConstantLikeMatcher IsSupportedByTFLiteMatrixSetDiag:$k), $align),
|
||||
(TFL_MatrixSetDiagOp $input, $diagonal)>;
|
||||
|
||||
def LegalizeScatterNd : Pat<
|
||||
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
|
||||
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
|
||||
(TF_ScatterNdOp $indices, $updates, $shape),
|
||||
(TFL_ScatterNdOp (CreateTFCastToInt32Op $indices), $updates,
|
||||
(CreateTFCastToInt32Op $shape))>;
|
||||
|
||||
def LegalizeCumsum : Pat<
|
||||
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
|
||||
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;
|
||||
(TFL_CumsumOp $input, (CreateTFCastToInt32Op $axis), $exclusive, $reverse)>;
|
||||
|
||||
def LegalizeReshape : Pat<
|
||||
(TF_ReshapeOp $input, $shape),
|
||||
(TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>;
|
||||
|
||||
def LegalizeStridedSlice : Pat<
|
||||
(TF_StridedSliceOp
|
||||
$input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask,
|
||||
$new_axis_mask, $shrink_axis_mask),
|
||||
(TFL_StridedSliceOp $input,
|
||||
(CreateTFCastToInt32Op $begin), (CreateTFCastToInt32Op $end),
|
||||
(CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask),
|
||||
(convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
|
||||
(convertIntAttrTo32Bit $new_axis_mask),
|
||||
(convertIntAttrTo32Bit $shrink_axis_mask))>;
|
||||
|
@ -123,7 +123,8 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
|
||||
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
|
||||
return rewriter.create<TFL::CastOp>(loc, new_type, val);
|
||||
return rewriter.createOrFold<TF::CastOp>(loc, new_type, val,
|
||||
rewriter.getBoolAttr(false));
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
|
||||
@ -145,10 +146,8 @@ DECL_CONVERT_OP(MatMul);
|
||||
DECL_CONVERT_OP(MatrixDiagV2);
|
||||
DECL_CONVERT_OP(MatrixDiagV3);
|
||||
DECL_CONVERT_OP(Pack);
|
||||
DECL_CONVERT_OP(Reshape);
|
||||
DECL_CONVERT_OP(Split);
|
||||
DECL_CONVERT_OP(SplitV);
|
||||
DECL_CONVERT_OP(StridedSlice);
|
||||
DECL_CONVERT_OP(Unpack);
|
||||
DECL_CONVERT_OP(RandomUniform);
|
||||
|
||||
@ -299,30 +298,6 @@ LogicalResult ConvertTFPackOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_reshape_op = cast<TF::ReshapeOp>(op);
|
||||
|
||||
auto input = tf_reshape_op.tensor();
|
||||
auto shape = tf_reshape_op.shape();
|
||||
|
||||
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||
if (!shape_type.getElementType().isSignlessInteger(32)) {
|
||||
auto new_shape = shape_type.getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
|
||||
// Uses TF::CastOp to be folded if the shape input is a constant.
|
||||
shape = rewriter
|
||||
.create<TF::CastOp>(op->getLoc(), new_type, shape,
|
||||
rewriter.getBoolAttr(false))
|
||||
.y();
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
|
||||
input, shape);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFSplitOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_split_op = cast<TF::SplitOp>(op);
|
||||
@ -349,81 +324,6 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
||||
Value attribute,
|
||||
ArrayRef<int32_t> padding_val, int* mask) {
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
SmallVector<int32_t, 8> padded_val;
|
||||
|
||||
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_attr_type ||
|
||||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
|
||||
// If the input attribute is neither ranked type nor constant, we
|
||||
// can't do any padding. Instead we just return it.
|
||||
return attribute;
|
||||
}
|
||||
for (const auto& idx : dense_elem_attr.getIntValues()) {
|
||||
padded_val.push_back(idx.getSExtValue());
|
||||
}
|
||||
auto attr_dim_count = ranked_attr_type.getShape()[0];
|
||||
int full_dim_count = padding_val.size();
|
||||
for (int i = attr_dim_count; i < full_dim_count; ++i) {
|
||||
padded_val.push_back(padding_val[i]);
|
||||
if (mask) *mask |= 1 << i;
|
||||
}
|
||||
auto type =
|
||||
RankedTensorType::get({full_dim_count}, rewriter.getIntegerType(32));
|
||||
auto attr = DenseElementsAttr::get<int32_t>(type, padded_val);
|
||||
return rewriter.create<ConstantOp>(op->getLoc(), type, attr);
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
|
||||
auto ranked_input_type =
|
||||
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_input_type) {
|
||||
// If input is not a ranked tensor, we can't deduce the padding dimensions
|
||||
// from it, so we just do a plain conversion here.
|
||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
|
||||
tf_strided_slice_op.strides(),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.begin_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.end_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
|
||||
return success();
|
||||
}
|
||||
|
||||
int num_input_dims = ranked_input_type.getRank();
|
||||
// Pad `begin` array with zero values and update the `begin_mask`.
|
||||
SmallVector<int32_t, 8> begin_pad_val(num_input_dims, 0);
|
||||
int begin_mask = tf_strided_slice_op.begin_mask();
|
||||
Value padded_begin = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.begin(), begin_pad_val, &begin_mask);
|
||||
// Pad `end` array with `input_shape` and update the `end_mask`.
|
||||
int end_mask = tf_strided_slice_op.end_mask();
|
||||
auto input_shape = ranked_input_type.getShape();
|
||||
SmallVector<int32_t, 8> end_pad_val(input_shape.begin(), input_shape.end());
|
||||
Value padded_end = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.end(), end_pad_val, &end_mask);
|
||||
// Pad `strides` array with ones.
|
||||
SmallVector<int32_t, 8> strides_pad_val(num_input_dims, 1);
|
||||
Value padded_strides = PadStridedSliceAttributeArray(
|
||||
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
|
||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||
padded_begin, padded_end, padded_strides,
|
||||
rewriter.getI32IntegerAttr(begin_mask),
|
||||
rewriter.getI32IntegerAttr(end_mask),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.ellipsis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.new_axis_mask()),
|
||||
rewriter.getI32IntegerAttr(tf_strided_slice_op.shrink_axis_mask()));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFUnpackOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||
@ -792,10 +692,9 @@ void addPatterns(MLIRContext* context, OwningRewritePatternList& patterns) {
|
||||
populateWithGenerated(context, patterns);
|
||||
patterns
|
||||
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFRandomUniformOp>(
|
||||
context);
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFUnpackOp, ConvertTFAssertOp,
|
||||
ConvertTFRandomUniformOp>(context);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
|
@ -376,14 +376,17 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||
bool quantization_aware_training_mode = ContainsQuantizeOps(func);
|
||||
// Enforce fixed output range for post-training quantization and
|
||||
// when the model has quantization emulation ops, unless it was disabled
|
||||
// explicitly by the flag.
|
||||
bool enforced_output_range =
|
||||
(quant_specs_.post_training_quantization ||
|
||||
quantization_aware_training_mode) &&
|
||||
!quant_specs_.disable_enforced_fixed_output_range;
|
||||
// When this is true, the quantizer will try its best to extract the
|
||||
// quantization parameters from the op quantization property and constant
|
||||
// content. This is also set to true when the `quantize_allowlist` and
|
||||
// `quantize_signed` test flags are enabled.
|
||||
bool eager_quantize = ContainsQuantizeOps(func) ||
|
||||
(!quantize_allowlist.empty() || quantize_signed);
|
||||
// Infer the tensor range for the activation ops and weight constants unless
|
||||
// it is disabled explicitly.
|
||||
bool infer_tensor_range =
|
||||
(quant_specs_.post_training_quantization || eager_quantize) &&
|
||||
!quant_specs_.disable_infer_tensor_range;
|
||||
if (is_signed) {
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
@ -403,7 +406,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(
|
||||
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
|
||||
GetOpQuantSpec, enforced_output_range);
|
||||
GetOpQuantSpec, infer_tensor_range);
|
||||
|
||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||
}
|
||||
|
@ -519,9 +519,10 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
explicit ConvertTFStridedSlice(MLIRContext *context)
|
||||
: RewritePattern(TF::StridedSliceOp::getOperationName(), 2, context) {}
|
||||
|
||||
LogicalResult RewriteNewAxisMask(Operation *op, uint64_t new_axis_mask,
|
||||
LogicalResult RewriteNewAxisMask(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
|
||||
|
||||
// Insert a new reshape op.
|
||||
Value original_input = strided_slice_op.input();
|
||||
@ -529,51 +530,51 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
original_input.getType().cast<RankedTensorType>();
|
||||
const ArrayRef<int64_t> &original_input_shape =
|
||||
original_input_type.getShape();
|
||||
SmallVector<int64_t, 4> new_shape;
|
||||
SmallVector<int64_t, 4> revised_shape;
|
||||
int index = 0;
|
||||
const int original_input_rank = original_input_shape.size();
|
||||
while (index < original_input_rank || new_axis_mask) {
|
||||
if (new_axis_mask & 1) {
|
||||
new_shape.emplace_back(1);
|
||||
revised_shape.emplace_back(1);
|
||||
} else {
|
||||
new_shape.emplace_back(original_input_shape[index++]);
|
||||
revised_shape.emplace_back(original_input_shape[index++]);
|
||||
}
|
||||
new_axis_mask >>= 1;
|
||||
}
|
||||
|
||||
if (failed(TF::VerifyShapeOfReshapeOp(new_shape))) return failure();
|
||||
if (failed(TF::VerifyShapeOfReshapeOp(revised_shape))) return failure();
|
||||
|
||||
const int dim_size = new_shape.size();
|
||||
const int dim_size = revised_shape.size();
|
||||
Location loc = strided_slice_op.getLoc();
|
||||
auto shape_type =
|
||||
RankedTensorType::get({dim_size}, rewriter.getIntegerType(32));
|
||||
SmallVector<Attribute, 4> result_shape_data(dim_size);
|
||||
for (int i = 0; i < dim_size; ++i) {
|
||||
result_shape_data[i] =
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(new_shape[i]));
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(revised_shape[i]));
|
||||
}
|
||||
|
||||
auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data);
|
||||
auto shape = rewriter.create<ConstantOp>(loc, shape_type, shape_attr);
|
||||
auto new_output_type =
|
||||
RankedTensorType::get(new_shape, original_input_type.getElementType());
|
||||
auto revised_output_type = RankedTensorType::get(
|
||||
revised_shape, original_input_type.getElementType());
|
||||
TF::ReshapeOp reshape = rewriter.create<TF::ReshapeOp>(
|
||||
loc, new_output_type, original_input, shape);
|
||||
loc, revised_output_type, original_input, shape);
|
||||
|
||||
// Replace the original strided_slice.
|
||||
uint64_t new_begin_mask = strided_slice_op.begin_mask();
|
||||
uint64_t new_end_mask = strided_slice_op.end_mask();
|
||||
uint64_t revised_begin_mask = strided_slice_op.begin_mask();
|
||||
uint64_t revised_end_mask = strided_slice_op.end_mask();
|
||||
// Since we expand the dims, we need to apply them to the begin_mask &
|
||||
// end_mask.
|
||||
new_begin_mask |= strided_slice_op.new_axis_mask();
|
||||
new_end_mask |= strided_slice_op.new_axis_mask();
|
||||
revised_begin_mask |= strided_slice_op.new_axis_mask();
|
||||
revised_end_mask |= strided_slice_op.new_axis_mask();
|
||||
|
||||
auto attribute_type = rewriter.getIntegerType(64);
|
||||
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
|
||||
op, strided_slice_op.getType(), reshape, strided_slice_op.begin(),
|
||||
strided_slice_op.end(), strided_slice_op.strides(),
|
||||
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, new_end_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.ellipsis_mask()),
|
||||
rewriter.getI64IntegerAttr(0),
|
||||
@ -582,10 +583,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult RewriteEllipsisMask(Operation *op, uint64_t ellipsis_mask,
|
||||
LogicalResult RewriteEllipsisMask(Operation *op,
|
||||
PatternRewriter &rewriter) const {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
|
||||
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
|
||||
uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask();
|
||||
|
||||
// Enforce operator precedence.
|
||||
@ -632,8 +634,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
int64_t begin_mask = strided_slice_op.begin_mask();
|
||||
int64_t end_mask = strided_slice_op.end_mask();
|
||||
int64_t new_begin_mask = 0;
|
||||
int64_t new_end_mask = 0;
|
||||
int64_t revised_begin_mask = 0;
|
||||
int64_t revised_end_mask = 0;
|
||||
int64_t revised_shrink_axis_mask = 0;
|
||||
|
||||
SmallVector<int32_t, 4> padded_begin;
|
||||
@ -647,8 +649,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
padded_begin.push_back(begin_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
|
||||
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
|
||||
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
|
||||
if ((shrink_axis_mask >> index) & 1)
|
||||
revised_shrink_axis_mask |= (1 << new_index);
|
||||
++index;
|
||||
@ -657,8 +659,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
|
||||
// Ellipsis.
|
||||
for (; new_index < index + ellipsis_filled_dim_size; ++new_index) {
|
||||
new_begin_mask |= (1 << new_index);
|
||||
new_end_mask |= (1 << new_index);
|
||||
revised_begin_mask |= (1 << new_index);
|
||||
revised_end_mask |= (1 << new_index);
|
||||
|
||||
// Mimic the begin/end/strides mask behavior.
|
||||
padded_begin.push_back(0);
|
||||
@ -675,8 +677,8 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
padded_end.push_back(end_dense_elem_attr.getValue<int32_t>(index));
|
||||
padded_stride.push_back(stride_dense_elem_attr.getValue<int32_t>(index));
|
||||
|
||||
if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index);
|
||||
if ((begin_mask >> index) & 1) revised_begin_mask |= (1 << new_index);
|
||||
if ((end_mask >> index) & 1) revised_end_mask |= (1 << new_index);
|
||||
if ((shrink_axis_mask >> index) & 1)
|
||||
revised_shrink_axis_mask |= (1 << new_index);
|
||||
|
||||
@ -701,35 +703,135 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
||||
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
|
||||
op, strided_slice_op.getType(), input, begin_op.getResult(),
|
||||
end_op.getResult(), stride_op.getResult(),
|
||||
rewriter.getIntegerAttr(attribute_type, new_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, new_end_mask),
|
||||
/*ellipsis_maks=*/rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_end_mask),
|
||||
/*ellipsis_mask=*/rewriter.getI64IntegerAttr(0),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.new_axis_mask()),
|
||||
rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask));
|
||||
return success();
|
||||
}
|
||||
|
||||
void PadStridedSliceAttributeArray(DenseIntElementsAttr dense_elem_attr,
|
||||
SmallVectorImpl<int32_t> &val,
|
||||
SmallVectorImpl<int32_t> &padded_val,
|
||||
ArrayRef<int32_t> padding_val,
|
||||
int *mask) const {
|
||||
for (const auto &idx : dense_elem_attr.getIntValues()) {
|
||||
val.push_back(idx.getSExtValue());
|
||||
padded_val.push_back(idx.getSExtValue());
|
||||
}
|
||||
int attr_dim_count = val.size();
|
||||
int full_dim_count = padding_val.size();
|
||||
for (int i = attr_dim_count; i < full_dim_count; ++i) {
|
||||
padded_val.push_back(padding_val[i]);
|
||||
if (mask) *mask |= 1 << i;
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TF::StridedSliceOp strided_slice_op = llvm::cast<TF::StridedSliceOp>(op);
|
||||
|
||||
// Handle new axis mask.
|
||||
uint64_t new_axis_mask = strided_slice_op.new_axis_mask();
|
||||
if (new_axis_mask != 0) {
|
||||
if (strided_slice_op.new_axis_mask() != 0) {
|
||||
// We currently don't handle simultaneous shrink_ and new_axis masks.
|
||||
if (strided_slice_op.shrink_axis_mask()) {
|
||||
return failure();
|
||||
if (!strided_slice_op.shrink_axis_mask()) {
|
||||
return RewriteNewAxisMask(strided_slice_op, rewriter);
|
||||
}
|
||||
return RewriteNewAxisMask(strided_slice_op, new_axis_mask, rewriter);
|
||||
}
|
||||
|
||||
// Handle ellipsis mask.
|
||||
uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask();
|
||||
if (ellipsis_mask != 0) {
|
||||
return RewriteEllipsisMask(strided_slice_op, ellipsis_mask, rewriter);
|
||||
if (strided_slice_op.ellipsis_mask() != 0) {
|
||||
return RewriteEllipsisMask(strided_slice_op, rewriter);
|
||||
}
|
||||
return failure();
|
||||
|
||||
auto ranked_input_type =
|
||||
strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_input_type) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto begin_attr = strided_slice_op.begin();
|
||||
auto end_attr = strided_slice_op.end();
|
||||
auto strides_attr = strided_slice_op.strides();
|
||||
|
||||
auto begin_attr_type = begin_attr.getType().dyn_cast<RankedTensorType>();
|
||||
auto end_attr_type = end_attr.getType().dyn_cast<RankedTensorType>();
|
||||
auto strides_attr_type =
|
||||
strides_attr.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
DenseIntElementsAttr begin_elem_attr;
|
||||
DenseIntElementsAttr end_elem_attr;
|
||||
DenseIntElementsAttr strides_elem_attr;
|
||||
|
||||
if (!begin_attr_type ||
|
||||
!matchPattern(begin_attr, m_Constant(&begin_elem_attr))) {
|
||||
return failure();
|
||||
}
|
||||
if (!end_attr_type || !matchPattern(end_attr, m_Constant(&end_elem_attr))) {
|
||||
return failure();
|
||||
}
|
||||
if (!strides_attr_type ||
|
||||
!matchPattern(strides_attr, m_Constant(&strides_elem_attr))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int32_t, 4> begin, end, strides;
|
||||
SmallVector<int32_t, 4> padded_begin, padded_end, padded_strides;
|
||||
|
||||
int num_input_dims = ranked_input_type.getRank();
|
||||
SmallVector<int32_t, 4> padding_begin(num_input_dims, 0);
|
||||
auto input_shape = ranked_input_type.getShape();
|
||||
SmallVector<int32_t, 4> padding_end(input_shape.begin(), input_shape.end());
|
||||
SmallVector<int32_t, 4> padding_strides(num_input_dims, 1);
|
||||
|
||||
int begin_mask = strided_slice_op.begin_mask();
|
||||
int end_mask = strided_slice_op.end_mask();
|
||||
|
||||
PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin,
|
||||
padding_begin, &begin_mask);
|
||||
PadStridedSliceAttributeArray(end_elem_attr, end, padded_end, padding_end,
|
||||
&end_mask);
|
||||
PadStridedSliceAttributeArray(strides_elem_attr, strides, padded_strides,
|
||||
padding_strides, nullptr);
|
||||
|
||||
if (begin == padded_begin && end == padded_end &&
|
||||
strides == padded_strides &&
|
||||
begin_mask == strided_slice_op.begin_mask() &&
|
||||
end_mask == strided_slice_op.end_mask()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto begin_end_type =
|
||||
RankedTensorType::get({num_input_dims}, rewriter.getIntegerType(32));
|
||||
auto new_begin_attr = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), begin_end_type,
|
||||
DenseElementsAttr::get<int32_t>(begin_end_type, padded_begin));
|
||||
auto new_end_attr = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), begin_end_type,
|
||||
DenseElementsAttr::get<int32_t>(begin_end_type, padded_end));
|
||||
auto strides_type =
|
||||
RankedTensorType::get({static_cast<long>(padded_strides.size())},
|
||||
rewriter.getIntegerType(32));
|
||||
auto new_strides_attr = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), strides_type,
|
||||
DenseElementsAttr::get<int32_t>(strides_type, padded_strides));
|
||||
|
||||
auto attribute_type = rewriter.getIntegerType(64);
|
||||
rewriter.replaceOpWithNewOp<TF::StridedSliceOp>(
|
||||
op, strided_slice_op.output().getType(), strided_slice_op.input(),
|
||||
new_begin_attr, new_end_attr, new_strides_attr,
|
||||
rewriter.getIntegerAttr(attribute_type, begin_mask),
|
||||
rewriter.getIntegerAttr(attribute_type, end_mask),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.ellipsis_mask()),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.new_axis_mask()),
|
||||
rewriter.getIntegerAttr(attribute_type,
|
||||
strided_slice_op.shrink_axis_mask()));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -57,6 +57,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) {
|
||||
return mlir::ComplexType::get(builder.getF64Type());
|
||||
case tflite::TensorType_INT8:
|
||||
return builder.getIntegerType(8);
|
||||
case tflite::TensorType_UINT64:
|
||||
return builder.getIntegerType(64, /*isSigned=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,6 +88,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) {
|
||||
return tensorflow::DT_STRING;
|
||||
case tflite::TensorType_UINT8:
|
||||
return tensorflow::DT_UINT8;
|
||||
case tflite::TensorType_UINT64:
|
||||
return tensorflow::DT_UINT64;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core/common_runtime:core_cpu_base_no_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -30,6 +30,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
|
||||
#include "tensorflow/core/common_runtime/function_body.h"
|
||||
#include "tensorflow/core/common_runtime/function_def_utils.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
@ -111,8 +113,24 @@ std::string ImportFunction(const std::string &functiondef_proto,
|
||||
}
|
||||
|
||||
const std::string &function_name = functiondef.signature().name();
|
||||
|
||||
const tensorflow::FunctionDef *fdef = flib_def.Find(function_name);
|
||||
if (fdef == nullptr) {
|
||||
s = tensorflow::errors::NotFound("Cannot find function ", function_name);
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::FunctionBody> fbody;
|
||||
s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
|
||||
&fbody);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return "// error";
|
||||
}
|
||||
|
||||
mlir::MLIRContext context;
|
||||
auto module = ConvertFunctionToMlir(function_name, flib_def, &context);
|
||||
auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
|
||||
if (!module.ok()) {
|
||||
Set_TF_Status_from_Status(status, module.status());
|
||||
return "// error";
|
||||
|
@ -42,10 +42,10 @@ filegroup(
|
||||
"ir/tf_ops.td",
|
||||
"ir/tfrt_ops.td",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:SideEffectTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
|
||||
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -337,6 +337,7 @@ cc_library(
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
@ -824,6 +825,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tf_pass_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-pass-decls -name TensorFlow",
|
||||
"transforms/tf_passes.h.inc",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/tf_passes.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:PassBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
@ -859,6 +876,7 @@ cc_library(
|
||||
"transforms/mark_ops_for_outside_compilation.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/outside_compiled_to_host_launch.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/parallelize_embedding_params_ops_pass.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
@ -918,6 +936,8 @@ cc_library(
|
||||
includes = ["include"],
|
||||
textual_hdrs = [
|
||||
"ir/tf_ops_helpers.inc",
|
||||
"transforms/passes_detail.h",
|
||||
"transforms/tf_passes.h.inc",
|
||||
],
|
||||
deps = [
|
||||
":attribute_utils",
|
||||
@ -940,6 +960,7 @@ cc_library(
|
||||
":tensorflow_optimize_inc_gen",
|
||||
":tensorflow_types",
|
||||
":tf_data_optimization",
|
||||
":tf_pass_inc_gen",
|
||||
":tpu_rewrite_device_util",
|
||||
":translate_utils",
|
||||
":unroll_batch_matmul_pass",
|
||||
@ -1043,6 +1064,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
||||
def TF_AbsOp : TF_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_AbsOp : TF_Op<"Abs", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes the absolute value of a tensor.";
|
||||
|
||||
let description = [{
|
||||
@ -604,29 +604,6 @@ If `condition` evaluates to false, print the list of tensors in `data`.
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_AssignOp : TF_Op<"Assign", [NoSideEffect]> {
|
||||
let summary = "Update 'ref' by assigning 'value' to it.";
|
||||
|
||||
let description = [{
|
||||
This operation outputs "ref" after the assignment is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$ref,
|
||||
TF_Tensor:$value,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$validate_shape,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$use_locking
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output_ref
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AssignAddVariableOp : TF_Op<"AssignAddVariableOp", []> {
|
||||
let summary = "Adds a value to the current value of a variable.";
|
||||
|
||||
@ -1025,13 +1002,13 @@ reverse of SpaceToBatch. See below for a precise description.
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tcrops = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> {
|
||||
@ -1509,7 +1486,7 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_CeilOp : TF_Op<"Ceil", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns element-wise smallest integer not less than x.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -2936,6 +2913,67 @@ Converts the given variant tensor to an iterator and stores it in the given reso
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_DeserializeSparseOp : TF_Op<"DeserializeSparse", [NoSideEffect]> {
|
||||
let summary = "Deserialize `SparseTensor` objects.";
|
||||
|
||||
let description = [{
|
||||
The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where
|
||||
the last dimension stores serialized `SparseTensor` objects and the other N
|
||||
dimensions (N >= 0) correspond to a batch. The ranks of the original
|
||||
`SparseTensor` objects must all match. When the final `SparseTensor` is
|
||||
created, its rank is the rank of the incoming `SparseTensor` objects plus N;
|
||||
the sparse tensors have been concatenated along new dimensions, one for each
|
||||
batch.
|
||||
|
||||
The output `SparseTensor` object's shape values for the original dimensions
|
||||
are the max across the input `SparseTensor` objects' shape values for the
|
||||
corresponding dimensions. The new dimensions match the size of the batch.
|
||||
|
||||
The input `SparseTensor` objects' indices are assumed ordered in
|
||||
standard lexicographic order. If this is not the case, after this
|
||||
step run `SparseReorder` to restore index ordering.
|
||||
|
||||
For example, if the serialized input is a `[2 x 3]` matrix representing two
|
||||
original `SparseTensor` objects:
|
||||
|
||||
index = [ 0]
|
||||
[10]
|
||||
[20]
|
||||
values = [1, 2, 3]
|
||||
shape = [50]
|
||||
|
||||
and
|
||||
|
||||
index = [ 2]
|
||||
[10]
|
||||
values = [4, 5]
|
||||
shape = [30]
|
||||
|
||||
then the final deserialized `SparseTensor` will be:
|
||||
|
||||
index = [0 0]
|
||||
[0 10]
|
||||
[0 20]
|
||||
[1 2]
|
||||
[1 10]
|
||||
values = [1, 2, 3, 4, 5]
|
||||
shape = [2 50]
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Str, TF_Variant]>:$serialized_sparse
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Int64Tensor:$sparse_indices,
|
||||
TF_Tensor:$sparse_values,
|
||||
TF_Int64Tensor:$sparse_shape
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tserialized = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_DestroyResourceOp : TF_Op<"DestroyResourceOp", []> {
|
||||
let summary = "Deletes the resource specified by the handle.";
|
||||
|
||||
@ -3464,8 +3502,8 @@ tf.math.equal(x, y) ==> array([True, True])
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
||||
TF_Tensor:$x,
|
||||
TF_Tensor:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||
);
|
||||
@ -3805,8 +3843,8 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
|
||||
let summary = "Compute gradients for a FakeQuantWithMinMaxArgs operation.";
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$gradients,
|
||||
F32Tensor:$inputs,
|
||||
TF_Float32Tensor:$gradients,
|
||||
TF_Float32Tensor:$inputs,
|
||||
|
||||
DefaultValuedAttr<F32Attr, "-6.0f">:$min,
|
||||
DefaultValuedAttr<F32Attr, "6.0f">:$max,
|
||||
@ -3815,7 +3853,7 @@ def TF_FakeQuantWithMinMaxArgsGradientOp : TF_Op<"FakeQuantWithMinMaxArgsGradien
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$backprops
|
||||
TF_Float32Tensor:$backprops
|
||||
);
|
||||
}
|
||||
|
||||
@ -3873,19 +3911,19 @@ def TF_FakeQuantWithMinMaxVarsGradientOp : TF_Op<"FakeQuantWithMinMaxVarsGradien
|
||||
let summary = "Compute gradients for a FakeQuantWithMinMaxVars operation.";
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$gradients,
|
||||
F32Tensor:$inputs,
|
||||
F32Tensor:$min,
|
||||
F32Tensor:$max,
|
||||
TF_Float32Tensor:$gradients,
|
||||
TF_Float32Tensor:$inputs,
|
||||
TF_Float32Tensor:$min,
|
||||
TF_Float32Tensor:$max,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "8">:$num_bits,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$narrow_range
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$backprops_wrt_input,
|
||||
F32Tensor:$backprop_wrt_min,
|
||||
F32Tensor:$backprop_wrt_max
|
||||
TF_Float32Tensor:$backprops_wrt_input,
|
||||
TF_Float32Tensor:$backprop_wrt_min,
|
||||
TF_Float32Tensor:$backprop_wrt_max
|
||||
);
|
||||
}
|
||||
|
||||
@ -3988,7 +4026,7 @@ fill([2, 3], 9) ==> [[9, 9, 9]
|
||||
];
|
||||
}
|
||||
|
||||
def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_FloorOp : TF_Op<"Floor", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns element-wise largest integer not greater than x.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -4939,13 +4977,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$predictions,
|
||||
TF_Float32Tensor:$predictions,
|
||||
TF_I32OrI64Tensor:$targets,
|
||||
TF_I32OrI64Tensor:$k
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I1Tensor:$precision
|
||||
TF_BoolTensor:$precision
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
@ -6523,6 +6561,8 @@ tensor of rank `k+1` with dimensions `[I, J, K, ..., M, N]` where:
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_MatrixSetDiagV2Op : TF_Op<"MatrixSetDiagV2", [NoSideEffect]> {
|
||||
@ -6615,6 +6655,8 @@ tf.matrix_set_diag(diagonals, k = (-1, 0))
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_MatrixSetDiagV3Op : TF_Op<"MatrixSetDiagV3", [NoSideEffect]> {
|
||||
@ -6870,7 +6912,7 @@ retained with length 1.
|
||||
];
|
||||
}
|
||||
|
||||
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> {
|
||||
let summary = "Performs max pooling on the input.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -6894,6 +6936,9 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
// TF_LayoutSensitiveInterface:
|
||||
StringRef GetOptimalLayout(const RuntimeDevices& devices);
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
@ -7810,8 +7855,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
||||
TF_Tensor:$x,
|
||||
TF_Tensor:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||
);
|
||||
@ -7989,7 +8034,7 @@ times by rerunning "MakeIterator".
|
||||
);
|
||||
}
|
||||
|
||||
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_OnesLikeOp : TF_Op<"OnesLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns a tensor of ones with the same shape and type as x.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -8387,6 +8432,10 @@ def TF_QrOp : TF_Op<"Qr", [NoSideEffect]> {
|
||||
Computes the QR decomposition of each inner matrix in `tensor` such that
|
||||
`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
|
||||
|
||||
Currently, the gradient for the QR decomposition is well-defined only when
|
||||
the first `P` columns of the inner matrix are linearly independent, where
|
||||
`P` is the minimum of `M` and `N`, the 2 inner-most dimmensions of `tensor`.
|
||||
|
||||
```python
|
||||
# a is a tensor.
|
||||
# q is a tensor of orthonormal matrices.
|
||||
@ -9072,7 +9121,7 @@ most one RecvTPUEmbeddingActivations op in the TPU graph.
|
||||
TF_DerivedResultSizeAttr num_outputs = TF_DerivedResultSizeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
|
||||
def TF_ReluOp : TF_Op<"Relu", [Idempotent, NoSideEffect, SameOperandsAndResultType, TF_ContractionFusableInterface, TF_LayoutAgnostic]> {
|
||||
let summary = "Computes rectified linear: `max(features, 0)`.";
|
||||
|
||||
let description = [{
|
||||
@ -9098,7 +9147,7 @@ array([ 0., 0., -0., 3.], dtype=float32)
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_Relu6Op : TF_Op<"Relu6", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`.";
|
||||
|
||||
let arguments = (ins
|
||||
@ -10496,7 +10545,7 @@ bitwise_ops.right_shift(lhs, rhs)
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RintOp : TF_Op<"Rint", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_RintOp : TF_Op<"Rint", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns element-wise integer closest to x.";
|
||||
|
||||
let description = [{
|
||||
@ -10563,7 +10612,7 @@ roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
|
||||
TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>;
|
||||
}
|
||||
|
||||
def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_RoundOp : TF_Op<"Round", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
}];
|
||||
@ -11293,7 +11342,7 @@ Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SignOp : TF_Op<"Sign", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_SignOp : TF_Op<"Sign", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns an element-wise indication of the sign of a number.";
|
||||
|
||||
let description = [{
|
||||
@ -12303,14 +12352,14 @@ The outputs are a deterministic function of `shape`, `seed`, and `alpha`.
|
||||
}
|
||||
|
||||
def TF_StatelessRandomGetAlgOp : TF_Op<"StatelessRandomGetAlg", []> {
|
||||
let summary = [{
|
||||
Picks the best counter-based RNG algorithm based on device.
|
||||
}];
|
||||
let summary = "Picks the best counter-based RNG algorithm based on device.";
|
||||
|
||||
let description = [{
|
||||
This op picks the best counter-based RNG algorithm based on device.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs
|
||||
TF_Int32Tensor:$alg
|
||||
);
|
||||
@ -14046,73 +14095,35 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
|
||||
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
|
||||
for the existing tensor cannot be re-used, a copy is made and updated.
|
||||
|
||||
If `indices` contains duplicates, then their updates are accumulated (summed).
|
||||
If `indices` contains duplicates, then we pick the last update for the index.
|
||||
|
||||
**WARNING**: The order in which updates are applied is nondeterministic, so the
|
||||
output will be nondeterministic if `indices` contains duplicates -- because
|
||||
of some numerical approximation issues, numbers summed in different order
|
||||
may yield different results.
|
||||
If an out of bound index is found on CPU, an error is returned.
|
||||
|
||||
**WARNING**: There are some GPU specific semantics for this operation.
|
||||
- If an out of bound index is found, the index is ignored.
|
||||
- The order in which updates are applied is nondeterministic, so the output
|
||||
will be nondeterministic if `indices` contains duplicates.
|
||||
|
||||
`indices` is an integer tensor containing indices into a new tensor of shape
|
||||
`shape`. The last dimension of `indices` can be at most the rank of `shape`:
|
||||
`shape`.
|
||||
|
||||
indices.shape[-1] <= shape.rank
|
||||
* `indices` must have at least 2 axes: `(num_updates, index_depth)`.
|
||||
* The last axis of `indices` is how deep to index into `tensor` so this index
|
||||
depth must be less than the rank of `tensor`: `indices.shape[-1] <= tensor.ndim`
|
||||
|
||||
The last dimension of `indices` corresponds to indices into elements
|
||||
(if `indices.shape[-1] = shape.rank`) or slices
|
||||
(if `indices.shape[-1] < shape.rank`) along dimension `indices.shape[-1]` of
|
||||
`shape`. `updates` is a tensor with shape
|
||||
if `indices.shape[-1] = tensor.rank` this Op indexes and updates scalar elements.
|
||||
if `indices.shape[-1] < tensor.rank` it indexes and updates slices of the input
|
||||
`tensor`.
|
||||
|
||||
indices.shape[:-1] + shape[indices.shape[-1]:]
|
||||
Each `update` has a rank of `tensor.rank - indices.shape[-1]`.
|
||||
The overall shape of `updates` is:
|
||||
|
||||
The simplest form of scatter is to insert individual elements in a tensor by
|
||||
index. For example, say we want to insert 4 scattered elements in a rank-1
|
||||
tensor with 8 elements.
|
||||
```
|
||||
indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
|
||||
```
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterNd1.png" alt>
|
||||
</div>
|
||||
|
||||
In Python, this scatter operation would look like this:
|
||||
|
||||
>>> indices = tf.constant([[4], [3], [1], [7]])
|
||||
>>> updates = tf.constant([9, 10, 11, 12])
|
||||
>>> tensor = tf.ones([8], dtype=tf.int32)
|
||||
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates))
|
||||
tf.Tensor([ 1 11 1 10 9 1 1 12], shape=(8,), dtype=int32)
|
||||
|
||||
We can also, insert entire slices of a higher rank tensor all at once. For
|
||||
example, if we wanted to insert two slices in the first dimension of a
|
||||
rank-3 tensor with two matrices of new values.
|
||||
|
||||
In Python, this scatter operation would look like this:
|
||||
|
||||
>>> indices = tf.constant([[0], [2]])
|
||||
>>> updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
... [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
... [[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
... [7, 7, 7, 7], [8, 8, 8, 8]]])
|
||||
>>> tensor = tf.ones([4, 4, 4], dtype=tf.int32)
|
||||
>>> print(tf.tensor_scatter_nd_update(tensor, indices, updates).numpy())
|
||||
[[[5 5 5 5]
|
||||
[6 6 6 6]
|
||||
[7 7 7 7]
|
||||
[8 8 8 8]]
|
||||
[[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]]
|
||||
[[5 5 5 5]
|
||||
[6 6 6 6]
|
||||
[7 7 7 7]
|
||||
[8 8 8 8]]
|
||||
[[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]]]
|
||||
|
||||
Note that on CPU, if an out of bound index is found, an error is returned.
|
||||
On GPU, if an out of bound index is found, the index is ignored.
|
||||
For usage examples see the python [tf.tensor_scatter_nd_update](
|
||||
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) function
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@ -15038,7 +15049,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{The array we're gathering from.}]>:$operand,
|
||||
Arg<TF_I32OrI64Tensor, [{Array containing the starting indices of the slices we gather.}]>:$start_indices,
|
||||
Arg<TF_I32OrI64Tensor, [{slice_sizes[i] is the bounds for the slice on dimension i.}]>:$slice_sizes,
|
||||
|
||||
@ -15047,7 +15058,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>;
|
||||
@ -15415,7 +15426,7 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_Sam
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_ZerosLikeOp : TF_Op<"ZerosLike", [Idempotent, NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns a tensor of zeros with the same shape and type as x.";
|
||||
|
||||
let arguments = (ins
|
||||
|
@ -684,12 +684,23 @@ body: A function that takes a list of tensors and returns another
|
||||
|
||||
FlatSymbolRefAttr:$cond,
|
||||
FlatSymbolRefAttr:$body,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
BoolAttr:$is_stateless
|
||||
BoolAttr:$is_stateless,
|
||||
|
||||
// In TensorFlow, While has a special behavior where if `output_shapes`
|
||||
// attribute is not empty, those shapes are used in its shape function
|
||||
// as result shapes instead of propagating operand shapes as result shapes.
|
||||
// This allows for different result shapes from operand shapes. While these
|
||||
// shapes are imported and set as a part of the result type, there is no
|
||||
// indicator differentiating between having no output shapes compared to
|
||||
// having all unranked shapes. Thus this attribute is set to determine
|
||||
// which shape function behavior to use for this op, specifically
|
||||
// propagating operand shapes as result shapes when this attribute is not
|
||||
// set, or preserving result shapes as is when this attribute is set.
|
||||
UnitAttr:$shape_invariant
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@ -697,6 +708,7 @@ body: A function that takes a list of tensors and returns another
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
@ -752,12 +764,23 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
|
||||
let arguments = (ins
|
||||
Variadic<AnyTensor>:$input,
|
||||
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
BoolAttr:$is_stateless
|
||||
BoolAttr:$is_stateless,
|
||||
|
||||
// In TensorFlow, While has a special behavior where if `output_shapes`
|
||||
// attribute is not empty, those shapes are used in its shape function
|
||||
// as result shapes instead of propagating operand shapes as result shapes.
|
||||
// This allows for different result shapes from operand shapes. While these
|
||||
// shapes are imported and set as a part of the result type, there is no
|
||||
// indicator differentiating between having no output shapes compared to
|
||||
// having all unranked shapes. Thus this attribute is set to determine
|
||||
// which shape function behavior to use for this op, specifically
|
||||
// propagating operand shapes as result shapes when this attribute is not
|
||||
// set, or preserving result shapes as is when this attribute is set.
|
||||
UnitAttr:$shape_invariant
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
|
||||
@ -1972,4 +1995,31 @@ operations inside a TPU host.
|
||||
);
|
||||
}
|
||||
|
||||
def TF_AssignOp : TF_Op<"Assign", []> {
|
||||
let summary = "Update 'ref' by assigning 'value' to it.";
|
||||
|
||||
let description = [{
|
||||
This operation outputs "ref" after the assignment is done.
|
||||
This makes it easier to chain operations that need to use the reset value.
|
||||
|
||||
This is a side-effecting operation because it will change the value of its
|
||||
argument "ref" in addition to returning the results.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$ref,
|
||||
TF_Tensor:$value,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "true">:$validate_shape,
|
||||
DefaultValuedAttr<BoolAttr, "true">:$use_locking
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output_ref
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
|
||||
#endif // TF_OPS
|
||||
|
@ -2428,6 +2428,24 @@ static LogicalResult Verify(MatrixBandPartOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixSetDiagOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
void MatrixSetDiagOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<MatrixSetDiagToV3>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixSetDiagV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void MatrixSetDiagV2Op::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<MatrixSetDiagV2ToV3>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2449,6 +2467,33 @@ LogicalResult MaxPoolOp::FoldOperandsPermutation(
|
||||
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
|
||||
}
|
||||
|
||||
LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) {
|
||||
StringRef src_data_format = data_format();
|
||||
|
||||
auto perm = GetDataFormatPermutation(src_data_format, new_data_format);
|
||||
if (perm.empty()) return failure();
|
||||
|
||||
// Update data_format attribute and result types.
|
||||
if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this)))
|
||||
return failure();
|
||||
|
||||
stridesAttr(ShuffleArrayAttr(strides(), perm));
|
||||
explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
ksizeAttr(ShuffleArrayAttr(ksize(), perm));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) {
|
||||
// Keep current data format if no GPUs are available or if explicit placement
|
||||
// does not allow to use GPU for this operation.
|
||||
if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation()))
|
||||
return data_format();
|
||||
|
||||
// Defaults to NCHW.
|
||||
return "NCHW";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxPoolGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1280,3 +1280,24 @@ func @testSumFoldBypass(%arg0: tensor<4x?xf16>, %arg1: tensor<*xi64>) -> tensor<
|
||||
%0 = "tf.Sum"(%arg0, %arg1) { keep_dims = false }: (tensor<4x?xf16>, tensor<*xi64>) -> tensor<4x?xf16>
|
||||
return %0 : tensor<4x?xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @testMatrixSetDiag
|
||||
func @testMatrixSetDiag(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi64>) -> tensor<3x3xi64> {
|
||||
%0 = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi64>, tensor<3xi64>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>}
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV3"(%arg0, %arg1, %[[ZERO]])
|
||||
// CHECK-SAME: {align = "RIGHT_LEFT"}
|
||||
// CHECK-SAME: (tensor<3x3xi64>, tensor<3xi64>, tensor<i32>) -> tensor<3x3xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @testMatrixSetDiagV2
|
||||
func @testMatrixSetDiagV2(%arg0: tensor<3x3xi64>, %arg1: tensor<3xi64>, %arg2: tensor<i32>) -> tensor<3x3xi64> {
|
||||
%0 = "tf.MatrixSetDiagV2"(%arg0, %arg1, %arg2) : (tensor<3x3xi64>, tensor<3xi64>, tensor<i32>) -> tensor<3x3xi64>
|
||||
return %0 : tensor<3x3xi64>
|
||||
|
||||
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiagV3"(%arg0, %arg1, %arg2)
|
||||
// CHECK-SAME: {align = "LEFT_LEFT"}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,8 @@
|
||||
// CHECK: func @main(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, %[[ARG_1:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:1/device:CPU:0"})
|
||||
// CHECK-NEXT: %[[RESULT_0:.*]]:2 = tf_device.remote_run "/job:worker/replica:0/task:0" @_job_worker_replica_0_task_0(%[[ARG_0]])
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = tf_device.remote_run "/job:worker/replica:0/task:1" @_job_worker_replica_0_task_1(%[[ARG_1]])
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.Const"() {value = dense<16> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK-NEXT: %[[RESULT_3:.*]] = "tf.AddV2"(%[[RESULT_2]], %[[RESULT_2]]) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK-NEXT: return %[[RESULT_0]]#0, %[[RESULT_0]]#1, %[[RESULT_1]] : tensor<i32>, tensor<i32>, tensor<i32>
|
||||
|
||||
// CHECK: func @_job_worker_replica_0_task_0(%[[ARG_0:.*]]: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}) -> (tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:0"}, tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:CPU:1"})
|
||||
@ -21,5 +23,8 @@ func @main(%arg0: tensor<i32> {tf.device = "/job:worker/replica:0/task:0/device:
|
||||
%1 = "tf.Mul"(%0, %0) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%2 = "tf.AddV2"(%arg0, %arg0) {device = "/job:worker/replica:0/task:0/device:CPU:1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%3 = "tf.AddV2"(%arg1, %arg1) {device = "/job:worker/replica:0/task:1/device:CPU:0"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%4 = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
|
||||
%5 = "tf.AddV2"(%4, %4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %1, %2, %3 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
|
@ -1,16 +1,47 @@
|
||||
// RUN: tf-opt -tf-tensor-device-copy %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @fold_identity
|
||||
// CHECK-SAME: ([[arg0:%.*]]: tensor<2x2xf32>, [[arg1:%.*]]: tensor<2x2xf32>
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32}} {
|
||||
func @fold_identity(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.MatMul
|
||||
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NOT: tf.Identity
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
|
||||
}
|
||||
return %0 : tensor<2x2xf32>
|
||||
// CHECK-LABEL: func @fold_identity_test
|
||||
func @fold_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.MatMul
|
||||
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NOT: tf.Identity
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
|
||||
}
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @keep_identity_test
|
||||
func @keep_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.MatMul
|
||||
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: tf.Identity
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = "/device:CPU:0"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
|
||||
}
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @while_loop_test(%[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<i32>, %arg2: tensor<*xf32>)
|
||||
func @while_loop_test(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<*xf32>) {
|
||||
// CHECK-NEXT: tf.WhileRegion
|
||||
%0:2 = "tf.WhileRegion"(%arg0, %arg2) ( {
|
||||
// CHECK-NEXT: bb0(%[[ARG_3:.*]]: tensor<i32>, %[[ARG_4:.*]]: tensor<*xf32>)
|
||||
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[ARG_3]])
|
||||
%1 = "tf.Identity"(%arg3) : (tensor<i32>) -> tensor<i32>
|
||||
%2 = "tf.Identity"(%arg1) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.NotEqual"(%[[RESULT_1]], %[[ARG_1]])
|
||||
%3 = "tf.NotEqual"(%1, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%3) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%1 = "tf.Sub"(%arg3, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.Yield"(%1, %arg4) : (tensor<i32>, tensor<*xf32>) -> ()
|
||||
}) {is_stateless = true} : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
@ -41,3 +41,34 @@ func @broadcast_mul_implicit_no_fold(%arg0: tensor<5x7xf32>, %arg1: tensor<5xf32
|
||||
// CHECK: %[[V1:.*]] = "tf.Mul"(%arg0, %[[V0]]) : (tensor<5x7xf32>, tensor<3x5x7xf32>) -> tensor<3x5x7xf32>
|
||||
// CHECK: %[[V1]] : tensor<3x5x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_eq
|
||||
func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
|
||||
%cst = constant dense<[5, 7]> : tensor<2xi32>
|
||||
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
|
||||
%1 = "tf.Equal"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
|
||||
return %1 : tensor<5x7xi1>
|
||||
// CHECK: %[[V0:.*]] = "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
|
||||
// CHECK: %[[V0]] : tensor<5x7xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_neq
|
||||
func @broadcast_neq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> {
|
||||
%cst = constant dense<[5, 7]> : tensor<2xi32>
|
||||
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32>
|
||||
%1 = "tf.NotEqual"(%arg0, %0) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1>
|
||||
return %1 : tensor<5x7xi1>
|
||||
// CHECK: %[[V0:.*]] = "tf.NotEqual"(%arg0, %arg1) {incompatible_shape_error = true} : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1>
|
||||
// CHECK: %[[V0]] : tensor<5x7xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_both_operand
|
||||
func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> tensor<5x7xf32> {
|
||||
%cst = constant dense<[5, 7]> : tensor<2xi64>
|
||||
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<7xf32>, tensor<2xi64>) -> tensor<5x7xf32>
|
||||
%1 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<5x1xf32>, tensor<2xi64>) -> tensor<5x7xf32>
|
||||
%2 = "tf.Add"(%0, %1) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32>
|
||||
return %2 : tensor<5x7xf32>
|
||||
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
|
||||
// CHECK: %[[V0]] : tensor<5x7xf32>
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-output-arrays=Add -o - | FileCheck --check-prefix=NONE %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=10:10 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=SOME %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=*:* -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=UNKNOWN %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=input0,input1 -tf-input-shapes=?,1,?:1,?,1 -tf-input-data-types=',DT_INT32' -tf-output-arrays=Add -o - | FileCheck --check-prefix=DYNAMIC %s
|
||||
|
||||
node {
|
||||
name: "Add"
|
||||
@ -61,3 +63,19 @@ versions {
|
||||
# NONE-SAME: outputs = "Add"
|
||||
# NONE: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# NONE: fetch %[[add]]
|
||||
|
||||
# UNKNOWN-LABEL: func @main
|
||||
# UNKNOWN-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<*xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<*xi32>) -> tensor<*xi32>
|
||||
# UNKNOWN-SAME: control_outputs = ""
|
||||
# UNKNOWN-SAME: inputs = "input0,input1"
|
||||
# UNKNOWN-SAME: outputs = "Add"
|
||||
# UNKNOWN: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# UNKNOWN: fetch %[[add]]
|
||||
|
||||
# DYNAMIC-LABEL: func @main
|
||||
# DYNAMIC-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<?x1x?xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<1x?x1xi32>) -> tensor<*xi32>
|
||||
# DYNAMIC-SAME: control_outputs = ""
|
||||
# DYNAMIC-SAME: inputs = "input0,input1"
|
||||
# DYNAMIC-SAME: outputs = "Add"
|
||||
# DYNAMIC: %[[add:.*]], %[[add_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
|
||||
# DYNAMIC: fetch %[[add]]
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s -tf-output-arrays=hash_table_node -o - | FileCheck %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-upgrade-legacy %s -tf-output-arrays=hash_table_node,variable_node -o - | FileCheck %s
|
||||
|
||||
node: {
|
||||
name: "hash_table_node"
|
||||
@ -22,6 +22,29 @@ node: {
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "variable_node"
|
||||
op: "VariableV2"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shared_name"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Call"
|
||||
op: "PartitionedCall"
|
||||
@ -90,6 +113,9 @@ library {
|
||||
# CHECK: tf.HashTableV2
|
||||
# CHECK-SAME: shared_name = "hash_table_node"
|
||||
|
||||
# CHECK: tf.VariableV2
|
||||
# CHECK-SAME: shared_name = "variable_node"
|
||||
|
||||
# CHECK: func private @create_resource
|
||||
# CHECK: tf.HashTableV2
|
||||
# CHECK-SAME: shared_name = "hash_table_node@create_resource"
|
||||
|
@ -1,4 +1,4 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1,WhileWithOutputShapes:1 -o - -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s
|
||||
|
||||
# Verify that TensorFlow While and StatelessWhile ops are mapped to the
|
||||
# composite While op in MLIR with is_stateless attribute set accordingly to
|
||||
@ -6,6 +6,7 @@
|
||||
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} loc("StatefulWhile")
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = true{{.*}} loc("StatelessWhile")
|
||||
# CHECK-DAG: "tf.While"{{.*}} is_stateless = false{{.*}} shape_invariant{{.*}} -> (tensor<i32>, tensor<*xf32>) loc("WhileWithOutputShapes")
|
||||
|
||||
node {
|
||||
name: "StatefulWhile"
|
||||
@ -73,6 +74,51 @@ node {
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "WhileWithOutputShapes"
|
||||
op: "While"
|
||||
input: "iter"
|
||||
input: "val"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "body"
|
||||
value {
|
||||
func {
|
||||
name: "body"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "cond"
|
||||
value {
|
||||
func {
|
||||
name: "cond"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
shape {
|
||||
unknown_rank: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main"
|
||||
op: "_Retval"
|
||||
@ -107,6 +153,23 @@ node {
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "main2"
|
||||
op: "_Retval"
|
||||
input: "WhileWithOutputShapes:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "index"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "iter"
|
||||
op: "Placeholder"
|
||||
|
@ -82,3 +82,20 @@ func @bias_add_nchw(%arg0: tensor<1x256x150x150xf32>, %arg1: tensor<256xf32>) ->
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW", device = ""} : (tensor<1x256x150x150xf32>, tensor<256xf32>) -> tensor<1x256x150x150xf32>
|
||||
return %0 : tensor<1x256x150x150xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: maxpool_nchw
|
||||
func @maxpool_nchw(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32> {
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[R0:.*]] = "tf.Transpose"(%arg0, %[[CST]])
|
||||
// CHECK: %[[R1:.*]] = "tf.MaxPool"(%[[R0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]}
|
||||
// CHECK: %[[CST_0:.*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: "tf.Transpose"(%[[R1]], %[[CST_0]])
|
||||
%0 = "tf.MaxPool"(%arg0)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
ksize = [1, 1, 3, 3],
|
||||
padding = "SAME",
|
||||
strides = [1, 1, 2, 2]
|
||||
} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
|
||||
return %0 : tensor<1x64x56x56xf32>
|
||||
}
|
||||
|
@ -1703,3 +1703,127 @@ func @convert_iota_3d() -> tensor<5x7x9xi32> {
|
||||
return %0 : tensor<5x7x9xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_avgpool_valid(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
|
||||
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
|
||||
// CHECK: }
|
||||
func @convert_avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||
%0 = mhlo.constant dense<0.0> : tensor<f32>
|
||||
%1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32>
|
||||
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%5) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
base_dilations = dense<1> : tensor<4xi64>,
|
||||
padding = dense<0> : tensor<4x2xi64>,
|
||||
window_dilations = dense<1> : tensor<4xi64>,
|
||||
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
|
||||
%3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32>
|
||||
return %3 : tensor<4x7x7x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_avgpool_valid_rw(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32>
|
||||
// CHECK: return %[[VAL_1]] : tensor<4x7x7x8xf32>
|
||||
// CHECK: }
|
||||
func @convert_avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> {
|
||||
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
|
||||
%1 = mhlo.constant dense<0.0> : tensor<f32>
|
||||
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
base_dilations = dense<1> : tensor<4xi64>,
|
||||
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
|
||||
window_dilations = dense<1> : tensor<4xi64>,
|
||||
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
|
||||
%3 = "mhlo.reduce_window"(%0, %1) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
base_dilations = dense<1> : tensor<4xi64>,
|
||||
padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>,
|
||||
window_dilations = dense<1> : tensor<4xi64>,
|
||||
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x8xf32>
|
||||
%4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32>
|
||||
return %4 : tensor<4x7x7x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_avgpool_valid_3d(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool3D"(%[[VAL_0]]) {data_format = "NDHWC", ksize = [1, 3, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 2, 1]} : (tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32>
|
||||
// CHECK: return %[[VAL_1]] : tensor<4x7x7x7x8xf32>
|
||||
// CHECK: }
|
||||
func @convert_avgpool_valid_3d(%arg0: tensor<4x16x16x16x8xf32>) -> tensor<4x7x7x7x8xf32> {
|
||||
%0 = mhlo.constant dense<0.0> : tensor<f32>
|
||||
%1 = mhlo.constant dense<27.0> : tensor<4x7x7x7x8xf32>
|
||||
%2 = "mhlo.reduce_window"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%5 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%5) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
base_dilations = dense<1> : tensor<5xi64>,
|
||||
padding = dense<0> : tensor<5x2xi64>,
|
||||
window_dilations = dense<1> : tensor<5xi64>,
|
||||
window_dimensions = dense<[1, 3, 3, 3, 1]> : tensor<5xi64>,
|
||||
window_strides = dense<[1, 2, 2, 2, 1]> : tensor<5xi64>} : (tensor<4x16x16x16x8xf32>, tensor<f32>) -> tensor<4x7x7x7x8xf32>
|
||||
%3 = mhlo.divide %2, %1 : tensor<4x7x7x7x8xf32>
|
||||
return %3 : tensor<4x7x7x7x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_avgpool_same(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32>
|
||||
// CHECK: return %[[VAL_1]] : tensor<4x8x8x8xf32>
|
||||
// CHECK: }
|
||||
func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> {
|
||||
%0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32>
|
||||
%1 = mhlo.constant dense<0.0> : tensor<f32>
|
||||
%2 = "mhlo.reduce_window"(%arg0, %1) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
base_dilations = dense<1> : tensor<4xi64>,
|
||||
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
|
||||
window_dilations = dense<1> : tensor<4xi64>,
|
||||
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
|
||||
%3 = "mhlo.reduce_window"(%0, %1) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%6 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%6) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
base_dilations = dense<1> : tensor<4xi64>,
|
||||
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
|
||||
window_dilations = dense<1> : tensor<4xi64>,
|
||||
window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor<f32>) -> tensor<4x8x8x8xf32>
|
||||
%4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32>
|
||||
return %4 : tensor<4x8x8x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert_pad(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x128xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f32>) -> tensor<11x131xf32> {
|
||||
// CHECK: %[[VAL_2:.*]] = constant dense<{{\[\[}}1, 2], [0, 3]]> : tensor<2x2xi64>
|
||||
// CHECK: %[[VAL_3:.*]] = "tf.PadV2"(%[[VAL_0]], %[[VAL_2]], %[[VAL_1]]) : (tensor<8x128xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<11x131xf32>
|
||||
// CHECK: return %[[VAL_3]] : tensor<11x131xf32>
|
||||
// CHECK: }
|
||||
func @convert_pad(%arg0: tensor<8x128xf32>, %arg1: tensor<f32>) -> tensor<11x131xf32> {
|
||||
%0 = "mhlo.pad"(%arg0, %arg1) {
|
||||
edge_padding_low = dense<[1, 0]> : tensor<2xi64>,
|
||||
edge_padding_high = dense<[2, 3]> : tensor<2xi64>,
|
||||
interior_padding = dense<0> : tensor<2xi64>
|
||||
} : (tensor<8x128xf32>, tensor<f32>) -> tensor<11x131xf32>
|
||||
return %0 : tensor<11x131xf32>
|
||||
}
|
||||
|
||||
|
@ -244,9 +244,9 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
|
||||
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
|
||||
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
|
||||
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
|
||||
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
|
||||
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.AddV2"([[PADDINGS]]#0, [[PADDINGS]]#1)
|
||||
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
|
||||
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
|
||||
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.AddV2"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
|
||||
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
|
||||
// CHECK-DAG: [[BLOCK_SHAPE_SPLITS:%.+]]:2 = "tf.Split"([[ZERO_I32]], %arg1)
|
||||
// CHECK-DAG: [[OUTER_SHAPE_0:%.+]] = "tf.Div"([[PADDED_SHAPE_SPLITS]]#1, [[BLOCK_SHAPE_SPLITS]]#0)
|
||||
@ -338,10 +338,10 @@ func @fake_quant_with_min_max_args(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-DAG: [[VAL5:%.+]] = "tf.ClipByValue"(%arg0, [[VAL2]], [[VAL1]])
|
||||
// CHECK-DAG: [[VAL6:%.+]] = "tf.Sub"([[VAL5]], [[VAL2]])
|
||||
// CHECK-DAG: [[VAL7:%.+]] = "tf.Mul"([[VAL6]], [[VAL0]])
|
||||
// CHECK-DAG: [[VAL8:%.+]] = "tf.Add"([[VAL7]], [[VAL4]])
|
||||
// CHECK-DAG: [[VAL8:%.+]] = "tf.AddV2"([[VAL7]], [[VAL4]])
|
||||
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
|
||||
// CHECK-DAG: [[VAL10:%.+]] = "tf.Mul"([[VAL9]], [[VAL3]])
|
||||
// CHECK-DAG: [[VAL11:%.+]] = "tf.Add"([[VAL10]], [[VAL2]])
|
||||
// CHECK-DAG: [[VAL11:%.+]] = "tf.AddV2"([[VAL10]], [[VAL2]])
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {max = 1.0 : f32, min = -1.0 : f32, narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
// CHECK: return [[VAL11]]
|
||||
@ -361,7 +361,7 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
|
||||
// CHECK-DAG: [[VAL9:%.+]] = "tf.Floor"([[VAL8]])
|
||||
// CHECK-DAG: [[VAL10:%.+]] = "tf.Sub"([[VAL8]], [[VAL9]])
|
||||
// CHECK-DAG: [[VAL11:%.+]] = "tf.Less"([[VAL10]], [[VAL3]])
|
||||
// CHECK-DAG: [[VAL12:%.+]] = "tf.Add"([[VAL2]], [[VAL9]])
|
||||
// CHECK-DAG: [[VAL12:%.+]] = "tf.AddV2"([[VAL9]], [[VAL2]])
|
||||
// CHECK-DAG: [[VAL13:%.+]] = "tf.Select"([[VAL11]], [[VAL9]], [[VAL12]])
|
||||
// CHECK-DAG: [[VAL14:%.+]] = "tf.ClipByValue"([[VAL13]], [[VAL0]], [[VAL1]]) :
|
||||
// CHECK-DAG: [[VAL15:%.+]] = "tf.Sub"([[VAL0]], [[VAL14]])
|
||||
@ -371,10 +371,10 @@ func @fake_quant_with_min_max_vars(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>,
|
||||
// CHECK-DAG: [[VAL19:%.+]] = "tf.ClipByValue"(%arg0, [[VAL17]], [[VAL18]])
|
||||
// CHECK-DAG: [[VAL20:%.+]] = "tf.Sub"([[VAL19]], [[VAL17]])
|
||||
// CHECK-DAG: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL6]])
|
||||
// CHECK-DAG: [[VAL22:%.+]] = "tf.Add"([[VAL21]], [[VAL3]])
|
||||
// CHECK-DAG: [[VAL22:%.+]] = "tf.AddV2"([[VAL21]], [[VAL3]])
|
||||
// CHECK-DAG: [[VAL23:%.+]] = "tf.Floor"([[VAL22]])
|
||||
// CHECK-DAG: [[VAL24:%.+]] = "tf.Mul"([[VAL23]], [[VAL5]])
|
||||
// CHECK-DAG: [[VAL25:%.+]] = "tf.Add"([[VAL24]], [[VAL17]])
|
||||
// CHECK-DAG: [[VAL25:%.+]] = "tf.AddV2"([[VAL24]], [[VAL17]])
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 8 : i64} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
|
||||
// CHECK: return [[VAL25]]
|
||||
@ -746,7 +746,7 @@ func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
|
||||
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
|
||||
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
|
||||
// CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
|
||||
%0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
@ -761,7 +761,7 @@ func @round_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK-DAG: [[HALF:%.+]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<f32>}
|
||||
// CHECK-DAG: [[CMP:%.+]] = "tf.Less"([[SUB]], [[HALF]])
|
||||
// CHECK-DAG: [[ONE:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK-DAG: [[ADD:%.+]] = "tf.Add"([[ONE]], [[FLOOR]])
|
||||
// CHECK-DAG: [[ADD:%.+]] = "tf.AddV2"([[FLOOR]], [[ONE]])
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "tf.Select"([[CMP]], [[FLOOR]], [[ADD]])
|
||||
%0 = "tf.Round"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -tf-graph-as-function -o - | FileCheck %s
|
||||
|
||||
// Verify arg attributes are exported as device assignment for arg nodes.
|
||||
// Verify arg/ret attributes are exported as device assignment for arg/retval
|
||||
// nodes.
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 121 : i32}} {
|
||||
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
|
||||
func @main(%arg0: tensor<*xf32> {tf.device = "/CPU:0"}, %arg1: tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32> {tf.device = "/CPU:1"})
|
||||
attributes {tf.entry_function = {inputs = "args_0,args_1", outputs = "rets_0,rets_1"}} {
|
||||
%0:2 = tf_executor.graph {
|
||||
%1:3 = tf_executor.island wraps "tf.IdentityN"(%arg0, %arg1) {T = ["tfdtype$DT_FLOAT", "tfdtype$DT_INT32"], device = "", name = "identity"} : (tensor<*xf32>, tensor<2x4x6x8xi32>) -> (tensor<*xf32>, tensor<2x4x6x8xi32>)
|
||||
@ -15,18 +16,39 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "args_0"
|
||||
// CHECK-NEXT: op: "_Arg"
|
||||
// CHECK: device: "/CPU:0"
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 0
|
||||
// CHECK_NEXT: }
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "args_1"
|
||||
// CHECK-NOT: device: "/CPU:0"
|
||||
// CHECK-NEXT: op: "_Arg"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 1
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK: op: "IdentityN"
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "rets_0"
|
||||
// CHECK-NEXT: op: "_Retval"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 0
|
||||
//
|
||||
// CHECK: node {
|
||||
// CHECK-NEXT: name: "rets_1"
|
||||
// CHECK-NEXT: op: "_Retval"
|
||||
// CHECK: device: "/CPU:1"
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "index"
|
||||
// CHECK-NEXT: value {
|
||||
// CHECK-NEXT: i: 1
|
||||
// CHECK_NEXT: }
|
@ -1,12 +1,13 @@
|
||||
// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
|
||||
%0:2 = tf_executor.graph {
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) {
|
||||
%0:3 = tf_executor.graph {
|
||||
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
|
||||
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
|
||||
%outputs_6:2, %control_7 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, shape_invariant} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("WhileWithOutputShapes")
|
||||
tf_executor.fetch %outputs_2#1, %outputs_4#1, %outputs_6#1 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
|
||||
}
|
||||
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>
|
||||
return %0#0, %0#1, %0#2 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32>
|
||||
}
|
||||
|
||||
func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
@ -36,6 +37,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "While"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK-NOT: shape_invariant
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
@ -54,6 +56,7 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "StatelessWhile"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK-NOT: shape_invariant
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
@ -67,3 +70,20 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: name: "WhileWithOutputShapes"
|
||||
// CHECK-NOT: name:
|
||||
// CHECK: op: "While"
|
||||
// CHECK-NOT: is_stateless
|
||||
// CHECK-NOT: shape_invariant
|
||||
// CHECK: attr {
|
||||
// CHECK: key: "output_shapes"
|
||||
// CHECK: value {
|
||||
// CHECK: list {
|
||||
// CHECK: shape {
|
||||
// CHECK: dim {
|
||||
// CHECK: size: 5
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
@ -0,0 +1,153 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-outside-compiled-to-host-launch | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
// expected-error@+1 {{'module' op bad 'tf.devices' attribute at index 0, not a string}}
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = [1]} {
|
||||
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||
func @invalid_device_attribute() -> tensor<?xi32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||
func @empty_outside_compilation_attribute() -> tensor<?xi32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
// expected-error@+1 {{'tf.B' op requires non empty '_xla_outside_compilation' string attribute}}
|
||||
%2 = "tf.B"(%1) {_xla_outside_compilation = ""}: (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
|
||||
// Tests that TPU cluster with no outside compilation does not generate launch op.
|
||||
|
||||
// CHECK-LABEL: func @no_outside_compilation
|
||||
// CHECK-NOT: "tf_device.launch"
|
||||
func @no_outside_compilation() -> tensor<?xi32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
|
||||
// Tests the launch wrap of a single outside compiled cluster with no input or output dependencies.
|
||||
|
||||
// CHECK-LABEL: func @nodep_single_outside_compilation
|
||||
func @nodep_single_outside_compilation() -> () {
|
||||
// CHECK: "tf.A"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests the launch wrap of a single outside compiled cluster with data parallelism.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compilation_with_replicate
|
||||
func @single_outside_compilation_with_replicate(%arg0: tensor<?xi32>) -> () {
|
||||
// CHECK: "tf.A"
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-NEXT: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.B"() : () -> ()
|
||||
"tf.C"(%ri_0) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.D"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Tests launch wrap of a single outside compiled cluster with input/output.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compilation_input_output
|
||||
func @single_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
|
||||
// CHECK: tf_device.return %[[B_OUTPUT]]
|
||||
// CHECK: "tf.C"(%[[LAUNCH_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests launch wrap of multiple outside compiled cluster with input/output.
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compilation_input_output
|
||||
func @multiple_outside_compilation_input_output(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK-NEXT: %[[LAUNCH_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]])
|
||||
// CHECK: tf_device.return %[[B_OUTPUT]]
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[LAUNCH_OUTPUT]])
|
||||
// CHECK-NEXT: %[[LAUNCH_OUTPUT2:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]])
|
||||
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[D_OUTPUT]])
|
||||
// CHECK: tf_device.return %[[E_OUTPUT]]
|
||||
// CHECK: "tf.F"(%[[LAUNCH_OUTPUT2]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%7 = "tf.E"(%6) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%8 = "tf.F"(%7) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %8 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
}
|
@ -1286,3 +1286,41 @@ func @while_body(%arg0: !tf_res) -> !tf_res {
|
||||
%0 = "tf.Cast"(%arg0) : (!tf_res) -> !tf_res
|
||||
return %0 : !tf_res
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests passthrough tf.Cast ops are removed.
|
||||
|
||||
!tf_res_static = type tensor<!tf.resource<tensor<f32>>>
|
||||
!tf_res_dynamic = type tensor<*x!tf.resource<tensor<f32>>>
|
||||
|
||||
// CHECK-LABEL: func @tpu_computation
|
||||
func @tpu_computation(%arg0: !tf_res_static) {
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.While"(%arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (!tf_res_static) -> !tf_res_dynamic
|
||||
%1 = "tf.WhileRegion"(%arg0) ( {
|
||||
^cond(%carg0: !tf_res_static):
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
"tf.Yield"(%2) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body(%barg0: !tf_res_static):
|
||||
// CHECK-NOT: tf.Cast
|
||||
%2 = "tf.Cast"(%barg0) : (!tf_res_static) -> !tf_res_dynamic
|
||||
"tf.Yield"(%2) : (!tf_res_dynamic) -> ()
|
||||
}) {is_stateless = false} : (!tf_res_static) -> !tf_res_dynamic
|
||||
tf_device.return
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @while_cond(%arg0: !tf_res_static) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_body
|
||||
func @while_body(%arg0: !tf_res_static) -> !tf_res_dynamic {
|
||||
// CHECK-NOT: tf.Cast
|
||||
%0 = "tf.Cast"(%arg0) : (!tf_res_static) -> !tf_res_dynamic
|
||||
return %0 : !tf_res_dynamic
|
||||
}
|
||||
|
@ -68,6 +68,35 @@ func @main() -> tensor<i32> {
|
||||
return %size_out : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test inferring shape from the result type of gather.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> tensor<2x3xf32> {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
return %gather : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test inferring shape from the element_shape attribute of gather.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> tensor<*xf32> {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = #tf.shape<*>, dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%indices = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%gather = "tf.TensorArrayGatherV3"(%ta#0, %indices, %ta#1) {element_shape = #tf.shape<3>} : (tensor<!tf.resource>, tensor<2xi32>, tensor<f32>) -> tensor<*xf32>
|
||||
return %gather : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Test tensor array concat and split.
|
||||
|
@ -104,6 +104,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(CreateTPUClusterCleanupAttributesPass());
|
||||
pm.addPass(TFDevice::CreateResourceOpLiftingPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass());
|
||||
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
|
||||
pm.addPass(CreateTPUOutsideCompilationClusterPass());
|
||||
|
@ -158,6 +158,26 @@ def LogicalNotOfLess : Pat<(TF_LogicalNotOp (TF_LessOp $arg0, $arg1)),
|
||||
def LogicalNotOfLessEqual : Pat<(TF_LogicalNotOp (TF_LessEqualOp $arg0, $arg1)),
|
||||
(TF_GreaterOp $arg0, $arg1)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixSetDiag op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class GetI32Attr<int x>: NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr(" # x # ")">;
|
||||
|
||||
class GetStrAttr<string x>: NativeCodeCall<
|
||||
"$_builder.getStringAttr(\"" # x # "\")">;
|
||||
|
||||
def MatrixSetDiagToV3 : Pat<(TF_MatrixSetDiagOp $input, $diag),
|
||||
(TF_MatrixSetDiagV3Op $input, $diag,
|
||||
(TF_ConstOp (GetI32Attr<0>)),
|
||||
(GetStrAttr<"RIGHT_LEFT">))>;
|
||||
|
||||
// MatrixSetDiagToV2 op implicitly used LEFT_LEFT alignment.
|
||||
def MatrixSetDiagV2ToV3 : Pat<(TF_MatrixSetDiagV2Op $input, $diag, $k),
|
||||
(TF_MatrixSetDiagV3Op $input, $diag, $k,
|
||||
(GetStrAttr<"LEFT_LEFT">))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RealDiv op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -43,6 +43,8 @@ using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
|
||||
|
||||
constexpr const char *kDeviceAttr = "device";
|
||||
constexpr const char *kTFDeviceAttr = "tf.device";
|
||||
// TODO(donglin): Handle the case where the address of localhost is different
|
||||
// from /job:localhost/replica:0/task:0.
|
||||
constexpr const char *kLocalhost = "/job:localhost/replica:0/task:0";
|
||||
constexpr const char *kErrorMessage =
|
||||
"The operation that uses the operand is on a different host than the "
|
||||
@ -53,8 +55,9 @@ constexpr const char *kErrorMessage =
|
||||
std::string GetHost(llvm::StringRef device) {
|
||||
ParsedName parsed_name;
|
||||
DeviceNameUtils::ParseFullName(device.str(), &parsed_name);
|
||||
return DeviceNameUtils::ParsedNameToString(
|
||||
std::string result = DeviceNameUtils::ParsedNameToString(
|
||||
DeviceNameUtils::AddressSpace(parsed_name));
|
||||
return result.empty() ? kLocalhost : result;
|
||||
}
|
||||
|
||||
std::string GetHost(Operation *op) {
|
||||
@ -70,12 +73,9 @@ std::string GetHost(Operation *op) {
|
||||
// 1) None of the job/replica/task is specified in the device name.
|
||||
// 2) The job/replica/task in the device name are explicitly specified as
|
||||
// /job:localhost/replica:0/task:0.
|
||||
//
|
||||
// TODO(dnglin): Handle the case where the address of localhost is different
|
||||
// from /job:localhost/replica:0/task:0.
|
||||
bool IsOnLocalHost(llvm::StringRef device) {
|
||||
std::string host = GetHost(device);
|
||||
return host.empty() || host == kLocalhost;
|
||||
return host == kLocalhost;
|
||||
}
|
||||
|
||||
// This structure contains the metadata of the per-host function. All operations
|
||||
|
@ -36,21 +36,15 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace collection_ops_util {
|
||||
|
||||
Value CreateScalarConst(int value, OpBuilder builder, Location loc) {
|
||||
tensorflow::Tensor scalar_tensor(tensorflow::DT_INT32, {});
|
||||
scalar_tensor.scalar<tensorflow::int32>()() = value;
|
||||
return builder.create<TF::ConstOp>(
|
||||
loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
|
||||
Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc) {
|
||||
auto attr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({}, builder.getI32Type()), value);
|
||||
return builder.create<TF::ConstOp>(loc, attr);
|
||||
}
|
||||
|
||||
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
@ -34,7 +33,7 @@ namespace collection_ops_util {
|
||||
// shape [max_element_count, element_shape].
|
||||
|
||||
// Creates an i32 scalar tf.Const.
|
||||
Value CreateScalarConst(int value, OpBuilder builder, Location loc);
|
||||
Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc);
|
||||
|
||||
// Creates an integer vector tf.Const.
|
||||
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
@ -38,6 +39,12 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
|
||||
|
||||
LogicalResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
|
||||
private:
|
||||
template <typename Op>
|
||||
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
|
||||
|
||||
LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
|
||||
};
|
||||
|
||||
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
||||
@ -47,7 +54,27 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
||||
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
if (!op->hasTrait<OpTrait::ResultsBroadcastableShape>()) return failure();
|
||||
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
|
||||
return RewriteOp(op, rewriter);
|
||||
|
||||
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
|
||||
// incompatible_shape_error is `true` (what is also checked by the verifier).
|
||||
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
|
||||
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
|
||||
if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
|
||||
return failure();
|
||||
|
||||
@ -56,6 +83,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
||||
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!result_type || !result_type.hasStaticShape()) return failure();
|
||||
|
||||
bool changed = false;
|
||||
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
|
||||
// Check that the i'th operand is a broadcast.
|
||||
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
|
||||
@ -89,10 +117,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
||||
// Update the operand of the op to be the operand of the broadcast.
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
|
||||
return success();
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return failure();
|
||||
return success(changed);
|
||||
}
|
||||
|
||||
void BroadcastFoldPass::runOnFunction() {
|
||||
|
@ -112,8 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
|
||||
LogicalResult ConvertWhileOp(WhileOp while_op) {
|
||||
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
|
||||
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
|
||||
while_op.output_shapes(), while_op.parallel_iterations(),
|
||||
while_op.is_stateless());
|
||||
while_op.parallel_iterations(), while_op.is_stateless(),
|
||||
while_op.shape_invariant());
|
||||
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
|
||||
|
||||
YieldOp cond_yield =
|
||||
|
@ -200,7 +200,16 @@ class FuseConv2DBiasAdd
|
||||
|
||||
// Performs a fusion of the following pattern(s), if possible:
|
||||
// MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
|
||||
using FuseMatMulBiasAdd = FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp>;
|
||||
class FuseMatMulBiasAdd
|
||||
: public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
|
||||
using FuseContractionWithBiasAdd<MatMulOp,
|
||||
_FusedMatMulOp>::FuseContractionWithBiasAdd;
|
||||
|
||||
bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return matmul.T().isF32() || matmul.T().isBF16();
|
||||
}
|
||||
};
|
||||
|
||||
void FusedKernelMatcherPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -21,15 +21,19 @@ limitations under the License.
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
@ -44,6 +48,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
@ -479,18 +484,16 @@ template <typename ReductionOp>
|
||||
LogicalResult MatchBinaryReduceFunction(mlir::Region &function) {
|
||||
Block &body = function.front();
|
||||
if (body.getNumArguments() != 2) return failure();
|
||||
if (body.getOperations().size() != 2) return failure();
|
||||
|
||||
ReductionOp reduce_op = dyn_cast<ReductionOp>(body.front());
|
||||
if (!reduce_op) return failure();
|
||||
if (reduce_op.lhs() != body.getArgument(0) ||
|
||||
reduce_op.rhs() != body.getArgument(1))
|
||||
return failure();
|
||||
|
||||
mhlo::ReturnOp return_op = dyn_cast<mhlo::ReturnOp>(body.back());
|
||||
if (!return_op) return failure();
|
||||
if (return_op.getNumOperands() != 1 ||
|
||||
return_op.results().front() != reduce_op)
|
||||
if (return_op.getNumOperands() != 1) return failure();
|
||||
|
||||
ReductionOp reduce_op = dyn_cast_or_null<ReductionOp>(
|
||||
return_op.getOperands().front().getDefiningOp());
|
||||
if (!reduce_op) return failure();
|
||||
if (reduce_op.lhs() != body.getArgument(0) ||
|
||||
reduce_op.rhs() != body.getArgument(1))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
@ -654,6 +657,190 @@ class ConvertIotaOpToTfRange : public OpConversionPattern<mhlo::IotaOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Maps the following represenattions of AvgPool in MHLO into a tf.AvgPool{3D}
|
||||
// operation when they cleanly map to 2D or 3D average pool with VALID or SAME
|
||||
// padding:
|
||||
// * div(reduce_sum_window(x), constant(sizeof(window)))
|
||||
// * div(reduce_sum_window(x), reduce_sum_window(constant(1)))
|
||||
class ConvertAvgPoolOp : public OpConversionPattern<mhlo::DivOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DivOp div_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto rw =
|
||||
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.lhs().getDefiningOp());
|
||||
if (!rw) return failure();
|
||||
|
||||
// Check that the reduce-window is a sum-reduce-window.
|
||||
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw.body())))
|
||||
return failure();
|
||||
|
||||
// Check that this is a floating point reduce window with a rank of 4 or 5.
|
||||
RankedTensorType rw_type = rw.getType().dyn_cast<RankedTensorType>();
|
||||
if (!rw_type || !rw_type.getElementType().isa<FloatType>() ||
|
||||
rw_type.getRank() <= 3 || rw_type.getRank() > 5)
|
||||
return failure();
|
||||
|
||||
// Check that the Div op doesn't do broadcasting on the output of the reduce
|
||||
// window.
|
||||
if (div_op.getType() != rw.getType()) return failure();
|
||||
|
||||
// tf.avg_pool need at least 3 dimensions (batch, spatial, channel)
|
||||
const uint64_t rank = rw.window_dimensions().size();
|
||||
if (rank <= 2) return failure();
|
||||
|
||||
// If the init value isn't zero then it can't be an average pool.
|
||||
if (!isFloatZero(rw.init_value())) return failure();
|
||||
|
||||
llvm::SmallVector<int64_t, 5> window_strides;
|
||||
if (rw.window_strides().hasValue()) {
|
||||
window_strides.insert(window_strides.end(),
|
||||
rw.window_strides()->getValues<int64_t>().begin(),
|
||||
rw.window_strides()->getValues<int64_t>().end());
|
||||
} else {
|
||||
window_strides.resize(rank, 1);
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 10> padding;
|
||||
if (rw.padding().hasValue()) {
|
||||
padding.insert(padding.begin(),
|
||||
rw.padding()->getValues<int64_t>().begin(),
|
||||
rw.padding()->getValues<int64_t>().end());
|
||||
} else {
|
||||
padding.resize(2 * rank, 0);
|
||||
}
|
||||
|
||||
// Check that we don't do any reduction along the batch (first) and channel
|
||||
// (last) dimensions.
|
||||
const uint64_t batch_dim = 0;
|
||||
const uint64_t channel_dim = rank - 1;
|
||||
if (rw.window_dimensions().getValue<int64_t>({batch_dim}) != 1 ||
|
||||
rw.window_dimensions().getValue<int64_t>({channel_dim}) != 1 ||
|
||||
window_strides[batch_dim] != 1 || window_strides[channel_dim] != 1 ||
|
||||
padding[2 * batch_dim] != 0 || padding[2 * batch_dim + 1] != 0 ||
|
||||
padding[2 * channel_dim] != 0 || padding[2 * channel_dim + 1] != 0)
|
||||
return failure();
|
||||
|
||||
if (rw.window_dilations().hasValue() &&
|
||||
!(rw.window_dilations()->isSplat() &&
|
||||
rw.window_dilations()->getSplatValue<APInt>() == 1))
|
||||
return failure();
|
||||
|
||||
if (rw.base_dilations().hasValue() &&
|
||||
!(rw.base_dilations()->isSplat() &&
|
||||
rw.base_dilations()->getSplatValue<APInt>() == 1))
|
||||
return failure();
|
||||
|
||||
DenseFPElementsAttr divisor;
|
||||
if (matchPattern(div_op.rhs(), m_Constant(&divisor))) {
|
||||
// If the divisor is a constant then check that it matches with the number
|
||||
// of elements inside the window what is required for a VALID AvgPool.
|
||||
if (!divisor.isSplat()) return failure();
|
||||
int64_t window_size = 1;
|
||||
for (int64_t w : rw.window_dimensions().getValues<int64_t>()) {
|
||||
window_size *= w;
|
||||
}
|
||||
if (!divisor.getSplatValue<APFloat>().isExactlyValue(window_size))
|
||||
return failure();
|
||||
|
||||
// Check that we have no padding.
|
||||
if (!llvm::all_of(padding, [](int64_t i) { return i == 0; }))
|
||||
return failure();
|
||||
|
||||
return replaceWithAvgPool(
|
||||
div_op, rw.operand(),
|
||||
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
|
||||
window_strides, "VALID", rewriter);
|
||||
}
|
||||
|
||||
auto rw_rhs =
|
||||
dyn_cast_or_null<mhlo::ReduceWindowOp>(div_op.rhs().getDefiningOp());
|
||||
if (rw_rhs) {
|
||||
// Check that RHS is a sum-reduce-window.
|
||||
if (failed(MatchBinaryReduceFunction<mhlo::AddOp>(rw_rhs.body())))
|
||||
return failure();
|
||||
|
||||
// Check that the RHS is a reduce_window over a constant 1 input with 0 as
|
||||
// the init value.
|
||||
DenseFPElementsAttr rhs_input;
|
||||
if (!isFloatZero(rw_rhs.init_value()) ||
|
||||
!matchPattern(rw_rhs.operand(), m_Constant(&rhs_input)) ||
|
||||
!rhs_input.isSplat() ||
|
||||
!rhs_input.getSplatValue<APFloat>().isExactlyValue(1.0))
|
||||
return failure();
|
||||
|
||||
// Check that the two reduce window have the same window configuration.
|
||||
if (rw.window_dimensions() != rw_rhs.window_dimensions() ||
|
||||
rw.window_strides() != rw_rhs.window_strides() ||
|
||||
rw.window_dilations() != rw_rhs.window_dilations() ||
|
||||
rw.base_dilations() != rw_rhs.base_dilations() ||
|
||||
rw.padding() != rw_rhs.padding())
|
||||
return failure();
|
||||
|
||||
if (llvm::all_of(padding, [](int64_t i) { return i == 0; }))
|
||||
return replaceWithAvgPool(
|
||||
div_op, rw.operand(),
|
||||
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
|
||||
window_strides, "VALID", rewriter);
|
||||
|
||||
RankedTensorType input_type =
|
||||
rw.operand().getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType output_type = rw.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_type || !output_type) return failure();
|
||||
|
||||
// Check that the individual padding values are corresponding to SAME
|
||||
// padding from TensorFlow.
|
||||
for (uint64_t i = 1; i < rank - 1; ++i) {
|
||||
int64_t padding_size =
|
||||
(output_type.getShape()[i] - 1) * window_strides[i] +
|
||||
rw.window_dimensions().getValue<int64_t>({i}) -
|
||||
input_type.getShape()[i];
|
||||
if (padding[2 * i] !=
|
||||
tensorflow::MathUtil::FloorOfRatio(padding_size, int64_t(2)) ||
|
||||
padding[2 * i + 1] !=
|
||||
tensorflow::MathUtil::CeilOfRatio(padding_size, int64_t(2)))
|
||||
return failure();
|
||||
}
|
||||
return replaceWithAvgPool(
|
||||
div_op, rw.operand(),
|
||||
llvm::to_vector<4>(rw.window_dimensions().getValues<int64_t>()),
|
||||
window_strides, "SAME", rewriter);
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
bool isFloatZero(Value value) const {
|
||||
DenseFPElementsAttr initial_value;
|
||||
return matchPattern(value, m_Constant(&initial_value)) &&
|
||||
initial_value.getNumElements() == 1 &&
|
||||
initial_value.getValue<APFloat>({}).isZero();
|
||||
}
|
||||
|
||||
LogicalResult replaceWithAvgPool(mhlo::DivOp op, Value input,
|
||||
llvm::ArrayRef<int64_t> ksizes,
|
||||
llvm::ArrayRef<int64_t> kstrides,
|
||||
llvm::StringRef padding,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (ksizes.size() == 4) {
|
||||
rewriter.replaceOpWithNewOp<AvgPoolOp>(
|
||||
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
|
||||
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
|
||||
rewriter.getStringAttr("NHWC"));
|
||||
return success();
|
||||
} else if (ksizes.size() == 5) {
|
||||
rewriter.replaceOpWithNewOp<AvgPool3DOp>(
|
||||
op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes),
|
||||
rewriter.getI64ArrayAttr(kstrides), rewriter.getStringAttr(padding),
|
||||
rewriter.getStringAttr("NDHWC"));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TF::TensorFlowDialect>();
|
||||
@ -729,6 +916,25 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
return reshape.getResult();
|
||||
}
|
||||
|
||||
// Converts mhlo.pad to tf.PadV2
|
||||
Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
auto pad_op = cast<mhlo::PadOp>(old_op);
|
||||
mlir::Location loc = pad_op.getLoc();
|
||||
|
||||
llvm::SmallVector<APInt, 8> padding;
|
||||
for (auto p : llvm::zip(pad_op.edge_padding_low().getValues<APInt>(),
|
||||
pad_op.edge_padding_high().getValues<APInt>())) {
|
||||
padding.push_back(std::get<0>(p));
|
||||
padding.push_back(std::get<1>(p));
|
||||
}
|
||||
auto attr_type = RankedTensorType::get({pad_op.edge_padding_low().size(), 2},
|
||||
rewriter.getI64Type());
|
||||
auto padding_attr = DenseIntElementsAttr::get(attr_type, padding);
|
||||
auto padding_op = rewriter.create<ConstantOp>(loc, attr_type, padding_attr);
|
||||
return rewriter.create<PadV2Op>(loc, pad_op.getType(), pad_op.operand(),
|
||||
padding_op, pad_op.padding_value());
|
||||
}
|
||||
|
||||
// Returns true if broadcast_dimensions obey Tensorflow convention, as in new
|
||||
// dimensions are added as prefix.
|
||||
bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions,
|
||||
@ -794,10 +1000,10 @@ static PassRegistration<LegalizeHloToTf> pass(
|
||||
|
||||
void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns,
|
||||
MLIRContext *context) {
|
||||
patterns->insert<ConvertAvgPoolOp, ConvertConvOp, ConvertSliceOp,
|
||||
ConvertReduceOpToTfMax, ConvertReduceOpToTfMin,
|
||||
ConvertReduceOpToTfSum, ConvertIotaOpToTfRange>(context);
|
||||
populateWithGenerated(context, *patterns);
|
||||
patterns->insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
|
||||
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum,
|
||||
ConvertIotaOpToTfRange>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass() {
|
||||
|
@ -216,3 +216,12 @@ def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
|
||||
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
|
||||
$precision_config),
|
||||
(ConvertDotGeneralOp $old_value)>;
|
||||
|
||||
def IsZero : Constraint<CPred<
|
||||
"$0.isSplat() && $0.getSplatValue<APInt>() == 0">>;
|
||||
def ConvertPadOp : NativeCodeCall<
|
||||
"ConvertPadOp($_builder, $0.getDefiningOp())">;
|
||||
def : Pat<(HLO_PadOp:$old_value $input, $pad_value, $pad_low, $pad_high,
|
||||
$pad_interior),
|
||||
(ConvertPadOp $old_value),
|
||||
[(IsZero $pad_interior)]>;
|
||||
|
@ -332,12 +332,13 @@ class LowerDynamicStitchOp : public RewritePattern {
|
||||
class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
|
||||
public:
|
||||
explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
|
||||
: RewritePattern(FakeQuantWithMinMaxVarsOp::getOperationName(),
|
||||
{SubOp::getOperationName(), ConstOp::getOperationName(),
|
||||
MulOp::getOperationName(), FloorOp::getOperationName(),
|
||||
ClipByValueOp::getOperationName(),
|
||||
DivOp::getOperationName(), RoundOp::getOperationName()},
|
||||
1, context) {}
|
||||
: RewritePattern(
|
||||
FakeQuantWithMinMaxVarsOp::getOperationName(),
|
||||
{AddV2Op::getOperationName(), SubOp::getOperationName(),
|
||||
ConstOp::getOperationName(), MulOp::getOperationName(),
|
||||
FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
|
||||
DivOp::getOperationName(), RoundOp::getOperationName()},
|
||||
1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *src_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@ -419,8 +420,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
|
||||
op.getLoc(),
|
||||
DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
|
||||
|
||||
quantized_input = rewriter.create<AddOp>(op.getLoc(), input_ty,
|
||||
quantized_input, half_val);
|
||||
quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
|
||||
quantized_input, half_val);
|
||||
|
||||
quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
|
||||
|
||||
@ -428,8 +429,8 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
|
||||
Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
|
||||
quantized_input, quant_to_float);
|
||||
|
||||
output =
|
||||
rewriter.create<AddOp>(op.getLoc(), input_ty, output, nudged_float_min);
|
||||
output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
|
||||
nudged_float_min);
|
||||
|
||||
rewriter.replaceOp(op, {output});
|
||||
return success();
|
||||
@ -811,7 +812,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
||||
CastOp::getOperationName(),
|
||||
ConstOp::getOperationName(),
|
||||
ConcatV2Op::getOperationName(),
|
||||
AddOp::getOperationName(),
|
||||
AddV2Op::getOperationName(),
|
||||
PadOp::getOperationName(),
|
||||
SplitOp::getOperationName(),
|
||||
UnpackOp::getOperationName(),
|
||||
@ -907,8 +908,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
||||
auto paddings_split = rewriter.create<UnpackOp>(
|
||||
loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
|
||||
rewriter.getI64IntegerAttr(1));
|
||||
auto paddings_sum = rewriter.create<AddOp>(loc, paddings_split.getResult(0),
|
||||
paddings_split.getResult(1));
|
||||
auto paddings_sum = rewriter.create<AddV2Op>(
|
||||
loc, paddings_split.getResult(0), paddings_split.getResult(1));
|
||||
|
||||
auto input_shape_tensor = rewriter.create<ConstOp>(
|
||||
loc,
|
||||
@ -918,7 +919,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern {
|
||||
|
||||
// padded_shape_tensor is the shape of padded.
|
||||
auto padded_shape_tensor =
|
||||
rewriter.create<AddOp>(loc, paddings_sum, input_shape_tensor);
|
||||
rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
|
||||
|
||||
auto zero_i32 = rewriter.create<ConstOp>(
|
||||
loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
|
||||
|
@ -237,7 +237,7 @@ def : Pat<(TF_RoundOp:$res TF_FloatTensor:$input),
|
||||
(TF_SubOp $input, (TF_FloorOp:$floor $input)),
|
||||
(TF_ConstOp (GetScalarOfFloatType<"0.5"> $input))),
|
||||
$floor,
|
||||
(TF_AddOp
|
||||
(TF_AddV2Op
|
||||
(TF_ConstOp (GetScalarOfType<1> $input)), $floor))>;
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user